network_kernels.cu 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. #include "dark_cuda.h"
  2. #include <stdio.h>
  3. #include <time.h>
  4. #include <assert.h>
  5. #include "network.h"
  6. #include "image.h"
  7. #include "data.h"
  8. #include "utils.h"
  9. #include "parser.h"
  10. #include "crop_layer.h"
  11. #include "connected_layer.h"
  12. #include "rnn_layer.h"
  13. #include "gru_layer.h"
  14. #include "crnn_layer.h"
  15. #include "detection_layer.h"
  16. #include "region_layer.h"
  17. #include "convolutional_layer.h"
  18. #include "activation_layer.h"
  19. #include "maxpool_layer.h"
  20. #include "reorg_layer.h"
  21. #include "avgpool_layer.h"
  22. #include "normalization_layer.h"
  23. #include "batchnorm_layer.h"
  24. #include "cost_layer.h"
  25. #include "local_layer.h"
  26. #include "softmax_layer.h"
  27. #include "dropout_layer.h"
  28. #include "route_layer.h"
  29. #include "shortcut_layer.h"
  30. #include "blas.h"
  31. //#ifdef OPENCV
  32. //#include <opencv2/highgui/highgui_c.h>
  33. //#endif
  34. #include "http_stream.h"
  35. float * get_network_output_gpu_layer(network net, int i);
  36. float * get_network_delta_gpu_layer(network net, int i);
  37. float * get_network_output_gpu(network net);
  38. typedef struct time_benchmark_layers {
  39. float time;
  40. int layer_id, layer_type;
  41. } time_benchmark_layers;
  42. int time_comparator(const void *pa, const void *pb)
  43. {
  44. time_benchmark_layers a = *(time_benchmark_layers *)pa;
  45. time_benchmark_layers b = *(time_benchmark_layers *)pb;
  46. float diff = a.time - b.time;
  47. if (diff < 0) return 1;
  48. else if (diff > 0) return -1;
  49. return 0;
  50. }
  51. void forward_network_gpu(network net, network_state state)
  52. {
  53. static time_benchmark_layers *avg_time_per_layer = NULL;
  54. static time_benchmark_layers *sorted_avg_time_per_layer = NULL;
  55. double start_time, end_time;
  56. if (net.benchmark_layers) {
  57. if (!avg_time_per_layer) {
  58. avg_time_per_layer = (time_benchmark_layers *)calloc(net.n, sizeof(time_benchmark_layers));
  59. sorted_avg_time_per_layer = (time_benchmark_layers *)calloc(net.n, sizeof(time_benchmark_layers));
  60. }
  61. cudaDeviceSynchronize();
  62. }
  63. //printf("\n");
  64. state.workspace = net.workspace;
  65. int i;
  66. for(i = 0; i < net.n; ++i){
  67. state.index = i;
  68. layer l = net.layers[i];
  69. if(l.delta_gpu && state.train){
  70. fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
  71. }
  72. if (net.benchmark_layers) {
  73. start_time = get_time_point();
  74. }
  75. l.forward_gpu(l, state);
  76. if (net.benchmark_layers) {
  77. CHECK_CUDA(cudaDeviceSynchronize());
  78. end_time = get_time_point();
  79. const double took_time = (end_time - start_time) / 1000;
  80. const double alpha = 0.9;
  81. if (avg_time_per_layer[i].time == 0) {
  82. avg_time_per_layer[i].layer_id = i;
  83. avg_time_per_layer[i].layer_type = l.type;
  84. avg_time_per_layer[i].time = took_time;
  85. }
  86. else avg_time_per_layer[i].time = avg_time_per_layer[i].time * alpha + took_time * (1 - alpha);
  87. sorted_avg_time_per_layer[i] = avg_time_per_layer[i];
  88. printf("\n layer %d - type: %d - %lf ms - avg_time %lf ms \n", i, l.type, took_time, avg_time_per_layer[i].time);
  89. }
  90. if(net.wait_stream)
  91. cudaStreamSynchronize(get_cuda_stream());
  92. state.input = l.output_gpu;
  93. //cudaDeviceSynchronize();
  94. /*
  95. cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
  96. if (l.out_w >= 0 && l.out_h >= 1 && l.c >= 3) {
  97. int j;
  98. for (j = 0; j < l.out_c; ++j) {
  99. image img = make_image(l.out_w, l.out_h, 3);
  100. memcpy(img.data, l.output + l.out_w*l.out_h*j, l.out_w*l.out_h * 1 * sizeof(float));
  101. memcpy(img.data + l.out_w*l.out_h * 1, l.output + l.out_w*l.out_h*j, l.out_w*l.out_h * 1 * sizeof(float));
  102. memcpy(img.data + l.out_w*l.out_h * 2, l.output + l.out_w*l.out_h*j, l.out_w*l.out_h * 1 * sizeof(float));
  103. char buff[256];
  104. sprintf(buff, "layer-%d slice-%d", i, j);
  105. show_image(img, buff);
  106. save_image(img, buff);
  107. }
  108. cvWaitKey(0); // wait press-key in console
  109. cvDestroyAllWindows();
  110. }
  111. */
  112. }
  113. if (net.benchmark_layers) {
  114. printf("\n\nSorted by time:\n");
  115. qsort(sorted_avg_time_per_layer, net.n, sizeof(time_benchmark_layers), time_comparator);
  116. for (i = 0; i < net.n; ++i) {
  117. //printf("layer %d - type: %d - avg_time %lf ms \n", avg_time_per_layer[i].layer_id, avg_time_per_layer[i].layer_type, avg_time_per_layer[i].time);
  118. printf("%d - layer %d - type: %d - avg_time %lf ms \n", i, sorted_avg_time_per_layer[i].layer_id, sorted_avg_time_per_layer[i].layer_type, sorted_avg_time_per_layer[i].time);
  119. }
  120. }
  121. //cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions
  122. //cudaDeviceSynchronize();
  123. }
  124. void backward_network_gpu(network net, network_state state)
  125. {
  126. state.workspace = net.workspace;
  127. int i;
  128. float * original_input = state.input;
  129. float * original_delta = state.delta;
  130. for(i = net.n-1; i >= 0; --i){
  131. state.index = i;
  132. layer l = net.layers[i];
  133. if (l.stopbackward) break;
  134. if(i == 0){
  135. state.input = original_input;
  136. state.delta = original_delta;
  137. }else{
  138. layer prev = net.layers[i-1];
  139. state.input = prev.output_gpu;
  140. state.delta = prev.delta_gpu;
  141. if (net.optimized_memory && !prev.keep_delta_gpu) {
  142. state.delta = net.state_delta_gpu;
  143. }
  144. }
  145. if (l.onlyforward) continue;
  146. l.backward_gpu(l, state);
  147. if (i != 0) {
  148. layer prev = net.layers[i - 1];
  149. if (net.optimized_memory && state.delta && !prev.keep_delta_gpu) {
  150. if (prev.delta_gpu != state.delta) simple_copy_ongpu(prev.outputs*prev.batch, state.delta, prev.delta_gpu);
  151. fill_ongpu(prev.outputs*prev.batch, 0, net.state_delta_gpu, 1);
  152. }
  153. }
  154. /*
  155. if(i != 0)
  156. {
  157. layer l = net.layers[i - 1];
  158. int state_delta_nan_inf = is_nan_or_inf(state.delta, l.outputs * l.batch);
  159. int state_input_nan_inf = is_nan_or_inf(state.input, l.outputs * l.batch);
  160. printf("\n i - %d is_nan_or_inf(s.delta) = %d \n", i, state_delta_nan_inf);
  161. printf(" i - %d is_nan_or_inf(s.input) = %d \n", i, state_input_nan_inf);
  162. if (state_delta_nan_inf || state_input_nan_inf) { printf(" found "); getchar(); }
  163. }
  164. */
  165. }
  166. }
  167. void update_network_gpu(network net)
  168. {
  169. cuda_set_device(net.gpu_index);
  170. const int iteration_num = (*net.seen) / (net.batch * net.subdivisions);
  171. int i;
  172. int update_batch = net.batch*net.subdivisions * get_sequence_value(net);
  173. float rate = get_current_rate(net);
  174. for(i = 0; i < net.n; ++i){
  175. layer l = net.layers[i];
  176. l.t = get_current_batch(net);
  177. if (iteration_num > (net.max_batches * 1 / 2)) l.deform = 0;
  178. if(l.update_gpu){
  179. l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
  180. }
  181. }
  182. }
  183. void forward_backward_network_gpu(network net, float *x, float *y)
  184. {
  185. network_state state;
  186. state.index = 0;
  187. state.net = net;
  188. int x_size = get_network_input_size(net)*net.batch;
  189. int y_size = get_network_output_size(net)*net.batch;
  190. if(net.layers[net.n-1].truths) y_size = net.layers[net.n-1].truths*net.batch;
  191. if(!*net.input_gpu){
  192. *net.input_gpu = cuda_make_array(x, x_size);
  193. *net.truth_gpu = cuda_make_array(y, y_size);
  194. }else{
  195. cuda_push_array(*net.input_gpu, x, x_size);
  196. cuda_push_array(*net.truth_gpu, y, y_size);
  197. }
  198. state.input = *net.input_gpu;
  199. state.delta = 0;
  200. state.truth = *net.truth_gpu;
  201. state.train = 1;
  202. #if defined(CUDNN_HALF) && defined(CUDNN)
  203. int i;
  204. for (i = 0; i < net.n; ++i) {
  205. layer l = net.layers[i];
  206. if (net.cudnn_half){
  207. if (l.type == CONVOLUTIONAL && l.weights_gpu && l.weights_gpu16) {
  208. assert((l.nweights) > 0);
  209. cuda_convert_f32_to_f16(l.weights_gpu, l.nweights, l.weights_gpu16);
  210. }
  211. else if (l.type == CRNN && l.input_layer->weights_gpu && l.input_layer->weights_gpu16) {
  212. assert((l.input_layer->c*l.input_layer->n*l.input_layer->size*l.input_layer->size) > 0);
  213. cuda_convert_f32_to_f16(l.input_layer->weights_gpu, l.input_layer->nweights, l.input_layer->weights_gpu16);
  214. cuda_convert_f32_to_f16(l.self_layer->weights_gpu, l.self_layer->nweights, l.self_layer->weights_gpu16);
  215. cuda_convert_f32_to_f16(l.output_layer->weights_gpu, l.output_layer->nweights, l.output_layer->weights_gpu16);
  216. }
  217. else if (l.type == CONV_LSTM && l.wf->weights_gpu && l.wf->weights_gpu16) {
  218. assert((l.wf->c * l.wf->n * l.wf->size * l.wf->size) > 0);
  219. if (l.peephole) {
  220. cuda_convert_f32_to_f16(l.vf->weights_gpu, l.vf->nweights, l.vf->weights_gpu16);
  221. cuda_convert_f32_to_f16(l.vi->weights_gpu, l.vi->nweights, l.vi->weights_gpu16);
  222. cuda_convert_f32_to_f16(l.vo->weights_gpu, l.vo->nweights, l.vo->weights_gpu16);
  223. }
  224. cuda_convert_f32_to_f16(l.wf->weights_gpu, l.wf->nweights, l.wf->weights_gpu16);
  225. cuda_convert_f32_to_f16(l.wi->weights_gpu, l.wi->nweights, l.wi->weights_gpu16);
  226. cuda_convert_f32_to_f16(l.wg->weights_gpu, l.wg->nweights, l.wg->weights_gpu16);
  227. cuda_convert_f32_to_f16(l.wo->weights_gpu, l.wo->nweights, l.wo->weights_gpu16);
  228. cuda_convert_f32_to_f16(l.uf->weights_gpu, l.uf->nweights, l.uf->weights_gpu16);
  229. cuda_convert_f32_to_f16(l.ui->weights_gpu, l.ui->nweights, l.ui->weights_gpu16);
  230. cuda_convert_f32_to_f16(l.ug->weights_gpu, l.ug->nweights, l.ug->weights_gpu16);
  231. cuda_convert_f32_to_f16(l.uo->weights_gpu, l.uo->nweights, l.uo->weights_gpu16);
  232. }
  233. }
  234. }
  235. #endif
  236. forward_network_gpu(net, state);
  237. //cudaStreamSynchronize(get_cuda_stream());
  238. backward_network_gpu(net, state);
  239. }
  240. float train_network_datum_gpu(network net, float *x, float *y)
  241. {
  242. *net.seen += net.batch;
  243. forward_backward_network_gpu(net, x, y);
  244. float error = get_network_cost(net);
  245. //if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net);
  246. const int sequence = get_sequence_value(net);
  247. if (((*net.seen) / net.batch) % (net.subdivisions*sequence) == 0) update_network_gpu(net);
  248. return error;
  249. }
  250. typedef struct {
  251. network net;
  252. data d;
  253. float *err;
  254. } train_args;
  255. void *train_thread(void *ptr)
  256. {
  257. train_args args = *(train_args*)ptr;
  258. free(ptr);
  259. cuda_set_device(args.net.gpu_index);
  260. *args.err = train_network(args.net, args.d);
  261. return 0;
  262. }
  263. pthread_t train_network_in_thread(network net, data d, float *err)
  264. {
  265. pthread_t thread;
  266. train_args *ptr = (train_args *)calloc(1, sizeof(train_args));
  267. ptr->net = net;
  268. ptr->d = d;
  269. ptr->err = err;
  270. if(pthread_create(&thread, 0, train_thread, ptr)) error("Thread creation failed");
  271. return thread;
  272. }
  273. void pull_updates(layer l)
  274. {
  275. if(l.type == CONVOLUTIONAL){
  276. cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
  277. cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
  278. if(l.scale_updates) cuda_pull_array(l.scale_updates_gpu, l.scale_updates, l.n);
  279. } else if(l.type == CONNECTED){
  280. cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
  281. cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
  282. }
  283. }
  284. void push_updates(layer l)
  285. {
  286. if(l.type == CONVOLUTIONAL){
  287. cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
  288. cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
  289. if(l.scale_updates) cuda_push_array(l.scale_updates_gpu, l.scale_updates, l.n);
  290. } else if(l.type == CONNECTED){
  291. cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
  292. cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
  293. }
  294. }
  295. void update_layer(layer l, network net)
  296. {
  297. int update_batch = net.batch*net.subdivisions;
  298. float rate = get_current_rate(net);
  299. l.t = get_current_batch(net);
  300. if(l.update_gpu){
  301. l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
  302. }
  303. }
  304. void merge_weights(layer l, layer base)
  305. {
  306. if (l.type == CONVOLUTIONAL) {
  307. axpy_cpu(l.n, 1, l.biases, 1, base.biases, 1);
  308. axpy_cpu(l.nweights, 1, l.weights, 1, base.weights, 1);
  309. if (l.scales) {
  310. axpy_cpu(l.n, 1, l.scales, 1, base.scales, 1);
  311. }
  312. } else if(l.type == CONNECTED) {
  313. axpy_cpu(l.outputs, 1, l.biases, 1, base.biases, 1);
  314. axpy_cpu(l.outputs*l.inputs, 1, l.weights, 1, base.weights, 1);
  315. }
  316. }
  317. void scale_weights(layer l, float s)
  318. {
  319. if (l.type == CONVOLUTIONAL) {
  320. scal_cpu(l.n, s, l.biases, 1);
  321. scal_cpu(l.nweights, s, l.weights, 1);
  322. if (l.scales) {
  323. scal_cpu(l.n, s, l.scales, 1);
  324. }
  325. } else if(l.type == CONNECTED) {
  326. scal_cpu(l.outputs, s, l.biases, 1);
  327. scal_cpu(l.outputs*l.inputs, s, l.weights, 1);
  328. }
  329. }
  330. void pull_weights(layer l)
  331. {
  332. if(l.type == CONVOLUTIONAL){
  333. cuda_pull_array(l.biases_gpu, l.biases, l.n);
  334. cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
  335. if(l.scales) cuda_pull_array(l.scales_gpu, l.scales, l.n);
  336. } else if(l.type == CONNECTED){
  337. cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
  338. cuda_pull_array(l.weights_gpu, l.weights, l.outputs*l.inputs);
  339. }
  340. }
  341. void push_weights(layer l)
  342. {
  343. if(l.type == CONVOLUTIONAL){
  344. cuda_push_array(l.biases_gpu, l.biases, l.n);
  345. cuda_push_array(l.weights_gpu, l.weights, l.nweights);
  346. if(l.scales) cuda_push_array(l.scales_gpu, l.scales, l.n);
  347. } else if(l.type == CONNECTED){
  348. cuda_push_array(l.biases_gpu, l.biases, l.outputs);
  349. cuda_push_array(l.weights_gpu, l.weights, l.outputs*l.inputs);
  350. }
  351. }
  352. void distribute_weights(layer l, layer base)
  353. {
  354. if(l.type == CONVOLUTIONAL){
  355. cuda_push_array(l.biases_gpu, base.biases, l.n);
  356. cuda_push_array(l.weights_gpu, base.weights, l.nweights);
  357. if(base.scales) cuda_push_array(l.scales_gpu, base.scales, l.n);
  358. } else if(l.type == CONNECTED){
  359. cuda_push_array(l.biases_gpu, base.biases, l.outputs);
  360. cuda_push_array(l.weights_gpu, base.weights, l.outputs*l.inputs);
  361. }
  362. }
  363. void merge_updates(layer l, layer base)
  364. {
  365. if (l.type == CONVOLUTIONAL) {
  366. axpy_cpu(l.n, 1, l.bias_updates, 1, base.bias_updates, 1);
  367. axpy_cpu(l.nweights, 1, l.weight_updates, 1, base.weight_updates, 1);
  368. if (l.scale_updates) {
  369. axpy_cpu(l.n, 1, l.scale_updates, 1, base.scale_updates, 1);
  370. }
  371. } else if(l.type == CONNECTED) {
  372. axpy_cpu(l.outputs, 1, l.bias_updates, 1, base.bias_updates, 1);
  373. axpy_cpu(l.outputs*l.inputs, 1, l.weight_updates, 1, base.weight_updates, 1);
  374. }
  375. }
  376. void distribute_updates(layer l, layer base)
  377. {
  378. if(l.type == CONVOLUTIONAL){
  379. cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.n);
  380. cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.nweights);
  381. if(base.scale_updates) cuda_push_array(l.scale_updates_gpu, base.scale_updates, l.n);
  382. } else if(l.type == CONNECTED){
  383. cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.outputs);
  384. cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.outputs*l.inputs);
  385. }
  386. }
  387. void sync_layer(network *nets, int n, int j)
  388. {
  389. //printf("Syncing layer %d\n", j);
  390. int i;
  391. network net = nets[0];
  392. layer base = net.layers[j];
  393. cuda_set_device(net.gpu_index);
  394. pull_weights(base);
  395. for (i = 1; i < n; ++i) {
  396. cuda_set_device(nets[i].gpu_index);
  397. layer l = nets[i].layers[j];
  398. pull_weights(l);
  399. merge_weights(l, base);
  400. }
  401. scale_weights(base, 1./n);
  402. for (i = 0; i < n; ++i) {
  403. cuda_set_device(nets[i].gpu_index);
  404. layer l = nets[i].layers[j];
  405. distribute_weights(l, base);
  406. }
  407. //printf("Done syncing layer %d\n", j);
  408. }
  409. typedef struct{
  410. network *nets;
  411. int n;
  412. int j;
  413. } sync_args;
  414. void *sync_layer_thread(void *ptr)
  415. {
  416. sync_args args = *(sync_args*)ptr;
  417. sync_layer(args.nets, args.n, args.j);
  418. free(ptr);
  419. return 0;
  420. }
  421. pthread_t sync_layer_in_thread(network *nets, int n, int j)
  422. {
  423. pthread_t thread;
  424. sync_args *ptr = (sync_args *)calloc(1, sizeof(sync_args));
  425. ptr->nets = nets;
  426. ptr->n = n;
  427. ptr->j = j;
  428. if(pthread_create(&thread, 0, sync_layer_thread, ptr)) error("Thread creation failed");
  429. return thread;
  430. }
  431. void sync_nets(network *nets, int n, int interval)
  432. {
  433. int j;
  434. int layers = nets[0].n;
  435. pthread_t *threads = (pthread_t *) calloc(layers, sizeof(pthread_t));
  436. *nets[0].seen += interval * (n-1) * nets[0].batch * nets[0].subdivisions;
  437. for (j = 0; j < n; ++j){
  438. *nets[j].seen = *nets[0].seen;
  439. }
  440. for (j = 0; j < layers; ++j) {
  441. threads[j] = sync_layer_in_thread(nets, n, j);
  442. }
  443. for (j = 0; j < layers; ++j) {
  444. pthread_join(threads[j], 0);
  445. }
  446. free(threads);
  447. }
  448. float train_networks(network *nets, int n, data d, int interval)
  449. {
  450. int i;
  451. #ifdef _DEBUG
  452. int batch = nets[0].batch;
  453. int subdivisions = nets[0].subdivisions;
  454. assert(batch * subdivisions * n == d.X.rows);
  455. #endif
  456. pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
  457. float *errors = (float *) calloc(n, sizeof(float));
  458. float sum = 0;
  459. for(i = 0; i < n; ++i){
  460. data p = get_data_part(d, i, n);
  461. threads[i] = train_network_in_thread(nets[i], p, errors + i);
  462. }
  463. for(i = 0; i < n; ++i){
  464. pthread_join(threads[i], 0);
  465. //printf("%f\n", errors[i]);
  466. sum += errors[i];
  467. }
  468. //cudaDeviceSynchronize();
  469. if (get_current_batch(nets[0]) % interval == 0) {
  470. printf("Syncing... ");
  471. fflush(stdout);
  472. sync_nets(nets, n, interval);
  473. printf("Done!\n");
  474. }
  475. //cudaDeviceSynchronize();
  476. free(threads);
  477. free(errors);
  478. return (float)sum/(n);
  479. }
  480. float *get_network_output_layer_gpu(network net, int i)
  481. {
  482. layer l = net.layers[i];
  483. if(l.type != REGION) cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
  484. return l.output;
  485. }
  486. float *get_network_output_gpu(network net)
  487. {
  488. int i;
  489. for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break;
  490. return get_network_output_layer_gpu(net, i);
  491. }
  492. float *network_predict_gpu(network net, float *input)
  493. {
  494. if (net.gpu_index != cuda_get_device())
  495. cuda_set_device(net.gpu_index);
  496. int size = get_network_input_size(net) * net.batch;
  497. network_state state;
  498. state.index = 0;
  499. state.net = net;
  500. //state.input = cuda_make_array(input, size); // memory will be allocated in the parse_network_cfg_custom()
  501. state.input = net.input_state_gpu;
  502. memcpy(net.input_pinned_cpu, input, size * sizeof(float));
  503. cuda_push_array(state.input, net.input_pinned_cpu, size);
  504. state.truth = 0;
  505. state.train = 0;
  506. state.delta = 0;
  507. forward_network_gpu(net, state);
  508. float *out = get_network_output_gpu(net);
  509. //cuda_free(state.input); // will be freed in the free_network()
  510. return out;
  511. }