| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 | from torch import nnfrom torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfigfrom torchvision.models.utils import load_state_dict_from_urlfrom torchvision.transforms.functional import normalizeclass MobileNetV3LargeEncoder(MobileNetV3):    def __init__(self, pretrained: bool = False):        super().__init__(            inverted_residual_setting=[                InvertedResidualConfig( 16, 3,  16,  16, False, "RE", 1, 1, 1),                InvertedResidualConfig( 16, 3,  64,  24, False, "RE", 2, 1, 1),  # C1                InvertedResidualConfig( 24, 3,  72,  24, False, "RE", 1, 1, 1),                InvertedResidualConfig( 24, 5,  72,  40,  True, "RE", 2, 1, 1),  # C2                InvertedResidualConfig( 40, 5, 120,  40,  True, "RE", 1, 1, 1),                InvertedResidualConfig( 40, 5, 120,  40,  True, "RE", 1, 1, 1),                InvertedResidualConfig( 40, 3, 240,  80, False, "HS", 2, 1, 1),  # C3                InvertedResidualConfig( 80, 3, 200,  80, False, "HS", 1, 1, 1),                InvertedResidualConfig( 80, 3, 184,  80, False, "HS", 1, 1, 1),                InvertedResidualConfig( 80, 3, 184,  80, False, "HS", 1, 1, 1),                InvertedResidualConfig( 80, 3, 480, 112,  True, "HS", 1, 1, 1),                InvertedResidualConfig(112, 3, 672, 112,  True, "HS", 1, 1, 1),                InvertedResidualConfig(112, 5, 672, 160,  True, "HS", 2, 2, 1),  # C4                InvertedResidualConfig(160, 5, 960, 160,  True, "HS", 1, 2, 1),                InvertedResidualConfig(160, 5, 960, 160,  True, "HS", 1, 2, 1),            ],            last_channel=1280        )                if pretrained:            self.load_state_dict(load_state_dict_from_url(                'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))        del self.avgpool        del self.classifier            def forward_single_frame(self, x):        x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])                x = self.features[0](x)        x = self.features[1](x)        f1 = x        x = self.features[2](x)        x = self.features[3](x)        f2 = x        x = self.features[4](x)        x = self.features[5](x)        x = self.features[6](x)        f3 = x        x = self.features[7](x)        x = self.features[8](x)        x = self.features[9](x)        x = self.features[10](x)        x = self.features[11](x)        x = self.features[12](x)        x = self.features[13](x)        x = self.features[14](x)        x = self.features[15](x)        x = self.features[16](x)        f4 = x        return [f1, f2, f3, f4]        def forward_time_series(self, x):        B, T = x.shape[:2]        features = self.forward_single_frame(x.flatten(0, 1))        features = [f.unflatten(0, (B, T)) for f in features]        return features    def forward(self, x):        if x.ndim == 5:            return self.forward_time_series(x)        else:            return self.forward_single_frame(x)
 |