reval_voc_py3.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #!/usr/bin/env python
  2. # Adapt from ->
  3. # --------------------------------------------------------
  4. # Fast R-CNN
  5. # Copyright (c) 2015 Microsoft
  6. # Licensed under The MIT License [see LICENSE for details]
  7. # Written by Ross Girshick
  8. # --------------------------------------------------------
  9. # <- Written by Yaping Sun
  10. """Reval = re-eval. Re-evaluate saved detections."""
  11. import os, sys, argparse
  12. import numpy as np
  13. import _pickle as cPickle
  14. #import cPickle
  15. from voc_eval_py3 import voc_eval
  16. def parse_args():
  17. """
  18. Parse input arguments
  19. """
  20. parser = argparse.ArgumentParser(description='Re-evaluate results')
  21. parser.add_argument('output_dir', nargs=1, help='results directory',
  22. type=str)
  23. parser.add_argument('--voc_dir', dest='voc_dir', default='data/VOCdevkit', type=str)
  24. parser.add_argument('--year', dest='year', default='2017', type=str)
  25. parser.add_argument('--image_set', dest='image_set', default='test', type=str)
  26. parser.add_argument('--classes', dest='class_file', default='data/voc.names', type=str)
  27. if len(sys.argv) == 1:
  28. parser.print_help()
  29. sys.exit(1)
  30. args = parser.parse_args()
  31. return args
  32. def get_voc_results_file_template(image_set, out_dir = 'results'):
  33. filename = 'comp4_det_' + image_set + '_{:s}.txt'
  34. path = os.path.join(out_dir, filename)
  35. return path
  36. def do_python_eval(devkit_path, year, image_set, classes, output_dir = 'results'):
  37. annopath = os.path.join(
  38. devkit_path,
  39. 'VOC' + year,
  40. 'Annotations',
  41. '{}.xml')
  42. imagesetfile = os.path.join(
  43. devkit_path,
  44. 'VOC' + year,
  45. 'ImageSets',
  46. 'Main',
  47. image_set + '.txt')
  48. cachedir = os.path.join(devkit_path, 'annotations_cache')
  49. aps = []
  50. # The PASCAL VOC metric changed in 2010
  51. use_07_metric = True if int(year) < 2010 else False
  52. print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
  53. print('devkit_path=',devkit_path,', year = ',year)
  54. if not os.path.isdir(output_dir):
  55. os.mkdir(output_dir)
  56. for i, cls in enumerate(classes):
  57. if cls == '__background__':
  58. continue
  59. filename = get_voc_results_file_template(image_set).format(cls)
  60. rec, prec, ap = voc_eval(
  61. filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
  62. use_07_metric=use_07_metric)
  63. aps += [ap]
  64. print('AP for {} = {:.4f}'.format(cls, ap))
  65. with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
  66. cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
  67. print('Mean AP = {:.4f}'.format(np.mean(aps)))
  68. print('~~~~~~~~')
  69. print('Results:')
  70. for ap in aps:
  71. print('{:.3f}'.format(ap))
  72. print('{:.3f}'.format(np.mean(aps)))
  73. print('~~~~~~~~')
  74. print('')
  75. print('--------------------------------------------------------------')
  76. print('Results computed with the **unofficial** Python eval code.')
  77. print('Results should be very close to the official MATLAB eval code.')
  78. print('-- Thanks, The Management')
  79. print('--------------------------------------------------------------')
  80. if __name__ == '__main__':
  81. args = parse_args()
  82. output_dir = os.path.abspath(args.output_dir[0])
  83. with open(args.class_file, 'r') as f:
  84. lines = f.readlines()
  85. classes = [t.strip('\n') for t in lines]
  86. print('Evaluating detections')
  87. do_python_eval(args.voc_dir, args.year, args.image_set, classes, output_dir)