yolo_layer.c 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. #include "yolo_layer.h"
  2. #include "activations.h"
  3. #include "blas.h"
  4. #include "box.h"
  5. #include "dark_cuda.h"
  6. #include "utils.h"
  7. #include <stdio.h>
  8. #include <assert.h>
  9. #include <string.h>
  10. #include <stdlib.h>
  11. layer make_yolo_layer(int batch, int w, int h, int n, int total, int *mask, int classes, int max_boxes)
  12. {
  13. int i;
  14. layer l = { (LAYER_TYPE)0 };
  15. l.type = YOLO;
  16. l.n = n;
  17. l.total = total;
  18. l.batch = batch;
  19. l.h = h;
  20. l.w = w;
  21. l.c = n*(classes + 4 + 1);
  22. l.out_w = l.w;
  23. l.out_h = l.h;
  24. l.out_c = l.c;
  25. l.classes = classes;
  26. l.cost = (float*)xcalloc(1, sizeof(float));
  27. l.biases = (float*)xcalloc(total * 2, sizeof(float));
  28. if(mask) l.mask = mask;
  29. else{
  30. l.mask = (int*)xcalloc(n, sizeof(int));
  31. for(i = 0; i < n; ++i){
  32. l.mask[i] = i;
  33. }
  34. }
  35. l.bias_updates = (float*)xcalloc(n * 2, sizeof(float));
  36. l.outputs = h*w*n*(classes + 4 + 1);
  37. l.inputs = l.outputs;
  38. l.max_boxes = max_boxes;
  39. l.truths = l.max_boxes*(4 + 1); // 90*(4 + 1);
  40. l.delta = (float*)xcalloc(batch * l.outputs, sizeof(float));
  41. l.output = (float*)xcalloc(batch * l.outputs, sizeof(float));
  42. for(i = 0; i < total*2; ++i){
  43. l.biases[i] = .5;
  44. }
  45. l.forward = forward_yolo_layer;
  46. l.backward = backward_yolo_layer;
  47. #ifdef GPU
  48. l.forward_gpu = forward_yolo_layer_gpu;
  49. l.backward_gpu = backward_yolo_layer_gpu;
  50. l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
  51. l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
  52. free(l.output);
  53. if (cudaSuccess == cudaHostAlloc(&l.output, batch*l.outputs*sizeof(float), cudaHostRegisterMapped)) l.output_pinned = 1;
  54. else {
  55. cudaGetLastError(); // reset CUDA-error
  56. l.output = (float*)xcalloc(batch * l.outputs, sizeof(float));
  57. }
  58. free(l.delta);
  59. if (cudaSuccess == cudaHostAlloc(&l.delta, batch*l.outputs*sizeof(float), cudaHostRegisterMapped)) l.delta_pinned = 1;
  60. else {
  61. cudaGetLastError(); // reset CUDA-error
  62. l.delta = (float*)xcalloc(batch * l.outputs, sizeof(float));
  63. }
  64. #endif
  65. fprintf(stderr, "yolo\n");
  66. srand(time(0));
  67. return l;
  68. }
  69. void resize_yolo_layer(layer *l, int w, int h)
  70. {
  71. l->w = w;
  72. l->h = h;
  73. l->outputs = h*w*l->n*(l->classes + 4 + 1);
  74. l->inputs = l->outputs;
  75. if (!l->output_pinned) l->output = (float*)xrealloc(l->output, l->batch*l->outputs * sizeof(float));
  76. if (!l->delta_pinned) l->delta = (float*)xrealloc(l->delta, l->batch*l->outputs*sizeof(float));
  77. #ifdef GPU
  78. if (l->output_pinned) {
  79. CHECK_CUDA(cudaFreeHost(l->output));
  80. if (cudaSuccess != cudaHostAlloc(&l->output, l->batch*l->outputs * sizeof(float), cudaHostRegisterMapped)) {
  81. cudaGetLastError(); // reset CUDA-error
  82. l->output = (float*)xcalloc(l->batch * l->outputs, sizeof(float));
  83. l->output_pinned = 0;
  84. }
  85. }
  86. if (l->delta_pinned) {
  87. CHECK_CUDA(cudaFreeHost(l->delta));
  88. if (cudaSuccess != cudaHostAlloc(&l->delta, l->batch*l->outputs * sizeof(float), cudaHostRegisterMapped)) {
  89. cudaGetLastError(); // reset CUDA-error
  90. l->delta = (float*)xcalloc(l->batch * l->outputs, sizeof(float));
  91. l->delta_pinned = 0;
  92. }
  93. }
  94. cuda_free(l->delta_gpu);
  95. cuda_free(l->output_gpu);
  96. l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs);
  97. l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs);
  98. #endif
  99. }
  100. box get_yolo_box(float *x, float *biases, int n, int index, int i, int j, int lw, int lh, int w, int h, int stride)
  101. {
  102. box b;
  103. // ln - natural logarithm (base = e)
  104. // x` = t.x * lw - i; // x = ln(x`/(1-x`)) // x - output of previous conv-layer
  105. // y` = t.y * lh - i; // y = ln(y`/(1-y`)) // y - output of previous conv-layer
  106. // w = ln(t.w * net.w / anchors_w); // w - output of previous conv-layer
  107. // h = ln(t.h * net.h / anchors_h); // h - output of previous conv-layer
  108. b.x = (i + x[index + 0*stride]) / lw;
  109. b.y = (j + x[index + 1*stride]) / lh;
  110. b.w = exp(x[index + 2*stride]) * biases[2*n] / w;
  111. b.h = exp(x[index + 3*stride]) * biases[2*n+1] / h;
  112. return b;
  113. }
  114. ious delta_yolo_box(box truth, float *x, float *biases, int n, int index, int i, int j, int lw, int lh, int w, int h, float *delta, float scale, int stride, float iou_normalizer, IOU_LOSS iou_loss, int accumulate)
  115. {
  116. ious all_ious = { 0 };
  117. // i - step in layer width
  118. // j - step in layer height
  119. // Returns a box in absolute coordinates
  120. box pred = get_yolo_box(x, biases, n, index, i, j, lw, lh, w, h, stride);
  121. all_ious.iou = box_iou(pred, truth);
  122. all_ious.giou = box_giou(pred, truth);
  123. all_ious.diou = box_diou(pred, truth);
  124. all_ious.ciou = box_ciou(pred, truth);
  125. // avoid nan in dx_box_iou
  126. if (pred.w == 0) { pred.w = 1.0; }
  127. if (pred.h == 0) { pred.h = 1.0; }
  128. if (iou_loss == MSE) // old loss
  129. {
  130. float tx = (truth.x*lw - i);
  131. float ty = (truth.y*lh - j);
  132. float tw = log(truth.w*w / biases[2 * n]);
  133. float th = log(truth.h*h / biases[2 * n + 1]);
  134. // accumulate delta
  135. delta[index + 0 * stride] += scale * (tx - x[index + 0 * stride]) * iou_normalizer;
  136. delta[index + 1 * stride] += scale * (ty - x[index + 1 * stride]) * iou_normalizer;
  137. delta[index + 2 * stride] += scale * (tw - x[index + 2 * stride]) * iou_normalizer;
  138. delta[index + 3 * stride] += scale * (th - x[index + 3 * stride]) * iou_normalizer;
  139. }
  140. else {
  141. // https://github.com/generalized-iou/g-darknet
  142. // https://arxiv.org/abs/1902.09630v2
  143. // https://giou.stanford.edu/
  144. all_ious.dx_iou = dx_box_iou(pred, truth, iou_loss);
  145. // jacobian^t (transpose)
  146. //float dx = (all_ious.dx_iou.dl + all_ious.dx_iou.dr);
  147. //float dy = (all_ious.dx_iou.dt + all_ious.dx_iou.db);
  148. //float dw = ((-0.5 * all_ious.dx_iou.dl) + (0.5 * all_ious.dx_iou.dr));
  149. //float dh = ((-0.5 * all_ious.dx_iou.dt) + (0.5 * all_ious.dx_iou.db));
  150. // jacobian^t (transpose)
  151. float dx = all_ious.dx_iou.dt;
  152. float dy = all_ious.dx_iou.db;
  153. float dw = all_ious.dx_iou.dl;
  154. float dh = all_ious.dx_iou.dr;
  155. // predict exponential, apply gradient of e^delta_t ONLY for w,h
  156. dw *= exp(x[index + 2 * stride]);
  157. dh *= exp(x[index + 3 * stride]);
  158. // normalize iou weight
  159. dx *= iou_normalizer;
  160. dy *= iou_normalizer;
  161. dw *= iou_normalizer;
  162. dh *= iou_normalizer;
  163. if (!accumulate) {
  164. delta[index + 0 * stride] = 0;
  165. delta[index + 1 * stride] = 0;
  166. delta[index + 2 * stride] = 0;
  167. delta[index + 3 * stride] = 0;
  168. }
  169. // accumulate delta
  170. delta[index + 0 * stride] += dx;
  171. delta[index + 1 * stride] += dy;
  172. delta[index + 2 * stride] += dw;
  173. delta[index + 3 * stride] += dh;
  174. }
  175. return all_ious;
  176. }
  177. void averages_yolo_deltas(int class_index, int box_index, int stride, int classes, float *delta)
  178. {
  179. int classes_in_one_box = 0;
  180. int c;
  181. for (c = 0; c < classes; ++c) {
  182. if (delta[class_index + stride*c] > 0) classes_in_one_box++;
  183. }
  184. if (classes_in_one_box > 0) {
  185. delta[box_index + 0 * stride] /= classes_in_one_box;
  186. delta[box_index + 1 * stride] /= classes_in_one_box;
  187. delta[box_index + 2 * stride] /= classes_in_one_box;
  188. delta[box_index + 3 * stride] /= classes_in_one_box;
  189. }
  190. }
  191. void delta_yolo_class(float *output, float *delta, int index, int class_id, int classes, int stride, float *avg_cat, int focal_loss, float label_smooth_eps, float *classes_multipliers)
  192. {
  193. int n;
  194. if (delta[index + stride*class_id]){
  195. delta[index + stride*class_id] = (1 - label_smooth_eps) - output[index + stride*class_id];
  196. if (classes_multipliers) delta[index + stride*class_id] *= classes_multipliers[class_id];
  197. if(avg_cat) *avg_cat += output[index + stride*class_id];
  198. return;
  199. }
  200. // Focal loss
  201. if (focal_loss) {
  202. // Focal Loss
  203. float alpha = 0.5; // 0.25 or 0.5
  204. //float gamma = 2; // hardcoded in many places of the grad-formula
  205. int ti = index + stride*class_id;
  206. float pt = output[ti] + 0.000000000000001F;
  207. // http://fooplot.com/#W3sidHlwZSI6MCwiZXEiOiItKDEteCkqKDIqeCpsb2coeCkreC0xKSIsImNvbG9yIjoiIzAwMDAwMCJ9LHsidHlwZSI6MTAwMH1d
  208. float grad = -(1 - pt) * (2 * pt*logf(pt) + pt - 1); // http://blog.csdn.net/linmingan/article/details/77885832
  209. //float grad = (1 - pt) * (2 * pt*logf(pt) + pt - 1); // https://github.com/unsky/focal-loss
  210. for (n = 0; n < classes; ++n) {
  211. delta[index + stride*n] = (((n == class_id) ? 1 : 0) - output[index + stride*n]);
  212. delta[index + stride*n] *= alpha*grad;
  213. if (n == class_id && avg_cat) *avg_cat += output[index + stride*n];
  214. }
  215. }
  216. else {
  217. // default
  218. for (n = 0; n < classes; ++n) {
  219. delta[index + stride*n] = ((n == class_id) ? (1 - label_smooth_eps) : (0 + label_smooth_eps/classes)) - output[index + stride*n];
  220. if (classes_multipliers && n == class_id) delta[index + stride*class_id] *= classes_multipliers[class_id];
  221. if (n == class_id && avg_cat) *avg_cat += output[index + stride*n];
  222. }
  223. }
  224. }
  225. int compare_yolo_class(float *output, int classes, int class_index, int stride, float objectness, int class_id, float conf_thresh)
  226. {
  227. int j;
  228. for (j = 0; j < classes; ++j) {
  229. //float prob = objectness * output[class_index + stride*j];
  230. float prob = output[class_index + stride*j];
  231. if (prob > conf_thresh) {
  232. return 1;
  233. }
  234. }
  235. return 0;
  236. }
  237. static int entry_index(layer l, int batch, int location, int entry)
  238. {
  239. int n = location / (l.w*l.h);
  240. int loc = location % (l.w*l.h);
  241. return batch*l.outputs + n*l.w*l.h*(4+l.classes+1) + entry*l.w*l.h + loc;
  242. }
  243. void forward_yolo_layer(const layer l, network_state state)
  244. {
  245. int i, j, b, t, n;
  246. memcpy(l.output, state.input, l.outputs*l.batch * sizeof(float));
  247. #ifndef GPU
  248. for (b = 0; b < l.batch; ++b) {
  249. for (n = 0; n < l.n; ++n) {
  250. int index = entry_index(l, b, n*l.w*l.h, 0);
  251. activate_array(l.output + index, 2 * l.w*l.h, LOGISTIC); // x,y,
  252. scal_add_cpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output + index, 1); // scale x,y
  253. index = entry_index(l, b, n*l.w*l.h, 4);
  254. activate_array(l.output + index, (1 + l.classes)*l.w*l.h, LOGISTIC);
  255. }
  256. }
  257. #endif
  258. // delta is zeroed
  259. memset(l.delta, 0, l.outputs * l.batch * sizeof(float));
  260. if (!state.train) return;
  261. //float avg_iou = 0;
  262. float tot_iou = 0;
  263. float tot_giou = 0;
  264. float tot_diou = 0;
  265. float tot_ciou = 0;
  266. float tot_iou_loss = 0;
  267. float tot_giou_loss = 0;
  268. float tot_diou_loss = 0;
  269. float tot_ciou_loss = 0;
  270. float recall = 0;
  271. float recall75 = 0;
  272. float avg_cat = 0;
  273. float avg_obj = 0;
  274. float avg_anyobj = 0;
  275. int count = 0;
  276. int class_count = 0;
  277. *(l.cost) = 0;
  278. for (b = 0; b < l.batch; ++b) {
  279. for (j = 0; j < l.h; ++j) {
  280. for (i = 0; i < l.w; ++i) {
  281. for (n = 0; n < l.n; ++n) {
  282. int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0);
  283. box pred = get_yolo_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.w*l.h);
  284. float best_match_iou = 0;
  285. int best_match_t = 0;
  286. float best_iou = 0;
  287. int best_t = 0;
  288. for (t = 0; t < l.max_boxes; ++t) {
  289. box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1);
  290. int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
  291. if (class_id >= l.classes) {
  292. printf(" Warning: in txt-labels class_id=%d >= classes=%d in cfg-file. In txt-labels class_id should be [from 0 to %d] \n", class_id, l.classes, l.classes - 1);
  293. printf(" truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f, class_id = %d \n", truth.x, truth.y, truth.w, truth.h, class_id);
  294. getchar();
  295. continue; // if label contains class_id more than number of classes in the cfg-file
  296. }
  297. if (!truth.x) break; // continue;
  298. int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
  299. int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
  300. float objectness = l.output[obj_index];
  301. int class_id_match = compare_yolo_class(l.output, l.classes, class_index, l.w*l.h, objectness, class_id, 0.25f);
  302. float iou = box_iou(pred, truth);
  303. if (iou > best_match_iou && class_id_match == 1) {
  304. best_match_iou = iou;
  305. best_match_t = t;
  306. }
  307. if (iou > best_iou) {
  308. best_iou = iou;
  309. best_t = t;
  310. }
  311. }
  312. int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4);
  313. avg_anyobj += l.output[obj_index];
  314. l.delta[obj_index] = l.cls_normalizer * (0 - l.output[obj_index]);
  315. if (best_match_iou > l.ignore_thresh) {
  316. l.delta[obj_index] = 0;
  317. }
  318. if (best_iou > l.truth_thresh) {
  319. l.delta[obj_index] = l.cls_normalizer * (1 - l.output[obj_index]);
  320. int class_id = state.truth[best_t*(4 + 1) + b*l.truths + 4];
  321. if (l.map) class_id = l.map[class_id];
  322. int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
  323. delta_yolo_class(l.output, l.delta, class_index, class_id, l.classes, l.w*l.h, 0, l.focal_loss, l.label_smooth_eps, l.classes_multipliers);
  324. box truth = float_to_box_stride(state.truth + best_t*(4 + 1) + b*l.truths, 1);
  325. const float class_multiplier = (l.classes_multipliers) ? l.classes_multipliers[class_id] : 1.0f;
  326. delta_yolo_box(truth, l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.delta, (2 - truth.w*truth.h), l.w*l.h, l.iou_normalizer * class_multiplier, l.iou_loss, 1);
  327. }
  328. }
  329. }
  330. }
  331. for (t = 0; t < l.max_boxes; ++t) {
  332. box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1);
  333. if (truth.x < 0 || truth.y < 0 || truth.x > 1 || truth.y > 1 || truth.w < 0 || truth.h < 0) {
  334. char buff[256];
  335. printf(" Wrong label: truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f \n", truth.x, truth.y, truth.w, truth.h);
  336. sprintf(buff, "echo \"Wrong label: truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f\" >> bad_label.list",
  337. truth.x, truth.y, truth.w, truth.h);
  338. system(buff);
  339. }
  340. int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
  341. if (class_id >= l.classes) continue; // if label contains class_id more than number of classes in the cfg-file
  342. if (!truth.x) break; // continue;
  343. float best_iou = 0;
  344. int best_n = 0;
  345. i = (truth.x * l.w);
  346. j = (truth.y * l.h);
  347. box truth_shift = truth;
  348. truth_shift.x = truth_shift.y = 0;
  349. for (n = 0; n < l.total; ++n) {
  350. box pred = { 0 };
  351. pred.w = l.biases[2 * n] / state.net.w;
  352. pred.h = l.biases[2 * n + 1] / state.net.h;
  353. float iou = box_iou(pred, truth_shift);
  354. if (iou > best_iou) {
  355. best_iou = iou;
  356. best_n = n;
  357. }
  358. }
  359. int mask_n = int_index(l.mask, best_n, l.n);
  360. if (mask_n >= 0) {
  361. int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
  362. if (l.map) class_id = l.map[class_id];
  363. int box_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 0);
  364. const float class_multiplier = (l.classes_multipliers) ? l.classes_multipliers[class_id] : 1.0f;
  365. ious all_ious = delta_yolo_box(truth, l.output, l.biases, best_n, box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.delta, (2 - truth.w*truth.h), l.w*l.h, l.iou_normalizer * class_multiplier, l.iou_loss, 1);
  366. // range is 0 <= 1
  367. tot_iou += all_ious.iou;
  368. tot_iou_loss += 1 - all_ious.iou;
  369. // range is -1 <= giou <= 1
  370. tot_giou += all_ious.giou;
  371. tot_giou_loss += 1 - all_ious.giou;
  372. tot_diou += all_ious.diou;
  373. tot_diou_loss += 1 - all_ious.diou;
  374. tot_ciou += all_ious.ciou;
  375. tot_ciou_loss += 1 - all_ious.ciou;
  376. int obj_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4);
  377. avg_obj += l.output[obj_index];
  378. l.delta[obj_index] = class_multiplier * l.cls_normalizer * (1 - l.output[obj_index]);
  379. int class_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4 + 1);
  380. delta_yolo_class(l.output, l.delta, class_index, class_id, l.classes, l.w*l.h, &avg_cat, l.focal_loss, l.label_smooth_eps, l.classes_multipliers);
  381. ++count;
  382. ++class_count;
  383. if (all_ious.iou > .5) recall += 1;
  384. if (all_ious.iou > .75) recall75 += 1;
  385. }
  386. // iou_thresh
  387. for (n = 0; n < l.total; ++n) {
  388. int mask_n = int_index(l.mask, n, l.n);
  389. if (mask_n >= 0 && n != best_n && l.iou_thresh < 1.0f) {
  390. box pred = { 0 };
  391. pred.w = l.biases[2 * n] / state.net.w;
  392. pred.h = l.biases[2 * n + 1] / state.net.h;
  393. float iou = box_iou(pred, truth_shift);
  394. // iou, n
  395. if (iou > l.iou_thresh) {
  396. int class_id = state.truth[t*(4 + 1) + b*l.truths + 4];
  397. if (l.map) class_id = l.map[class_id];
  398. int box_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 0);
  399. const float class_multiplier = (l.classes_multipliers) ? l.classes_multipliers[class_id] : 1.0f;
  400. ious all_ious = delta_yolo_box(truth, l.output, l.biases, n, box_index, i, j, l.w, l.h, state.net.w, state.net.h, l.delta, (2 - truth.w*truth.h), l.w*l.h, l.iou_normalizer * class_multiplier, l.iou_loss, 1);
  401. // range is 0 <= 1
  402. tot_iou += all_ious.iou;
  403. tot_iou_loss += 1 - all_ious.iou;
  404. // range is -1 <= giou <= 1
  405. tot_giou += all_ious.giou;
  406. tot_giou_loss += 1 - all_ious.giou;
  407. tot_diou += all_ious.diou;
  408. tot_diou_loss += 1 - all_ious.diou;
  409. tot_ciou += all_ious.ciou;
  410. tot_ciou_loss += 1 - all_ious.ciou;
  411. int obj_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4);
  412. avg_obj += l.output[obj_index];
  413. l.delta[obj_index] = class_multiplier * l.cls_normalizer * (1 - l.output[obj_index]);
  414. int class_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4 + 1);
  415. delta_yolo_class(l.output, l.delta, class_index, class_id, l.classes, l.w*l.h, &avg_cat, l.focal_loss, l.label_smooth_eps, l.classes_multipliers);
  416. ++count;
  417. ++class_count;
  418. if (all_ious.iou > .5) recall += 1;
  419. if (all_ious.iou > .75) recall75 += 1;
  420. }
  421. }
  422. }
  423. }
  424. // averages the deltas obtained by the function: delta_yolo_box()_accumulate
  425. for (j = 0; j < l.h; ++j) {
  426. for (i = 0; i < l.w; ++i) {
  427. for (n = 0; n < l.n; ++n) {
  428. int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0);
  429. int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 4 + 1);
  430. const int stride = l.w*l.h;
  431. averages_yolo_deltas(class_index, box_index, stride, l.classes, l.delta);
  432. }
  433. }
  434. }
  435. }
  436. //*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
  437. //printf("Region %d Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, .5R: %f, .75R: %f, count: %d\n", state.index, avg_iou / count, avg_cat / class_count, avg_obj / count, avg_anyobj / (l.w*l.h*l.n*l.batch), recall / count, recall75 / count, count);
  438. int stride = l.w*l.h;
  439. float* no_iou_loss_delta = (float *)calloc(l.batch * l.outputs, sizeof(float));
  440. memcpy(no_iou_loss_delta, l.delta, l.batch * l.outputs * sizeof(float));
  441. for (b = 0; b < l.batch; ++b) {
  442. for (j = 0; j < l.h; ++j) {
  443. for (i = 0; i < l.w; ++i) {
  444. for (n = 0; n < l.n; ++n) {
  445. int index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0);
  446. no_iou_loss_delta[index + 0 * stride] = 0;
  447. no_iou_loss_delta[index + 1 * stride] = 0;
  448. no_iou_loss_delta[index + 2 * stride] = 0;
  449. no_iou_loss_delta[index + 3 * stride] = 0;
  450. }
  451. }
  452. }
  453. }
  454. float classification_loss = l.cls_normalizer * pow(mag_array(no_iou_loss_delta, l.outputs * l.batch), 2);
  455. free(no_iou_loss_delta);
  456. float loss = pow(mag_array(l.delta, l.outputs * l.batch), 2);
  457. float iou_loss = loss - classification_loss;
  458. float avg_iou_loss = 0;
  459. // gIOU loss + MSE (objectness) loss
  460. if (l.iou_loss == MSE) {
  461. *(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
  462. }
  463. else {
  464. // Always compute classification loss both for iou + cls loss and for logging with mse loss
  465. // TODO: remove IOU loss fields before computing MSE on class
  466. // probably split into two arrays
  467. if (l.iou_loss == GIOU) {
  468. avg_iou_loss = count > 0 ? l.iou_normalizer * (tot_giou_loss / count) : 0;
  469. }
  470. else {
  471. avg_iou_loss = count > 0 ? l.iou_normalizer * (tot_iou_loss / count) : 0;
  472. }
  473. *(l.cost) = avg_iou_loss + classification_loss;
  474. }
  475. loss /= l.batch;
  476. classification_loss /= l.batch;
  477. iou_loss /= l.batch;
  478. printf("v3 (%s loss, Normalizer: (iou: %f, cls: %f) Region %d Avg (IOU: %f, GIOU: %f), Class: %f, Obj: %f, No Obj: %f, .5R: %f, .75R: %f, count: %d, loss = %f, class_loss = %f, iou_loss = %f\n",
  479. (l.iou_loss == MSE ? "mse" : (l.iou_loss == GIOU ? "giou" : "iou")), l.iou_normalizer, l.cls_normalizer, state.index, tot_iou / count, tot_giou / count, avg_cat / class_count, avg_obj / count, avg_anyobj / (l.w*l.h*l.n*l.batch), recall / count, recall75 / count, count,
  480. loss, classification_loss, iou_loss);
  481. }
  482. void backward_yolo_layer(const layer l, network_state state)
  483. {
  484. axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
  485. }
  486. // Converts output of the network to detection boxes
  487. // w,h: image width,height
  488. // netw,neth: network width,height
  489. // relative: 1 (all callers seems to pass TRUE)
  490. void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter)
  491. {
  492. int i;
  493. // network height (or width)
  494. int new_w = 0;
  495. // network height (or width)
  496. int new_h = 0;
  497. // Compute scale given image w,h vs network w,h
  498. // I think this "rotates" the image to match network to input image w/h ratio
  499. // new_h and new_w are really just network width and height
  500. if (letter) {
  501. if (((float)netw / w) < ((float)neth / h)) {
  502. new_w = netw;
  503. new_h = (h * netw) / w;
  504. }
  505. else {
  506. new_h = neth;
  507. new_w = (w * neth) / h;
  508. }
  509. }
  510. else {
  511. new_w = netw;
  512. new_h = neth;
  513. }
  514. // difference between network width and "rotated" width
  515. float deltaw = netw - new_w;
  516. // difference between network height and "rotated" height
  517. float deltah = neth - new_h;
  518. // ratio between rotated network width and network width
  519. float ratiow = (float)new_w / netw;
  520. // ratio between rotated network width and network width
  521. float ratioh = (float)new_h / neth;
  522. for (i = 0; i < n; ++i) {
  523. box b = dets[i].bbox;
  524. // x = ( x - (deltaw/2)/netw ) / ratiow;
  525. // x - [(1/2 the difference of the network width and rotated width) / (network width)]
  526. b.x = (b.x - deltaw / 2. / netw) / ratiow;
  527. b.y = (b.y - deltah / 2. / neth) / ratioh;
  528. // scale to match rotation of incoming image
  529. b.w *= 1 / ratiow;
  530. b.h *= 1 / ratioh;
  531. // relative seems to always be == 1, I don't think we hit this condition, ever.
  532. if (!relative) {
  533. b.x *= w;
  534. b.w *= w;
  535. b.y *= h;
  536. b.h *= h;
  537. }
  538. dets[i].bbox = b;
  539. }
  540. }
  541. /*
  542. void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter)
  543. {
  544. int i;
  545. int new_w=0;
  546. int new_h=0;
  547. if (letter) {
  548. if (((float)netw / w) < ((float)neth / h)) {
  549. new_w = netw;
  550. new_h = (h * netw) / w;
  551. }
  552. else {
  553. new_h = neth;
  554. new_w = (w * neth) / h;
  555. }
  556. }
  557. else {
  558. new_w = netw;
  559. new_h = neth;
  560. }
  561. for (i = 0; i < n; ++i){
  562. box b = dets[i].bbox;
  563. b.x = (b.x - (netw - new_w)/2./netw) / ((float)new_w/netw);
  564. b.y = (b.y - (neth - new_h)/2./neth) / ((float)new_h/neth);
  565. b.w *= (float)netw/new_w;
  566. b.h *= (float)neth/new_h;
  567. if(!relative){
  568. b.x *= w;
  569. b.w *= w;
  570. b.y *= h;
  571. b.h *= h;
  572. }
  573. dets[i].bbox = b;
  574. }
  575. }
  576. */
  577. int yolo_num_detections(layer l, float thresh)
  578. {
  579. int i, n;
  580. int count = 0;
  581. for (i = 0; i < l.w*l.h; ++i){
  582. for(n = 0; n < l.n; ++n){
  583. int obj_index = entry_index(l, 0, n*l.w*l.h + i, 4);
  584. if(l.output[obj_index] > thresh){
  585. ++count;
  586. }
  587. }
  588. }
  589. return count;
  590. }
  591. void avg_flipped_yolo(layer l)
  592. {
  593. int i,j,n,z;
  594. float *flip = l.output + l.outputs;
  595. for (j = 0; j < l.h; ++j) {
  596. for (i = 0; i < l.w/2; ++i) {
  597. for (n = 0; n < l.n; ++n) {
  598. for(z = 0; z < l.classes + 4 + 1; ++z){
  599. int i1 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + i;
  600. int i2 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + (l.w - i - 1);
  601. float swap = flip[i1];
  602. flip[i1] = flip[i2];
  603. flip[i2] = swap;
  604. if(z == 0){
  605. flip[i1] = -flip[i1];
  606. flip[i2] = -flip[i2];
  607. }
  608. }
  609. }
  610. }
  611. }
  612. for(i = 0; i < l.outputs; ++i){
  613. l.output[i] = (l.output[i] + flip[i])/2.;
  614. }
  615. }
  616. int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter)
  617. {
  618. //printf("\n l.batch = %d, l.w = %d, l.h = %d, l.n = %d \n", l.batch, l.w, l.h, l.n);
  619. int i,j,n;
  620. float *predictions = l.output;
  621. // This snippet below is not necessary
  622. // Need to comment it in order to batch processing >= 2 images
  623. //if (l.batch == 2) avg_flipped_yolo(l);
  624. int count = 0;
  625. for (i = 0; i < l.w*l.h; ++i){
  626. int row = i / l.w;
  627. int col = i % l.w;
  628. for(n = 0; n < l.n; ++n){
  629. int obj_index = entry_index(l, 0, n*l.w*l.h + i, 4);
  630. float objectness = predictions[obj_index];
  631. //if(objectness <= thresh) continue; // incorrect behavior for Nan values
  632. if (objectness > thresh) {
  633. //printf("\n objectness = %f, thresh = %f, i = %d, n = %d \n", objectness, thresh, i, n);
  634. int box_index = entry_index(l, 0, n*l.w*l.h + i, 0);
  635. dets[count].bbox = get_yolo_box(predictions, l.biases, l.mask[n], box_index, col, row, l.w, l.h, netw, neth, l.w*l.h);
  636. dets[count].objectness = objectness;
  637. dets[count].classes = l.classes;
  638. for (j = 0; j < l.classes; ++j) {
  639. int class_index = entry_index(l, 0, n*l.w*l.h + i, 4 + 1 + j);
  640. float prob = objectness*predictions[class_index];
  641. dets[count].prob[j] = (prob > thresh) ? prob : 0;
  642. }
  643. ++count;
  644. }
  645. }
  646. }
  647. correct_yolo_boxes(dets, count, w, h, netw, neth, relative, letter);
  648. return count;
  649. }
  650. #ifdef GPU
  651. void forward_yolo_layer_gpu(const layer l, network_state state)
  652. {
  653. //copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1);
  654. simple_copy_ongpu(l.batch*l.inputs, state.input, l.output_gpu);
  655. int b, n;
  656. for (b = 0; b < l.batch; ++b){
  657. for(n = 0; n < l.n; ++n){
  658. int index = entry_index(l, b, n*l.w*l.h, 0);
  659. // y = 1./(1. + exp(-x))
  660. // x = ln(y/(1-y)) // ln - natural logarithm (base = e)
  661. // if(y->1) x -> inf
  662. // if(y->0) x -> -inf
  663. activate_array_ongpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC); // x,y
  664. if (l.scale_x_y != 1) scal_add_ongpu(2 * l.w*l.h, l.scale_x_y, -0.5*(l.scale_x_y - 1), l.output_gpu + index, 1); // scale x,y
  665. index = entry_index(l, b, n*l.w*l.h, 4);
  666. activate_array_ongpu(l.output_gpu + index, (1+l.classes)*l.w*l.h, LOGISTIC); // classes and objectness
  667. }
  668. }
  669. if(!state.train || l.onlyforward){
  670. //cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
  671. cuda_pull_array_async(l.output_gpu, l.output, l.batch*l.outputs);
  672. CHECK_CUDA(cudaPeekAtLastError());
  673. return;
  674. }
  675. float *in_cpu = (float *)xcalloc(l.batch*l.inputs, sizeof(float));
  676. cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
  677. memcpy(in_cpu, l.output, l.batch*l.outputs*sizeof(float));
  678. float *truth_cpu = 0;
  679. if (state.truth) {
  680. int num_truth = l.batch*l.truths;
  681. truth_cpu = (float *)xcalloc(num_truth, sizeof(float));
  682. cuda_pull_array(state.truth, truth_cpu, num_truth);
  683. }
  684. network_state cpu_state = state;
  685. cpu_state.net = state.net;
  686. cpu_state.index = state.index;
  687. cpu_state.train = state.train;
  688. cpu_state.truth = truth_cpu;
  689. cpu_state.input = in_cpu;
  690. forward_yolo_layer(l, cpu_state);
  691. //forward_yolo_layer(l, state);
  692. cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
  693. free(in_cpu);
  694. if (cpu_state.truth) free(cpu_state.truth);
  695. }
  696. void backward_yolo_layer_gpu(const layer l, network_state state)
  697. {
  698. axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
  699. }
  700. #endif