123456789101112131415161718192021222324252627282930313233343536373839 |
- """
- Loading model
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50")
- Converter API
- convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
- """
- dependencies = ['torch', 'torchvision']
- import torch
- from model import MattingNetwork
- def mobilenetv3(pretrained: bool = True, progress: bool = True):
- model = MattingNetwork('mobilenetv3')
- if pretrained:
- url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth'
- model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress))
- return model
- def resnet50(pretrained: bool = True, progress: bool = True):
- model = MattingNetwork('resnet50')
- if pretrained:
- url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth'
- model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress))
- return model
- def converter():
- try:
- from inference import convert_video
- return convert_video
- except ModuleNotFoundError as error:
- print(error)
- print('Please run "pip install av tqdm pims"')
|