mobilenetv3.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch
  2. from torch import nn
  3. from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
  4. from torchvision.transforms.functional import normalize
  5. class MobileNetV3LargeEncoder(MobileNetV3):
  6. def __init__(self, pretrained: bool = False):
  7. super().__init__(
  8. inverted_residual_setting=[
  9. InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1),
  10. InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1
  11. InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1),
  12. InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2
  13. InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
  14. InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
  15. InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3
  16. InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1),
  17. InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
  18. InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
  19. InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1),
  20. InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1),
  21. InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4
  22. InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
  23. InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
  24. ],
  25. last_channel=1280
  26. )
  27. if pretrained:
  28. self.load_state_dict(torch.hub.load_state_dict_from_url(
  29. 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
  30. del self.avgpool
  31. del self.classifier
  32. def forward_single_frame(self, x):
  33. x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  34. x = self.features[0](x)
  35. x = self.features[1](x)
  36. f1 = x
  37. x = self.features[2](x)
  38. x = self.features[3](x)
  39. f2 = x
  40. x = self.features[4](x)
  41. x = self.features[5](x)
  42. x = self.features[6](x)
  43. f3 = x
  44. x = self.features[7](x)
  45. x = self.features[8](x)
  46. x = self.features[9](x)
  47. x = self.features[10](x)
  48. x = self.features[11](x)
  49. x = self.features[12](x)
  50. x = self.features[13](x)
  51. x = self.features[14](x)
  52. x = self.features[15](x)
  53. x = self.features[16](x)
  54. f4 = x
  55. return [f1, f2, f3, f4]
  56. def forward_time_series(self, x):
  57. B, T = x.shape[:2]
  58. features = self.forward_single_frame(x.flatten(0, 1))
  59. features = [f.unflatten(0, (B, T)) for f in features]
  60. return features
  61. def forward(self, x):
  62. if x.ndim == 5:
  63. return self.forward_time_series(x)
  64. else:
  65. return self.forward_single_frame(x)