evaluate_hr.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. """
  2. HR (High-Resolution) evaluation. We found using numpy is very slow for high resolution, so we moved it to PyTorch using CUDA.
  3. Note, the script only does evaluation. You will need to first inference yourself and save the results to disk
  4. Expected directory format for both prediction and ground-truth is:
  5. videomatte_1920x1080
  6. ├── videomatte_motion
  7. ├── pha
  8. ├── 0000
  9. ├── 0000.png
  10. ├── fgr
  11. ├── 0000
  12. ├── 0000.png
  13. ├── videomatte_static
  14. ├── pha
  15. ├── 0000
  16. ├── 0000.png
  17. ├── fgr
  18. ├── 0000
  19. ├── 0000.png
  20. Prediction must have the exact file structure and file name as the ground-truth,
  21. meaning that if the ground-truth is png/jpg, prediction should be png/jpg.
  22. Example usage:
  23. python evaluate.py \
  24. --pred-dir pred/videomatte_1920x1080 \
  25. --true-dir true/videomatte_1920x1080
  26. An excel sheet with evaluation results will be written to "pred/videomatte_1920x1080/videomatte_1920x1080.xlsx"
  27. """
  28. import argparse
  29. import os
  30. import cv2
  31. import kornia
  32. import numpy as np
  33. import xlsxwriter
  34. import torch
  35. from concurrent.futures import ThreadPoolExecutor
  36. from tqdm import tqdm
  37. class Evaluator:
  38. def __init__(self):
  39. self.parse_args()
  40. self.init_metrics()
  41. self.evaluate()
  42. self.write_excel()
  43. def parse_args(self):
  44. parser = argparse.ArgumentParser()
  45. parser.add_argument('--pred-dir', type=str, required=True)
  46. parser.add_argument('--true-dir', type=str, required=True)
  47. parser.add_argument('--num-workers', type=int, default=48)
  48. parser.add_argument('--metrics', type=str, nargs='+', default=[
  49. 'pha_mad', 'pha_mse', 'pha_grad', 'pha_dtssd', 'fgr_mse'])
  50. self.args = parser.parse_args()
  51. def init_metrics(self):
  52. self.mad = MetricMAD()
  53. self.mse = MetricMSE()
  54. self.grad = MetricGRAD()
  55. self.dtssd = MetricDTSSD()
  56. def evaluate(self):
  57. tasks = []
  58. position = 0
  59. with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor:
  60. for dataset in sorted(os.listdir(self.args.pred_dir)):
  61. if os.path.isdir(os.path.join(self.args.pred_dir, dataset)):
  62. for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))):
  63. future = executor.submit(self.evaluate_worker, dataset, clip, position)
  64. tasks.append((dataset, clip, future))
  65. position += 1
  66. self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks]
  67. def write_excel(self):
  68. workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx'))
  69. summarysheet = workbook.add_worksheet('summary')
  70. metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()]
  71. for i, metric in enumerate(self.results[0][2].keys()):
  72. summarysheet.write(i, 0, metric)
  73. summarysheet.write(i, 1, f'={metric}!B2')
  74. for row, (dataset, clip, metrics) in enumerate(self.results):
  75. for metricsheet, metric in zip(metricsheets, metrics.values()):
  76. # Write the header
  77. if row == 0:
  78. metricsheet.write(1, 0, 'Average')
  79. metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)')
  80. for col in range(len(metric)):
  81. metricsheet.write(0, col + 2, col)
  82. colname = xlsxwriter.utility.xl_col_to_name(col + 2)
  83. metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)')
  84. metricsheet.write(row + 2, 0, dataset)
  85. metricsheet.write(row + 2, 1, clip)
  86. metricsheet.write_row(row + 2, 2, metric)
  87. workbook.close()
  88. def evaluate_worker(self, dataset, clip, position):
  89. framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha')))
  90. metrics = {metric_name : [] for metric_name in self.args.metrics}
  91. pred_pha_tm1 = None
  92. true_pha_tm1 = None
  93. for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)):
  94. true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE)
  95. pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE)
  96. true_pha = torch.from_numpy(true_pha).cuda(non_blocking=True).float().div_(255)
  97. pred_pha = torch.from_numpy(pred_pha).cuda(non_blocking=True).float().div_(255)
  98. if 'pha_mad' in self.args.metrics:
  99. metrics['pha_mad'].append(self.mad(pred_pha, true_pha))
  100. if 'pha_mse' in self.args.metrics:
  101. metrics['pha_mse'].append(self.mse(pred_pha, true_pha))
  102. if 'pha_grad' in self.args.metrics:
  103. metrics['pha_grad'].append(self.grad(pred_pha, true_pha))
  104. if 'pha_conn' in self.args.metrics:
  105. metrics['pha_conn'].append(self.conn(pred_pha, true_pha))
  106. if 'pha_dtssd' in self.args.metrics:
  107. if i == 0:
  108. metrics['pha_dtssd'].append(0)
  109. else:
  110. metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1))
  111. pred_pha_tm1 = pred_pha
  112. true_pha_tm1 = true_pha
  113. if 'fgr_mse' in self.args.metrics:
  114. true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR)
  115. pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR)
  116. true_fgr = torch.from_numpy(true_fgr).float().div_(255)
  117. pred_fgr = torch.from_numpy(pred_fgr).float().div_(255)
  118. true_msk = true_pha > 0
  119. metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk]))
  120. return metrics
  121. class MetricMAD:
  122. def __call__(self, pred, true):
  123. return (pred - true).abs_().mean() * 1e3
  124. class MetricMSE:
  125. def __call__(self, pred, true):
  126. return ((pred - true) ** 2).mean() * 1e3
  127. class MetricGRAD:
  128. def __init__(self, sigma=1.4):
  129. self.filter_x, self.filter_y = self.gauss_filter(sigma)
  130. self.filter_x = torch.from_numpy(self.filter_x).unsqueeze(0).cuda()
  131. self.filter_y = torch.from_numpy(self.filter_y).unsqueeze(0).cuda()
  132. def __call__(self, pred, true):
  133. true_grad = self.gauss_gradient(true)
  134. pred_grad = self.gauss_gradient(pred)
  135. return ((true_grad - pred_grad) ** 2).sum() / 1000
  136. def gauss_gradient(self, img):
  137. img_filtered_x = kornia.filters.filter2D(img[None, None, :, :], self.filter_x, border_type='replicate')[0, 0]
  138. img_filtered_y = kornia.filters.filter2D(img[None, None, :, :], self.filter_y, border_type='replicate')[0, 0]
  139. return (img_filtered_x**2 + img_filtered_y**2).sqrt()
  140. @staticmethod
  141. def gauss_filter(sigma, epsilon=1e-2):
  142. half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon)))
  143. size = np.int(2 * half_size + 1)
  144. # create filter in x axis
  145. filter_x = np.zeros((size, size))
  146. for i in range(size):
  147. for j in range(size):
  148. filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian(
  149. j - half_size, sigma)
  150. # normalize filter
  151. norm = np.sqrt((filter_x**2).sum())
  152. filter_x = filter_x / norm
  153. filter_y = np.transpose(filter_x)
  154. return filter_x, filter_y
  155. @staticmethod
  156. def gaussian(x, sigma):
  157. return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
  158. @staticmethod
  159. def dgaussian(x, sigma):
  160. return -x * MetricGRAD.gaussian(x, sigma) / sigma**2
  161. class MetricDTSSD:
  162. def __call__(self, pred_t, pred_tm1, true_t, true_tm1):
  163. dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2
  164. dtSSD = dtSSD.sum() / true_t.numel()
  165. dtSSD = dtSSD.sqrt()
  166. return dtSSD * 1e2
  167. if __name__ == '__main__':
  168. Evaluator()