Peter Lin 53d74c6826 Merge pull request #227 from dcyoung/master | 1 year ago | |
---|---|---|
dataset | 2 years ago | |
documentation | 2 years ago | |
evaluation | 3 years ago | |
model | 3 years ago | |
LICENSE | 3 years ago | |
README.md | 3 years ago | |
README_zh_Hans.md | 3 years ago | |
hubconf.py | 3 years ago | |
inference.py | 3 years ago | |
inference_speed_test.py | 3 years ago | |
inference_utils.py | 1 year ago | |
requirements_inference.txt | 3 years ago | |
requirements_training.txt | 3 years ago | |
train.py | 3 years ago | |
train_config.py | 3 years ago | |
train_loss.py | 3 years ago |
English | 中文
Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves 4K 76FPS and HD 104FPS on an Nvidia GTX 1080 Ti GPU. The project was developed at ByteDance Inc.
Watch the showreel video (YouTube, Bilibili) to see the model's performance.
All footage in the video are available in Google Drive.
We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See inference documentation for more instructions.
Framework | Download | Notes |
PyTorch |
rvm_mobilenetv3.pth rvm_resnet50.pth |
Official weights for PyTorch. Doc |
TorchHub | Nothing to Download. | Easiest way to use our model in your PyTorch project. Doc |
TorchScript |
rvm_mobilenetv3_fp32.torchscript rvm_mobilenetv3_fp16.torchscript rvm_resnet50_fp32.torchscript rvm_resnet50_fp16.torchscript |
If inference on mobile, consider export int8 quantized models yourself. Doc |
ONNX |
rvm_mobilenetv3_fp32.onnx rvm_mobilenetv3_fp16.onnx rvm_resnet50_fp32.onnx rvm_resnet50_fp16.onnx |
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter. |
TensorFlow |
rvm_mobilenetv3_tf.zip rvm_resnet50_tf.zip |
TensorFlow 2 SavedModel. Doc |
TensorFlow.js |
rvm_mobilenetv3_tfjs_int8.zip |
Run the model on the web. Demo, Starter Code |
CoreML |
rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel |
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio . Doc, Exporter
|
All models are available in Google Drive and Baidu Pan (code: gym7).
Install dependencies:
pip install -r requirements_inference.txt
import torch
from model import MattingNetwork
model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
To convert videos, we provide a simple conversion API:
from inference import convert_video
convert_video(
model, # The model, can be on any device (cpu or cuda).
input_source='input.mp4', # A video file or an image sequence directory.
output_type='video', # Choose "video" or "png_sequence"
output_composition='com.mp4', # File path if video; directory path if png sequence.
output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction.
output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction.
output_video_mbps=4, # Output video mbps. Not needed for png sequence.
downsample_ratio=None, # A hyperparameter to adjust or use None for auto.
seq_chunk=12, # Process n frames at once for better parallelism.
)
reader = VideoReader('input.mp4', transform=ToTensor()) writer = VideoWriter('output.mp4', frame_rate=30)
bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background. rec = [None] * 4 # Initial recurrent states. downsample_ratio = 0.25 # Adjust based on your video.
with torch.no_grad():
for src in DataLoader(reader): # RGB tensor normalized to 0 ~ 1.
fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Cycle the recurrent states.
com = fgr * pha + bgr * (1 - pha) # Composite to green background.
writer.write(com) # Write frame.
5. The models and converter API are also available through TorchHub.
```python
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"
# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
Please see inference documentation for details on downsample_ratio
hyperparameter, more converter arguments, and more advanced usage.
Please refer to the training documentation to train and evaluate your own model.
Speed is measured with inference_speed_test.py
for reference.
GPU | dType | HD (1920x1080) | 4K (3840x2160) |
---|---|---|---|
RTX 3090 | FP16 | 172 FPS | 154 FPS |
RTX 2060 Super | FP16 | 134 FPS | 108 FPS |
GTX 1080 Ti | FP32 | 104 FPS | 74 FPS |
downsample_ratio=0.25
, 4K uses downsample_ratio=0.125
. All tests use batch size 1 and frame chunk 1.