reval_voc.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #!/usr/bin/env python
  2. # Adapt from ->
  3. # --------------------------------------------------------
  4. # Fast R-CNN
  5. # Copyright (c) 2015 Microsoft
  6. # Licensed under The MIT License [see LICENSE for details]
  7. # Written by Ross Girshick
  8. # --------------------------------------------------------
  9. # <- Written by Yaping Sun
  10. """Reval = re-eval. Re-evaluate saved detections."""
  11. import os, sys, argparse
  12. import numpy as np
  13. import cPickle
  14. from voc_eval import voc_eval
  15. def parse_args():
  16. """
  17. Parse input arguments
  18. """
  19. parser = argparse.ArgumentParser(description='Re-evaluate results')
  20. parser.add_argument('output_dir', nargs=1, help='results directory',
  21. type=str)
  22. parser.add_argument('--voc_dir', dest='voc_dir', default='data/VOCdevkit', type=str)
  23. parser.add_argument('--year', dest='year', default='2017', type=str)
  24. parser.add_argument('--image_set', dest='image_set', default='test', type=str)
  25. parser.add_argument('--classes', dest='class_file', default='data/voc.names', type=str)
  26. if len(sys.argv) == 1:
  27. parser.print_help()
  28. sys.exit(1)
  29. args = parser.parse_args()
  30. return args
  31. def get_voc_results_file_template(image_set, out_dir = 'results'):
  32. filename = 'comp4_det_' + image_set + '_{:s}.txt'
  33. path = os.path.join(out_dir, filename)
  34. return path
  35. def do_python_eval(devkit_path, year, image_set, classes, output_dir = 'results'):
  36. annopath = os.path.join(
  37. devkit_path,
  38. 'VOC' + year,
  39. 'Annotations',
  40. '{:s}.xml')
  41. imagesetfile = os.path.join(
  42. devkit_path,
  43. 'VOC' + year,
  44. 'ImageSets',
  45. 'Main',
  46. image_set + '.txt')
  47. cachedir = os.path.join(devkit_path, 'annotations_cache')
  48. aps = []
  49. # The PASCAL VOC metric changed in 2010
  50. use_07_metric = True if int(year) < 2010 else False
  51. print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
  52. if not os.path.isdir(output_dir):
  53. os.mkdir(output_dir)
  54. for i, cls in enumerate(classes):
  55. if cls == '__background__':
  56. continue
  57. filename = get_voc_results_file_template(image_set).format(cls)
  58. rec, prec, ap = voc_eval(
  59. filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
  60. use_07_metric=use_07_metric)
  61. aps += [ap]
  62. print('AP for {} = {:.4f}'.format(cls, ap))
  63. with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
  64. cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
  65. print('Mean AP = {:.4f}'.format(np.mean(aps)))
  66. print('~~~~~~~~')
  67. print('Results:')
  68. for ap in aps:
  69. print('{:.3f}'.format(ap))
  70. print('{:.3f}'.format(np.mean(aps)))
  71. print('~~~~~~~~')
  72. print('')
  73. print('--------------------------------------------------------------')
  74. print('Results computed with the **unofficial** Python eval code.')
  75. print('Results should be very close to the official MATLAB eval code.')
  76. print('-- Thanks, The Management')
  77. print('--------------------------------------------------------------')
  78. if __name__ == '__main__':
  79. args = parse_args()
  80. output_dir = os.path.abspath(args.output_dir[0])
  81. with open(args.class_file, 'r') as f:
  82. lines = f.readlines()
  83. classes = [t.strip('\n') for t in lines]
  84. print 'Evaluating detections'
  85. do_python_eval(args.voc_dir, args.year, args.image_set, classes, output_dir)