inference_speed_test.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. """
  2. python inference_speed_test.py \
  3. --model-variant mobilenetv3 \
  4. --resolution 1920 1080 \
  5. --downsample-ratio 0.25 \
  6. --precision float32
  7. """
  8. import argparse
  9. import torch
  10. from tqdm import tqdm
  11. from model.model import MattingNetwork
  12. torch.backends.cudnn.benchmark = True
  13. class InferenceSpeedTest:
  14. def __init__(self):
  15. self.parse_args()
  16. self.init_model()
  17. self.loop()
  18. def parse_args(self):
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument('--model-variant', type=str, required=True)
  21. parser.add_argument('--resolution', type=int, required=True, nargs=2)
  22. parser.add_argument('--downsample-ratio', type=float, required=True)
  23. parser.add_argument('--precision', type=str, default='float32')
  24. parser.add_argument('--disable-refiner', action='store_true')
  25. self.args = parser.parse_args()
  26. def init_model(self):
  27. self.device = 'cuda'
  28. self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision]
  29. self.model = MattingNetwork(self.args.model_variant)
  30. self.model = self.model.to(device=self.device, dtype=self.precision).eval()
  31. self.model = torch.jit.script(self.model)
  32. self.model = torch.jit.freeze(self.model)
  33. def loop(self):
  34. w, h = self.args.resolution
  35. src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision)
  36. with torch.no_grad():
  37. rec = None, None, None, None
  38. for _ in tqdm(range(1000)):
  39. fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio)
  40. torch.cuda.synchronize()
  41. if __name__ == '__main__':
  42. InferenceSpeedTest()