compare.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. #include <stdio.h>
  2. #include "network.h"
  3. #include "detection_layer.h"
  4. #include "cost_layer.h"
  5. #include "utils.h"
  6. #include "parser.h"
  7. #include "box.h"
  8. void train_compare(char *cfgfile, char *weightfile)
  9. {
  10. srand(time(0));
  11. float avg_loss = -1;
  12. char *base = basecfg(cfgfile);
  13. char* backup_directory = "backup/";
  14. printf("%s\n", base);
  15. network net = parse_network_cfg(cfgfile);
  16. if(weightfile){
  17. load_weights(&net, weightfile);
  18. }
  19. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
  20. int imgs = 1024;
  21. list *plist = get_paths("data/compare.train.list");
  22. char **paths = (char **)list_to_array(plist);
  23. int N = plist->size;
  24. printf("%d\n", N);
  25. clock_t time;
  26. pthread_t load_thread;
  27. data train;
  28. data buffer;
  29. load_args args = {0};
  30. args.w = net.w;
  31. args.h = net.h;
  32. args.paths = paths;
  33. args.classes = 20;
  34. args.n = imgs;
  35. args.m = N;
  36. args.d = &buffer;
  37. args.type = COMPARE_DATA;
  38. load_thread = load_data_in_thread(args);
  39. int epoch = *net.seen/N;
  40. int i = 0;
  41. while(1){
  42. ++i;
  43. time=clock();
  44. pthread_join(load_thread, 0);
  45. train = buffer;
  46. load_thread = load_data_in_thread(args);
  47. printf("Loaded: %lf seconds\n", sec(clock()-time));
  48. time=clock();
  49. float loss = train_network(net, train);
  50. if(avg_loss == -1) avg_loss = loss;
  51. avg_loss = avg_loss*.9 + loss*.1;
  52. printf("%.3f: %f, %f avg, %lf seconds, %ld images\n", (float)*net.seen/N, loss, avg_loss, sec(clock()-time), *net.seen);
  53. free_data(train);
  54. if(i%100 == 0){
  55. char buff[256];
  56. sprintf(buff, "%s/%s_%d_minor_%d.weights",backup_directory,base, epoch, i);
  57. save_weights(net, buff);
  58. }
  59. if(*net.seen/N > epoch){
  60. epoch = *net.seen/N;
  61. i = 0;
  62. char buff[256];
  63. sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
  64. save_weights(net, buff);
  65. if(epoch%22 == 0) net.learning_rate *= .1;
  66. }
  67. }
  68. pthread_join(load_thread, 0);
  69. free_data(buffer);
  70. free_network(net);
  71. free_ptrs((void**)paths, plist->size);
  72. free_list(plist);
  73. free(base);
  74. }
  75. void validate_compare(char *filename, char *weightfile)
  76. {
  77. int i = 0;
  78. network net = parse_network_cfg(filename);
  79. if(weightfile){
  80. load_weights(&net, weightfile);
  81. }
  82. srand(time(0));
  83. list *plist = get_paths("data/compare.val.list");
  84. //list *plist = get_paths("data/compare.val.old");
  85. char **paths = (char **)list_to_array(plist);
  86. int N = plist->size/2;
  87. free_list(plist);
  88. clock_t time;
  89. int correct = 0;
  90. int total = 0;
  91. int splits = 10;
  92. int num = (i+1)*N/splits - i*N/splits;
  93. data val, buffer;
  94. load_args args = {0};
  95. args.w = net.w;
  96. args.h = net.h;
  97. args.paths = paths;
  98. args.classes = 20;
  99. args.n = num;
  100. args.m = 0;
  101. args.d = &buffer;
  102. args.type = COMPARE_DATA;
  103. pthread_t load_thread = load_data_in_thread(args);
  104. for(i = 1; i <= splits; ++i){
  105. time=clock();
  106. pthread_join(load_thread, 0);
  107. val = buffer;
  108. num = (i+1)*N/splits - i*N/splits;
  109. char **part = paths+(i*N/splits);
  110. if(i != splits){
  111. args.paths = part;
  112. load_thread = load_data_in_thread(args);
  113. }
  114. printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
  115. time=clock();
  116. matrix pred = network_predict_data(net, val);
  117. int j,k;
  118. for(j = 0; j < val.y.rows; ++j){
  119. for(k = 0; k < 20; ++k){
  120. if(val.y.vals[j][k*2] != val.y.vals[j][k*2+1]){
  121. ++total;
  122. if((val.y.vals[j][k*2] < val.y.vals[j][k*2+1]) == (pred.vals[j][k*2] < pred.vals[j][k*2+1])){
  123. ++correct;
  124. }
  125. }
  126. }
  127. }
  128. free_matrix(pred);
  129. printf("%d: Acc: %f, %lf seconds, %d images\n", i, (float)correct/total, sec(clock()-time), val.X.rows);
  130. free_data(val);
  131. }
  132. }
  133. typedef struct {
  134. network net;
  135. char *filename;
  136. int class_id;
  137. int classes;
  138. float elo;
  139. float *elos;
  140. } sortable_bbox;
  141. int total_compares = 0;
  142. int current_class_id = 0;
  143. int elo_comparator(const void*a, const void *b)
  144. {
  145. sortable_bbox box1 = *(sortable_bbox*)a;
  146. sortable_bbox box2 = *(sortable_bbox*)b;
  147. if(box1.elos[current_class_id] == box2.elos[current_class_id]) return 0;
  148. if(box1.elos[current_class_id] > box2.elos[current_class_id]) return -1;
  149. return 1;
  150. }
  151. int bbox_comparator(const void *a, const void *b)
  152. {
  153. ++total_compares;
  154. sortable_bbox box1 = *(sortable_bbox*)a;
  155. sortable_bbox box2 = *(sortable_bbox*)b;
  156. network net = box1.net;
  157. int class_id = box1.class_id;
  158. image im1 = load_image_color(box1.filename, net.w, net.h);
  159. image im2 = load_image_color(box2.filename, net.w, net.h);
  160. float* X = (float*)xcalloc(net.w * net.h * net.c, sizeof(float));
  161. memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
  162. memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
  163. float *predictions = network_predict(net, X);
  164. free_image(im1);
  165. free_image(im2);
  166. free(X);
  167. if (predictions[class_id*2] > predictions[class_id*2+1]){
  168. return 1;
  169. }
  170. return -1;
  171. }
  172. void bbox_update(sortable_bbox *a, sortable_bbox *b, int class_id, int result)
  173. {
  174. int k = 32;
  175. float EA = 1./(1+pow(10, (b->elos[class_id] - a->elos[class_id])/400.));
  176. float EB = 1./(1+pow(10, (a->elos[class_id] - b->elos[class_id])/400.));
  177. float SA = result ? 1 : 0;
  178. float SB = result ? 0 : 1;
  179. a->elos[class_id] += k*(SA - EA);
  180. b->elos[class_id] += k*(SB - EB);
  181. }
  182. void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class_id)
  183. {
  184. image im1 = load_image_color(a->filename, net.w, net.h);
  185. image im2 = load_image_color(b->filename, net.w, net.h);
  186. float* X = (float*)xcalloc(net.w * net.h * net.c, sizeof(float));
  187. memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
  188. memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
  189. float *predictions = network_predict(net, X);
  190. ++total_compares;
  191. int i;
  192. for(i = 0; i < classes; ++i){
  193. if(class_id < 0 || class_id == i){
  194. int result = predictions[i*2] > predictions[i*2+1];
  195. bbox_update(a, b, i, result);
  196. }
  197. }
  198. free_image(im1);
  199. free_image(im2);
  200. free(X);
  201. }
  202. void SortMaster3000(char *filename, char *weightfile)
  203. {
  204. int i = 0;
  205. network net = parse_network_cfg(filename);
  206. if(weightfile){
  207. load_weights(&net, weightfile);
  208. }
  209. srand(time(0));
  210. set_batch_network(&net, 1);
  211. list *plist = get_paths("data/compare.sort.list");
  212. //list *plist = get_paths("data/compare.val.old");
  213. char **paths = (char **)list_to_array(plist);
  214. int N = plist->size;
  215. free_list(plist);
  216. sortable_bbox* boxes = (sortable_bbox*)xcalloc(N, sizeof(sortable_bbox));
  217. printf("Sorting %d boxes...\n", N);
  218. for(i = 0; i < N; ++i){
  219. boxes[i].filename = paths[i];
  220. boxes[i].net = net;
  221. boxes[i].class_id = 7;
  222. boxes[i].elo = 1500;
  223. }
  224. clock_t time=clock();
  225. qsort(boxes, N, sizeof(sortable_bbox), bbox_comparator);
  226. for(i = 0; i < N; ++i){
  227. printf("%s\n", boxes[i].filename);
  228. }
  229. printf("Sorted in %d compares, %f secs\n", total_compares, sec(clock()-time));
  230. }
  231. void BattleRoyaleWithCheese(char *filename, char *weightfile)
  232. {
  233. int classes = 20;
  234. int i,j;
  235. network net = parse_network_cfg(filename);
  236. if(weightfile){
  237. load_weights(&net, weightfile);
  238. }
  239. srand(time(0));
  240. set_batch_network(&net, 1);
  241. list *plist = get_paths("data/compare.sort.list");
  242. //list *plist = get_paths("data/compare.small.list");
  243. //list *plist = get_paths("data/compare.cat.list");
  244. //list *plist = get_paths("data/compare.val.old");
  245. char **paths = (char **)list_to_array(plist);
  246. int N = plist->size;
  247. int total = N;
  248. free_list(plist);
  249. sortable_bbox* boxes = (sortable_bbox*)xcalloc(N, sizeof(sortable_bbox));
  250. printf("Battling %d boxes...\n", N);
  251. for(i = 0; i < N; ++i){
  252. boxes[i].filename = paths[i];
  253. boxes[i].net = net;
  254. boxes[i].classes = classes;
  255. boxes[i].elos = (float*)xcalloc(classes, sizeof(float));
  256. for(j = 0; j < classes; ++j){
  257. boxes[i].elos[j] = 1500;
  258. }
  259. }
  260. int round;
  261. clock_t time=clock();
  262. for(round = 1; round <= 4; ++round){
  263. clock_t round_time=clock();
  264. printf("Round: %d\n", round);
  265. shuffle(boxes, N, sizeof(sortable_bbox));
  266. for(i = 0; i < N/2; ++i){
  267. bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
  268. }
  269. printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
  270. }
  271. int class_id;
  272. for (class_id = 0; class_id < classes; ++class_id){
  273. N = total;
  274. current_class_id = class_id;
  275. qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
  276. N /= 2;
  277. for(round = 1; round <= 100; ++round){
  278. clock_t round_time=clock();
  279. printf("Round: %d\n", round);
  280. sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
  281. for(i = 0; i < N/2; ++i){
  282. bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class_id);
  283. }
  284. qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
  285. if(round <= 20) N = (N*9/10)/2*2;
  286. printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
  287. }
  288. char buff[256];
  289. sprintf(buff, "results/battle_%d.log", class_id);
  290. FILE *outfp = fopen(buff, "w");
  291. for(i = 0; i < N; ++i){
  292. fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class_id]);
  293. }
  294. fclose(outfp);
  295. }
  296. printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
  297. }
  298. void run_compare(int argc, char **argv)
  299. {
  300. if(argc < 4){
  301. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  302. return;
  303. }
  304. char *cfg = argv[3];
  305. char *weights = (argc > 4) ? argv[4] : 0;
  306. //char *filename = (argc > 5) ? argv[5]: 0;
  307. if(0==strcmp(argv[2], "train")) train_compare(cfg, weights);
  308. else if(0==strcmp(argv[2], "valid")) validate_compare(cfg, weights);
  309. else if(0==strcmp(argv[2], "sort")) SortMaster3000(cfg, weights);
  310. else if(0==strcmp(argv[2], "battle")) BattleRoyaleWithCheese(cfg, weights);
  311. /*
  312. else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
  313. else if(0==strcmp(argv[2], "extract")) extract_boxes(cfg, weights);
  314. else if(0==strcmp(argv[2], "valid")) validate_recall(cfg, weights);
  315. */
  316. }