1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- """
- python inference_speed_test.py \
- --model-variant mobilenetv3 \
- --resolution 1920 1080 \
- --downsample-ratio 0.25 \
- --precision float32
- """
- import argparse
- import torch
- from tqdm import tqdm
- from model.model import MattingNetwork
- torch.backends.cudnn.benchmark = True
- class InferenceSpeedTest:
- def __init__(self):
- self.parse_args()
- self.init_model()
- self.loop()
-
- def parse_args(self):
- parser = argparse.ArgumentParser()
- parser.add_argument('--model-variant', type=str, required=True)
- parser.add_argument('--resolution', type=int, required=True, nargs=2)
- parser.add_argument('--downsample-ratio', type=float, required=True)
- parser.add_argument('--precision', type=str, default='float32')
- parser.add_argument('--disable-refiner', action='store_true')
- self.args = parser.parse_args()
-
- def init_model(self):
- self.device = 'cuda'
- self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision]
- self.model = MattingNetwork(self.args.model_variant)
- self.model = self.model.to(device=self.device, dtype=self.precision).eval()
- self.model = torch.jit.script(self.model)
- self.model = torch.jit.freeze(self.model)
-
- def loop(self):
- w, h = self.args.resolution
- src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision)
- with torch.no_grad():
- rec = None, None, None, None
- for _ in tqdm(range(1000)):
- fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio)
- torch.cuda.synchronize()
- if __name__ == '__main__':
- InferenceSpeedTest()
|