model.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import torch
  2. from torch import Tensor
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from typing import Optional, List
  6. from .mobilenetv3 import MobileNetV3LargeEncoder
  7. from .resnet import ResNet50Encoder
  8. from .lraspp import LRASPP
  9. from .decoder import RecurrentDecoder, Projection
  10. from .fast_guided_filter import FastGuidedFilterRefiner
  11. from .deep_guided_filter import DeepGuidedFilterRefiner
  12. class MattingNetwork(nn.Module):
  13. def __init__(self,
  14. variant: str = 'mobilenetv3',
  15. refiner: str = 'deep_guided_filter',
  16. pretrained_backbone: bool = False):
  17. super().__init__()
  18. assert variant in ['mobilenetv3', 'resnet50']
  19. assert refiner in ['fast_guided_filter', 'deep_guided_filter']
  20. if variant == 'mobilenetv3':
  21. self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)
  22. self.aspp = LRASPP(960, 128)
  23. self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])
  24. else:
  25. self.backbone = ResNet50Encoder(pretrained_backbone)
  26. self.aspp = LRASPP(2048, 256)
  27. self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])
  28. self.project_mat = Projection(16, 4)
  29. self.project_seg = Projection(16, 1)
  30. if refiner == 'deep_guided_filter':
  31. self.refiner = DeepGuidedFilterRefiner()
  32. else:
  33. self.refiner = FastGuidedFilterRefiner()
  34. def forward(self,
  35. src: Tensor,
  36. r1: Optional[Tensor] = None,
  37. r2: Optional[Tensor] = None,
  38. r3: Optional[Tensor] = None,
  39. r4: Optional[Tensor] = None,
  40. downsample_ratio: float = 1,
  41. segmentation_pass: bool = False):
  42. if downsample_ratio != 1:
  43. src_sm = self._interpolate(src, scale_factor=downsample_ratio)
  44. else:
  45. src_sm = src
  46. f1, f2, f3, f4 = self.backbone(src_sm)
  47. f4 = self.aspp(f4)
  48. hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
  49. if not segmentation_pass:
  50. fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
  51. if downsample_ratio != 1:
  52. fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
  53. fgr = fgr_residual + src
  54. fgr = fgr.clamp(0., 1.)
  55. pha = pha.clamp(0., 1.)
  56. return [fgr, pha, *rec]
  57. else:
  58. seg = self.project_seg(hid)
  59. return [seg, *rec]
  60. def _interpolate(self, x: Tensor, scale_factor: float):
  61. if x.ndim == 5:
  62. B, T = x.shape[:2]
  63. x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,
  64. mode='bilinear', align_corners=False, recompute_scale_factor=False)
  65. x = x.unflatten(0, (B, T))
  66. else:
  67. x = F.interpolate(x, scale_factor=scale_factor,
  68. mode='bilinear', align_corners=False, recompute_scale_factor=False)
  69. return x