deep_guided_filter.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. """
  5. Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
  6. """
  7. class DeepGuidedFilterRefiner(nn.Module):
  8. def __init__(self, hid_channels=16):
  9. super().__init__()
  10. self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
  11. self.box_filter.weight.data[...] = 1 / 9
  12. self.conv = nn.Sequential(
  13. nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
  14. nn.BatchNorm2d(hid_channels),
  15. nn.ReLU(True),
  16. nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
  17. nn.BatchNorm2d(hid_channels),
  18. nn.ReLU(True),
  19. nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
  20. )
  21. def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  22. fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
  23. base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
  24. base_y = torch.cat([base_fgr, base_pha], dim=1)
  25. mean_x = self.box_filter(base_x)
  26. mean_y = self.box_filter(base_y)
  27. cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
  28. var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
  29. A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
  30. b = mean_y - A * mean_x
  31. H, W = fine_src.shape[2:]
  32. A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
  33. b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
  34. out = A * fine_x + b
  35. fgr, pha = out.split([3, 1], dim=1)
  36. return fgr, pha
  37. def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  38. B, T = fine_src.shape[:2]
  39. fgr, pha = self.forward_single_frame(
  40. fine_src.flatten(0, 1),
  41. base_src.flatten(0, 1),
  42. base_fgr.flatten(0, 1),
  43. base_pha.flatten(0, 1),
  44. base_hid.flatten(0, 1))
  45. fgr = fgr.unflatten(0, (B, T))
  46. pha = pha.unflatten(0, (B, T))
  47. return fgr, pha
  48. def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  49. if fine_src.ndim == 5:
  50. return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid)
  51. else:
  52. return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid)