hubconf.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. """
  2. Loading model
  3. model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
  4. model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50")
  5. Converter API
  6. convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
  7. """
  8. dependencies = ['torch', 'torchvision']
  9. import torch
  10. from model import MattingNetwork
  11. def mobilenetv3(pretrained: bool = True, progress: bool = True):
  12. model = MattingNetwork('mobilenetv3')
  13. if pretrained:
  14. url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth'
  15. model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress))
  16. return model
  17. def resnet50(pretrained: bool = True, progress: bool = True):
  18. model = MattingNetwork('resnet50')
  19. if pretrained:
  20. url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth'
  21. model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress))
  22. return model
  23. def converter():
  24. try:
  25. from inference import convert_video
  26. return convert_video
  27. except ModuleNotFoundError as error:
  28. print(error)
  29. print('Please run "pip install av tqdm pims"')