inference.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. """
  2. python inference.py \
  3. --variant mobilenetv3 \
  4. --checkpoint "CHECKPOINT" \
  5. --device cuda \
  6. --input-source "input.mp4" \
  7. --output-type video \
  8. --output-composition "composition.mp4" \
  9. --output-alpha "alpha.mp4" \
  10. --output-foreground "foreground.mp4" \
  11. --output-video-mbps 4 \
  12. --seq-chunk 1
  13. """
  14. import torch
  15. import os
  16. from torch.utils.data import DataLoader
  17. from torchvision import transforms
  18. from typing import Optional, Tuple
  19. from tqdm.auto import tqdm
  20. from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
  21. def convert_video(model,
  22. input_source: str,
  23. input_resize: Optional[Tuple[int, int]] = None,
  24. downsample_ratio: Optional[float] = None,
  25. output_type: str = 'video',
  26. output_composition: Optional[str] = None,
  27. output_alpha: Optional[str] = None,
  28. output_foreground: Optional[str] = None,
  29. output_video_mbps: Optional[float] = None,
  30. seq_chunk: int = 1,
  31. num_workers: int = 0,
  32. progress: bool = True,
  33. device: Optional[str] = None,
  34. dtype: Optional[torch.dtype] = None):
  35. """
  36. Args:
  37. input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg.
  38. input_resize: If provided, the input are first resized to (w, h).
  39. downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
  40. output_type: Options: ["video", "png_sequence"].
  41. output_composition:
  42. The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
  43. If output_type == 'video', the composition has green screen background.
  44. If output_type == 'png_sequence'. the composition is RGBA png images.
  45. output_alpha: The alpha output from the model.
  46. output_foreground: The foreground output from the model.
  47. seq_chunk: Number of frames to process at once. Increase it for better parallelism.
  48. num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
  49. progress: Show progress bar.
  50. device: Only need to manually provide if model is a TorchScript freezed model.
  51. dtype: Only need to manually provide if model is a TorchScript freezed model.
  52. """
  53. assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
  54. assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'
  55. assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'
  56. assert seq_chunk >= 1, 'Sequence chunk must be >= 1'
  57. assert num_workers >= 0, 'Number of workers must be >= 0'
  58. assert output_video_mbps == None or output_type == 'video', 'Mbps is not available for png_sequence output.'
  59. # Initialize transform
  60. if input_resize is not None:
  61. transform = transforms.Compose([
  62. transforms.Resize(input_resize[::-1]),
  63. transforms.ToTensor()
  64. ])
  65. else:
  66. transform = transforms.ToTensor()
  67. # Initialize reader
  68. if os.path.isfile(input_source):
  69. source = VideoReader(input_source, transform)
  70. else:
  71. source = ImageSequenceReader(input_source, transform)
  72. reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
  73. # Initialize writers
  74. if output_type == 'video':
  75. frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
  76. output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
  77. if output_composition is not None:
  78. writer_com = VideoWriter(
  79. path=output_composition,
  80. frame_rate=frame_rate,
  81. bit_rate=int(output_video_mbps * 1000000))
  82. if output_alpha is not None:
  83. writer_pha = VideoWriter(
  84. path=output_alpha,
  85. frame_rate=frame_rate,
  86. bit_rate=int(output_video_mbps * 1000000))
  87. if output_foreground is not None:
  88. writer_fgr = VideoWriter(
  89. path=output_foreground,
  90. frame_rate=frame_rate,
  91. bit_rate=int(output_video_mbps * 1000000))
  92. else:
  93. if output_composition is not None:
  94. writer_com = ImageSequenceWriter(output_composition, 'png')
  95. if output_alpha is not None:
  96. writer_pha = VideoWriter(output_alpha, 'png')
  97. if output_foreground is not None:
  98. writer_fgr = VideoWriter(output_foreground, 'png')
  99. # Inference
  100. model = model.eval()
  101. if device is None or dtype is None:
  102. param = next(model.parameters())
  103. dtype = param.dtype
  104. device = param.device
  105. if (output_composition is not None) and (output_type == 'video'):
  106. bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
  107. try:
  108. with torch.no_grad():
  109. bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
  110. rec = [None] * 4
  111. for src in reader:
  112. if downsample_ratio is None:
  113. downsample_ratio = auto_downsample_ratio(*src.shape[2:])
  114. src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]
  115. fgr, pha, *rec = model(src, *rec, downsample_ratio)
  116. if output_foreground is not None:
  117. writer_fgr.write(fgr[0])
  118. if output_alpha is not None:
  119. writer_pha.write(pha[0])
  120. if output_composition is not None:
  121. if output_type == 'video':
  122. com = fgr * pha + bgr * (1 - pha)
  123. else:
  124. fgr = fgr * pha.gt(0)
  125. com = torch.cat([fgr, pha], dim=-3)
  126. writer_com.write(com[0])
  127. bar.update(src.size(1))
  128. finally:
  129. # Clean up
  130. if output_composition is not None:
  131. writer_com.close()
  132. if output_alpha is not None:
  133. writer_pha.close()
  134. if output_foreground is not None:
  135. writer_fgr.close()
  136. def auto_downsample_ratio(h, w):
  137. """
  138. Automatically find a downsample ratio so that the largest side of the resolution be 512px.
  139. """
  140. return min(512 / max(h, w), 1)
  141. class Converter:
  142. def __init__(self, variant: str, checkpoint: str, device: str):
  143. self.model = MattingNetwork(variant).eval().to(device)
  144. self.model.load_state_dict(torch.load(checkpoint, map_location=device))
  145. self.model = torch.jit.script(self.model)
  146. self.model = torch.jit.freeze(self.model)
  147. self.device = device
  148. def convert(self, *args, **kwargs):
  149. convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
  150. if __name__ == '__main__':
  151. import argparse
  152. from model import MattingNetwork
  153. parser = argparse.ArgumentParser()
  154. parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
  155. parser.add_argument('--checkpoint', type=str, required=True)
  156. parser.add_argument('--device', type=str, required=True)
  157. parser.add_argument('--input-source', type=str, required=True)
  158. parser.add_argument('--input-resize', type=int, default=None, nargs=2)
  159. parser.add_argument('--downsample-ratio', type=float)
  160. parser.add_argument('--output-composition', type=str)
  161. parser.add_argument('--output-alpha', type=str)
  162. parser.add_argument('--output-foreground', type=str)
  163. parser.add_argument('--output-type', type=str, required=True, choices=['video', 'png_sequence'])
  164. parser.add_argument('--output-video-mbps', type=int, default=1)
  165. parser.add_argument('--seq-chunk', type=int, default=1)
  166. parser.add_argument('--num-workers', type=int, default=0)
  167. parser.add_argument('--disable-progress', action='store_true')
  168. args = parser.parse_args()
  169. converter = Converter(args.variant, args.checkpoint, args.device)
  170. converter.convert(
  171. input_source=args.input_source,
  172. input_resize=args.input_resize,
  173. downsample_ratio=args.downsample_ratio,
  174. output_type=args.output_type,
  175. output_composition=args.output_composition,
  176. output_alpha=args.output_alpha,
  177. output_foreground=args.output_foreground,
  178. output_video_mbps=args.output_video_mbps,
  179. seq_chunk=args.seq_chunk,
  180. num_workers=args.num_workers,
  181. progress=not args.disable_progress
  182. )