| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 | import torchfrom torch import Tensorfrom torch import nnfrom torch.nn import functional as Ffrom typing import Optional, Listfrom .mobilenetv3 import MobileNetV3LargeEncoderfrom .resnet import ResNet50Encoderfrom .lraspp import LRASPPfrom .decoder import RecurrentDecoder, Projectionfrom .fast_guided_filter import FastGuidedFilterRefinerfrom .deep_guided_filter import DeepGuidedFilterRefinerclass MattingNetwork(nn.Module):    def __init__(self,                 variant: str = 'mobilenetv3',                 refiner: str = 'deep_guided_filter',                 pretrained_backbone: bool = False):        super().__init__()        assert variant in ['mobilenetv3', 'resnet50']        assert refiner in ['fast_guided_filter', 'deep_guided_filter']                if variant == 'mobilenetv3':            self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)            self.aspp = LRASPP(960, 128)            self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])        else:            self.backbone = ResNet50Encoder(pretrained_backbone)            self.aspp = LRASPP(2048, 256)            self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])                    self.project_mat = Projection(16, 4)        self.project_seg = Projection(16, 1)        if refiner == 'deep_guided_filter':            self.refiner = DeepGuidedFilterRefiner()        else:            self.refiner = FastGuidedFilterRefiner()            def forward(self,                src: Tensor,                r1: Optional[Tensor] = None,                r2: Optional[Tensor] = None,                r3: Optional[Tensor] = None,                r4: Optional[Tensor] = None,                downsample_ratio: float = 1,                segmentation_pass: bool = False):                if downsample_ratio != 1:            src_sm = self._interpolate(src, scale_factor=downsample_ratio)        else:            src_sm = src                f1, f2, f3, f4 = self.backbone(src_sm)        f4 = self.aspp(f4)        hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)                if not segmentation_pass:            fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)            if downsample_ratio != 1:                fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)            fgr = fgr_residual + src            fgr = fgr.clamp(0., 1.)            pha = pha.clamp(0., 1.)            return [fgr, pha, *rec]        else:            seg = self.project_seg(hid)            return [seg, *rec]    def _interpolate(self, x: Tensor, scale_factor: float):        if x.ndim == 5:            B, T = x.shape[:2]            x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,                mode='bilinear', align_corners=False, recompute_scale_factor=False)            x = x.unflatten(0, (B, T))        else:            x = F.interpolate(x, scale_factor=scale_factor,                mode='bilinear', align_corners=False, recompute_scale_factor=False)        return x
 |