generate_imagematte_with_background_video.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """
  2. python generate_imagematte_with_background_video.py \
  3. --imagematte-dir ../matting-data/Distinctions/test \
  4. --background-dir ../matting-data/BackgroundVideos_mp4/test \
  5. --resolution 512 \
  6. --out-dir ../matting-data/evaluation/distinction_motion_sd/ \
  7. --random-seed 11
  8. Seed:
  9. 10 - distinction-static
  10. 11 - distinction-motion
  11. 12 - adobe-static
  12. 13 - adobe-motion
  13. """
  14. import argparse
  15. import os
  16. import pims
  17. import numpy as np
  18. import random
  19. from multiprocessing import Pool
  20. from PIL import Image
  21. # from tqdm import tqdm
  22. from tqdm.contrib.concurrent import process_map
  23. from torchvision import transforms
  24. from torchvision.transforms import functional as F
  25. parser = argparse.ArgumentParser()
  26. parser.add_argument('--imagematte-dir', type=str, required=True)
  27. parser.add_argument('--background-dir', type=str, required=True)
  28. parser.add_argument('--num-samples', type=int, default=20)
  29. parser.add_argument('--num-frames', type=int, default=100)
  30. parser.add_argument('--resolution', type=int, required=True)
  31. parser.add_argument('--out-dir', type=str, required=True)
  32. parser.add_argument('--random-seed', type=int)
  33. parser.add_argument('--extension', type=str, default='.png')
  34. args = parser.parse_args()
  35. random.seed(args.random_seed)
  36. imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr'))
  37. random.shuffle(imagematte_filenames)
  38. background_filenames = [
  39. "0000.mp4",
  40. "0007.mp4",
  41. "0008.mp4",
  42. "0010.mp4",
  43. "0013.mp4",
  44. "0015.mp4",
  45. "0016.mp4",
  46. "0018.mp4",
  47. "0021.mp4",
  48. "0029.mp4",
  49. "0033.mp4",
  50. "0035.mp4",
  51. "0039.mp4",
  52. "0050.mp4",
  53. "0052.mp4",
  54. "0055.mp4",
  55. "0060.mp4",
  56. "0063.mp4",
  57. "0087.mp4",
  58. "0086.mp4",
  59. "0090.mp4",
  60. "0101.mp4",
  61. "0110.mp4",
  62. "0117.mp4",
  63. "0120.mp4",
  64. "0122.mp4",
  65. "0123.mp4",
  66. "0125.mp4",
  67. "0128.mp4",
  68. "0131.mp4",
  69. "0172.mp4",
  70. "0176.mp4",
  71. "0181.mp4",
  72. "0187.mp4",
  73. "0193.mp4",
  74. "0198.mp4",
  75. "0220.mp4",
  76. "0221.mp4",
  77. "0224.mp4",
  78. "0229.mp4",
  79. "0233.mp4",
  80. "0238.mp4",
  81. "0241.mp4",
  82. "0245.mp4",
  83. "0246.mp4"
  84. ]
  85. random.shuffle(background_filenames)
  86. def lerp(a, b, percentage):
  87. return a * (1 - percentage) + b * percentage
  88. def motion_affine(*imgs):
  89. config = dict(degrees=(-10, 10), translate=(0.1, 0.1),
  90. scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)
  91. angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)
  92. angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)
  93. T = len(imgs[0])
  94. variation_over_time = random.random()
  95. for t in range(T):
  96. percentage = (t / (T - 1)) * variation_over_time
  97. angle = lerp(angleA, angleB, percentage)
  98. transX = lerp(transXA, transXB, percentage)
  99. transY = lerp(transYA, transYB, percentage)
  100. scale = lerp(scaleA, scaleB, percentage)
  101. shearX = lerp(shearXA, shearXB, percentage)
  102. shearY = lerp(shearYA, shearYB, percentage)
  103. for img in imgs:
  104. img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)
  105. return imgs
  106. def process(i):
  107. imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)]
  108. background_filename = background_filenames[i % len(background_filenames)]
  109. bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filename))
  110. out_path = os.path.join(args.out_dir, str(i).zfill(4))
  111. os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True)
  112. os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True)
  113. os.makedirs(os.path.join(out_path, 'com'), exist_ok=True)
  114. os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True)
  115. with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \
  116. Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha:
  117. fgr = fgr.convert('RGB')
  118. pha = pha.convert('L')
  119. fgrs = [fgr] * args.num_frames
  120. phas = [pha] * args.num_frames
  121. fgrs, phas = motion_affine(fgrs, phas)
  122. for t in range(args.num_frames):
  123. fgr = fgrs[t]
  124. pha = phas[t]
  125. w, h = fgr.size
  126. scale = args.resolution / max(h, w)
  127. w, h = int(w * scale), int(h * scale)
  128. fgr = fgr.resize((w, h))
  129. pha = pha.resize((w, h))
  130. if h < args.resolution:
  131. pt = (args.resolution - h) // 2
  132. pb = args.resolution - h - pt
  133. else:
  134. pt = 0
  135. pb = 0
  136. if w < args.resolution:
  137. pl = (args.resolution - w) // 2
  138. pr = args.resolution - w - pl
  139. else:
  140. pl = 0
  141. pr = 0
  142. fgr = F.pad(fgr, [pl, pt, pr, pb])
  143. pha = F.pad(pha, [pl, pt, pr, pb])
  144. if i // len(imagematte_filenames) % 2 == 1:
  145. fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT)
  146. pha = pha.transpose(Image.FLIP_LEFT_RIGHT)
  147. fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension))
  148. pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension))
  149. bgr = Image.fromarray(bgrs[t]).convert('RGB')
  150. w, h = bgr.size
  151. scale = args.resolution / min(h, w)
  152. w, h = int(w * scale), int(h * scale)
  153. bgr = bgr.resize((w, h))
  154. bgr = F.center_crop(bgr, (args.resolution, args.resolution))
  155. bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension))
  156. pha = np.asarray(pha).astype(float)[:, :, None] / 255
  157. com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha)))
  158. com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension))
  159. if __name__ == '__main__':
  160. r = process_map(process, range(args.num_samples), max_workers=10)