| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 | import easing_functions as efimport randomimport torchfrom torchvision import transformsfrom torchvision.transforms import functional as Fclass MotionAugmentation:    def __init__(self,                 size,                 prob_fgr_affine,                 prob_bgr_affine,                 prob_noise,                 prob_color_jitter,                 prob_grayscale,                 prob_sharpness,                 prob_blur,                 prob_hflip,                 prob_pause,                 static_affine=True,                 aspect_ratio_range=(0.9, 1.1)):        self.size = size        self.prob_fgr_affine = prob_fgr_affine        self.prob_bgr_affine = prob_bgr_affine        self.prob_noise = prob_noise        self.prob_color_jitter = prob_color_jitter        self.prob_grayscale = prob_grayscale        self.prob_sharpness = prob_sharpness        self.prob_blur = prob_blur        self.prob_hflip = prob_hflip        self.prob_pause = prob_pause        self.static_affine = static_affine        self.aspect_ratio_range = aspect_ratio_range            def __call__(self, fgrs, phas, bgrs):        # Foreground affine        if random.random() < self.prob_fgr_affine:            fgrs, phas = self._motion_affine(fgrs, phas)        # Background affine        if random.random() < self.prob_bgr_affine / 2:            bgrs = self._motion_affine(bgrs)        if random.random() < self.prob_bgr_affine / 2:            fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs)                        # Still Affine        if self.static_affine:            fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1))            bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5))                # To tensor        fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs])        phas = torch.stack([F.to_tensor(pha) for pha in phas])        bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs])                # Resize        params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range)        fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)        phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)        params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range)        bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)        # Horizontal flip        if random.random() < self.prob_hflip:            fgrs = F.hflip(fgrs)            phas = F.hflip(phas)        if random.random() < self.prob_hflip:            bgrs = F.hflip(bgrs)        # Noise        if random.random() < self.prob_noise:            fgrs, bgrs = self._motion_noise(fgrs, bgrs)                # Color jitter        if random.random() < self.prob_color_jitter:            fgrs = self._motion_color_jitter(fgrs)        if random.random() < self.prob_color_jitter:            bgrs = self._motion_color_jitter(bgrs)                    # Grayscale        if random.random() < self.prob_grayscale:            fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous()            bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous()                    # Sharpen        if random.random() < self.prob_sharpness:            sharpness = random.random() * 8            fgrs = F.adjust_sharpness(fgrs, sharpness)            phas = F.adjust_sharpness(phas, sharpness)            bgrs = F.adjust_sharpness(bgrs, sharpness)                # Blur        if random.random() < self.prob_blur / 3:            fgrs, phas = self._motion_blur(fgrs, phas)        if random.random() < self.prob_blur / 3:            bgrs = self._motion_blur(bgrs)        if random.random() < self.prob_blur / 3:            fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs)        # Pause        if random.random() < self.prob_pause:            fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs)                return fgrs, phas, bgrs        def _static_affine(self, *imgs, scale_ranges):        params = transforms.RandomAffine.get_params(            degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges,            shears=(-5, 5), img_size=imgs[0][0].size)        imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs]        return imgs if len(imgs) > 1 else imgs[0]         def _motion_affine(self, *imgs):        config = dict(degrees=(-10, 10), translate=(0.1, 0.1),                      scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size)        angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config)        angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config)                T = len(imgs[0])        easing = random_easing_fn()        for t in range(T):            percentage = easing(t / (T - 1))            angle = lerp(angleA, angleB, percentage)            transX = lerp(transXA, transXB, percentage)            transY = lerp(transYA, transYB, percentage)            scale = lerp(scaleA, scaleB, percentage)            shearX = lerp(shearXA, shearXB, percentage)            shearY = lerp(shearYA, shearYB, percentage)            for img in imgs:                img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR)        return imgs if len(imgs) > 1 else imgs[0]        def _motion_noise(self, *imgs):        grain_size = random.random() * 3 + 1 # range 1 ~ 4        monochrome = random.random() < 0.5        for img in imgs:            T, C, H, W = img.shape            noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size)))            noise.mul_(random.random() * 0.2 / grain_size)            if grain_size != 1:                noise = F.resize(noise, (H, W))            img.add_(noise).clamp_(0, 1)        return imgs if len(imgs) > 1 else imgs[0]        def _motion_color_jitter(self, *imgs):        brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \            = torch.randn(8).mul(0.1).tolist()        strength = random.random() * 0.2        easing = random_easing_fn()        T = len(imgs[0])        for t in range(T):            percentage = easing(t / (T - 1)) * strength            for img in imgs:                img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))                img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1))                img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1))                img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1)))        return imgs if len(imgs) > 1 else imgs[0]        def _motion_blur(self, *imgs):        blurA = random.random() * 10        blurB = random.random() * 10        T = len(imgs[0])        easing = random_easing_fn()        for t in range(T):            percentage = easing(t / (T - 1))            blur = max(lerp(blurA, blurB, percentage), 0)            if blur != 0:                kernel_size = int(blur * 2)                if kernel_size % 2 == 0:                    kernel_size += 1 # Make kernel_size odd                for img in imgs:                    img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur)            return imgs if len(imgs) > 1 else imgs[0]        def _motion_pause(self, *imgs):        T = len(imgs[0])        pause_frame = random.choice(range(T - 1))        pause_length = random.choice(range(T - pause_frame))        for img in imgs:            img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame]        return imgs if len(imgs) > 1 else imgs[0]    def lerp(a, b, percentage):    return a * (1 - percentage) + b * percentagedef random_easing_fn():    if random.random() < 0.2:        return ef.LinearInOut()    else:        return random.choice([            ef.BackEaseIn,            ef.BackEaseOut,            ef.BackEaseInOut,            ef.BounceEaseIn,            ef.BounceEaseOut,            ef.BounceEaseInOut,            ef.CircularEaseIn,            ef.CircularEaseOut,            ef.CircularEaseInOut,            ef.CubicEaseIn,            ef.CubicEaseOut,            ef.CubicEaseInOut,            ef.ExponentialEaseIn,            ef.ExponentialEaseOut,            ef.ExponentialEaseInOut,            ef.ElasticEaseIn,            ef.ElasticEaseOut,            ef.ElasticEaseInOut,            ef.QuadEaseIn,            ef.QuadEaseOut,            ef.QuadEaseInOut,            ef.QuarticEaseIn,            ef.QuarticEaseOut,            ef.QuarticEaseInOut,            ef.QuinticEaseIn,            ef.QuinticEaseOut,            ef.QuinticEaseInOut,            ef.SineEaseIn,            ef.SineEaseOut,            ef.SineEaseInOut,            Step,        ])()class Step: # Custom easing function for sudden change.    def __call__(self, value):        return 0 if value < 0.5 else 1# ---------------------------- Frame Sampler ----------------------------class TrainFrameSampler:    def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]):        self.speed = speed        def __call__(self, seq_length):        frames = list(range(seq_length))                # Speed up        speed = random.choice(self.speed)        frames = [int(f * speed) for f in frames]                # Shift        shift = random.choice(range(seq_length))        frames = [f + shift for f in frames]                # Reverse        if random.random() < 0.5:            frames = frames[::-1]        return frames    class ValidFrameSampler:    def __call__(self, seq_length):        return range(seq_length)
 |