gen_anchors.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. '''
  2. Created on Feb 20, 2017
  3. @author: jumabek
  4. '''
  5. from os import listdir
  6. from os.path import isfile, join
  7. import argparse
  8. #import cv2
  9. import numpy as np
  10. import sys
  11. import os
  12. import shutil
  13. import random
  14. import math
  15. width_in_cfg_file = 416.
  16. height_in_cfg_file = 416.
  17. def IOU(x,centroids):
  18. similarities = []
  19. k = len(centroids)
  20. for centroid in centroids:
  21. c_w,c_h = centroid
  22. w,h = x
  23. if c_w>=w and c_h>=h:
  24. similarity = w*h/(c_w*c_h)
  25. elif c_w>=w and c_h<=h:
  26. similarity = w*c_h/(w*h + (c_w-w)*c_h)
  27. elif c_w<=w and c_h>=h:
  28. similarity = c_w*h/(w*h + c_w*(c_h-h))
  29. else: #means both w,h are bigger than c_w and c_h respectively
  30. similarity = (c_w*c_h)/(w*h)
  31. similarities.append(similarity) # will become (k,) shape
  32. return np.array(similarities)
  33. def avg_IOU(X,centroids):
  34. n,d = X.shape
  35. sum = 0.
  36. for i in range(X.shape[0]):
  37. #note IOU() will return array which contains IoU for each centroid and X[i] // slightly ineffective, but I am too lazy
  38. sum+= max(IOU(X[i],centroids))
  39. return sum/n
  40. def write_anchors_to_file(centroids,X,anchor_file):
  41. f = open(anchor_file,'w')
  42. anchors = centroids.copy()
  43. print(anchors.shape)
  44. for i in range(anchors.shape[0]):
  45. anchors[i][0]*=width_in_cfg_file/32.
  46. anchors[i][1]*=height_in_cfg_file/32.
  47. widths = anchors[:,0]
  48. sorted_indices = np.argsort(widths)
  49. print('Anchors = ', anchors[sorted_indices])
  50. for i in sorted_indices[:-1]:
  51. f.write('%0.2f,%0.2f, '%(anchors[i,0],anchors[i,1]))
  52. #there should not be comma after last anchor, that's why
  53. f.write('%0.2f,%0.2f\n'%(anchors[sorted_indices[-1:],0],anchors[sorted_indices[-1:],1]))
  54. f.write('%f\n'%(avg_IOU(X,centroids)))
  55. print()
  56. def kmeans(X,centroids,eps,anchor_file):
  57. N = X.shape[0]
  58. iterations = 0
  59. k,dim = centroids.shape
  60. prev_assignments = np.ones(N)*(-1)
  61. iter = 0
  62. old_D = np.zeros((N,k))
  63. while True:
  64. D = []
  65. iter+=1
  66. for i in range(N):
  67. d = 1 - IOU(X[i],centroids)
  68. D.append(d)
  69. D = np.array(D) # D.shape = (N,k)
  70. print("iter {}: dists = {}".format(iter,np.sum(np.abs(old_D-D))))
  71. #assign samples to centroids
  72. assignments = np.argmin(D,axis=1)
  73. if (assignments == prev_assignments).all() :
  74. print("Centroids = ",centroids)
  75. write_anchors_to_file(centroids,X,anchor_file)
  76. return
  77. #calculate new centroids
  78. centroid_sums=np.zeros((k,dim),np.float)
  79. for i in range(N):
  80. centroid_sums[assignments[i]]+=X[i]
  81. for j in range(k):
  82. centroids[j] = centroid_sums[j]/(np.sum(assignments==j))
  83. prev_assignments = assignments.copy()
  84. old_D = D.copy()
  85. def main(argv):
  86. parser = argparse.ArgumentParser()
  87. parser.add_argument('-filelist', default = '\\path\\to\\voc\\filelist\\train.txt',
  88. help='path to filelist\n' )
  89. parser.add_argument('-output_dir', default = 'generated_anchors/anchors', type = str,
  90. help='Output anchor directory\n' )
  91. parser.add_argument('-num_clusters', default = 0, type = int,
  92. help='number of clusters\n' )
  93. args = parser.parse_args()
  94. if not os.path.exists(args.output_dir):
  95. os.mkdir(args.output_dir)
  96. f = open(args.filelist)
  97. lines = [line.rstrip('\n') for line in f.readlines()]
  98. annotation_dims = []
  99. size = np.zeros((1,1,3))
  100. for line in lines:
  101. #line = line.replace('images','labels')
  102. #line = line.replace('img1','labels')
  103. line = line.replace('JPEGImages','labels')
  104. line = line.replace('.jpg','.txt')
  105. line = line.replace('.png','.txt')
  106. print(line)
  107. f2 = open(line)
  108. for line in f2.readlines():
  109. line = line.rstrip('\n')
  110. w,h = line.split(' ')[3:]
  111. #print(w,h)
  112. annotation_dims.append(tuple(map(float,(w,h))))
  113. annotation_dims = np.array(annotation_dims)
  114. eps = 0.005
  115. if args.num_clusters == 0:
  116. for num_clusters in range(1,11): #we make 1 through 10 clusters
  117. anchor_file = join( args.output_dir,'anchors%d.txt'%(num_clusters))
  118. indices = [ random.randrange(annotation_dims.shape[0]) for i in range(num_clusters)]
  119. centroids = annotation_dims[indices]
  120. kmeans(annotation_dims,centroids,eps,anchor_file)
  121. print('centroids.shape', centroids.shape)
  122. else:
  123. anchor_file = join( args.output_dir,'anchors%d.txt'%(args.num_clusters))
  124. indices = [ random.randrange(annotation_dims.shape[0]) for i in range(args.num_clusters)]
  125. centroids = annotation_dims[indices]
  126. kmeans(annotation_dims,centroids,eps,anchor_file)
  127. print('centroids.shape', centroids.shape)
  128. if __name__=="__main__":
  129. main(sys.argv)