decoder.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import torch
  2. from torch import Tensor
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from typing import Tuple, Optional
  6. class RecurrentDecoder(nn.Module):
  7. def __init__(self, feature_channels, decoder_channels):
  8. super().__init__()
  9. self.avgpool = AvgPool()
  10. self.decode4 = BottleneckBlock(feature_channels[3])
  11. self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
  12. self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
  13. self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
  14. self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
  15. def forward(self,
  16. s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
  17. r1: Optional[Tensor], r2: Optional[Tensor],
  18. r3: Optional[Tensor], r4: Optional[Tensor]):
  19. s1, s2, s3 = self.avgpool(s0)
  20. x4, r4 = self.decode4(f4, r4)
  21. x3, r3 = self.decode3(x4, f3, s3, r3)
  22. x2, r2 = self.decode2(x3, f2, s2, r2)
  23. x1, r1 = self.decode1(x2, f1, s1, r1)
  24. x0 = self.decode0(x1, s0)
  25. return x0, r1, r2, r3, r4
  26. class AvgPool(nn.Module):
  27. def __init__(self):
  28. super().__init__()
  29. self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
  30. def forward_single_frame(self, s0):
  31. s1 = self.avgpool(s0)
  32. s2 = self.avgpool(s1)
  33. s3 = self.avgpool(s2)
  34. return s1, s2, s3
  35. def forward_time_series(self, s0):
  36. B, T = s0.shape[:2]
  37. s0 = s0.flatten(0, 1)
  38. s1, s2, s3 = self.forward_single_frame(s0)
  39. s1 = s1.unflatten(0, (B, T))
  40. s2 = s2.unflatten(0, (B, T))
  41. s3 = s3.unflatten(0, (B, T))
  42. return s1, s2, s3
  43. def forward(self, s0):
  44. if s0.ndim == 5:
  45. return self.forward_time_series(s0)
  46. else:
  47. return self.forward_single_frame(s0)
  48. class BottleneckBlock(nn.Module):
  49. def __init__(self, channels):
  50. super().__init__()
  51. self.channels = channels
  52. self.gru = ConvGRU(channels // 2)
  53. def forward(self, x, r: Optional[Tensor]):
  54. a, b = x.split(self.channels // 2, dim=-3)
  55. b, r = self.gru(b, r)
  56. x = torch.cat([a, b], dim=-3)
  57. return x, r
  58. class UpsamplingBlock(nn.Module):
  59. def __init__(self, in_channels, skip_channels, src_channels, out_channels):
  60. super().__init__()
  61. self.out_channels = out_channels
  62. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  63. self.conv = nn.Sequential(
  64. nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
  65. nn.BatchNorm2d(out_channels),
  66. nn.ReLU(True),
  67. )
  68. self.gru = ConvGRU(out_channels // 2)
  69. def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
  70. x = self.upsample(x)
  71. x = x[:, :, :s.size(2), :s.size(3)]
  72. x = torch.cat([x, f, s], dim=1)
  73. x = self.conv(x)
  74. a, b = x.split(self.out_channels // 2, dim=1)
  75. b, r = self.gru(b, r)
  76. x = torch.cat([a, b], dim=1)
  77. return x, r
  78. def forward_time_series(self, x, f, s, r: Optional[Tensor]):
  79. B, T, _, H, W = s.shape
  80. x = x.flatten(0, 1)
  81. f = f.flatten(0, 1)
  82. s = s.flatten(0, 1)
  83. x = self.upsample(x)
  84. x = x[:, :, :H, :W]
  85. x = torch.cat([x, f, s], dim=1)
  86. x = self.conv(x)
  87. x = x.unflatten(0, (B, T))
  88. a, b = x.split(self.out_channels // 2, dim=2)
  89. b, r = self.gru(b, r)
  90. x = torch.cat([a, b], dim=2)
  91. return x, r
  92. def forward(self, x, f, s, r: Optional[Tensor]):
  93. if x.ndim == 5:
  94. return self.forward_time_series(x, f, s, r)
  95. else:
  96. return self.forward_single_frame(x, f, s, r)
  97. class OutputBlock(nn.Module):
  98. def __init__(self, in_channels, src_channels, out_channels):
  99. super().__init__()
  100. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  101. self.conv = nn.Sequential(
  102. nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
  103. nn.BatchNorm2d(out_channels),
  104. nn.ReLU(True),
  105. nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
  106. nn.BatchNorm2d(out_channels),
  107. nn.ReLU(True),
  108. )
  109. def forward_single_frame(self, x, s):
  110. x = self.upsample(x)
  111. x = x[:, :, :s.size(2), :s.size(3)]
  112. x = torch.cat([x, s], dim=1)
  113. x = self.conv(x)
  114. return x
  115. def forward_time_series(self, x, s):
  116. B, T, _, H, W = s.shape
  117. x = x.flatten(0, 1)
  118. s = s.flatten(0, 1)
  119. x = self.upsample(x)
  120. x = x[:, :, :H, :W]
  121. x = torch.cat([x, s], dim=1)
  122. x = self.conv(x)
  123. x = x.unflatten(0, (B, T))
  124. return x
  125. def forward(self, x, s):
  126. if x.ndim == 5:
  127. return self.forward_time_series(x, s)
  128. else:
  129. return self.forward_single_frame(x, s)
  130. class ConvGRU(nn.Module):
  131. def __init__(self,
  132. channels: int,
  133. kernel_size: int = 3,
  134. padding: int = 1):
  135. super().__init__()
  136. self.channels = channels
  137. self.ih = nn.Sequential(
  138. nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
  139. nn.Sigmoid()
  140. )
  141. self.hh = nn.Sequential(
  142. nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
  143. nn.Tanh()
  144. )
  145. def forward_single_frame(self, x, h):
  146. r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
  147. c = self.hh(torch.cat([x, r * h], dim=1))
  148. h = (1 - z) * h + z * c
  149. return h, h
  150. def forward_time_series(self, x, h):
  151. o = []
  152. for xt in x.unbind(dim=1):
  153. ot, h = self.forward_single_frame(xt, h)
  154. o.append(ot)
  155. o = torch.stack(o, dim=1)
  156. return o, h
  157. def forward(self, x, h: Optional[Tensor]):
  158. if h is None:
  159. h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)),
  160. device=x.device, dtype=x.dtype)
  161. if x.ndim == 5:
  162. return self.forward_time_series(x, h)
  163. else:
  164. return self.forward_single_frame(x, h)
  165. class Projection(nn.Module):
  166. def __init__(self, in_channels, out_channels):
  167. super().__init__()
  168. self.conv = nn.Conv2d(in_channels, out_channels, 1)
  169. def forward_single_frame(self, x):
  170. return self.conv(x)
  171. def forward_time_series(self, x):
  172. B, T = x.shape[:2]
  173. return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
  174. def forward(self, x):
  175. if x.ndim == 5:
  176. return self.forward_time_series(x)
  177. else:
  178. return self.forward_single_frame(x)