fast_guided_filter.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. """
  5. Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>
  6. """
  7. class FastGuidedFilterRefiner(nn.Module):
  8. def __init__(self, *args, **kwargs):
  9. super().__init__()
  10. self.guilded_filter = FastGuidedFilter(1)
  11. def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
  12. fine_src_gray = fine_src.mean(1, keepdim=True)
  13. base_src_gray = base_src.mean(1, keepdim=True)
  14. fgr, pha = self.guilded_filter(
  15. torch.cat([base_src, base_src_gray], dim=1),
  16. torch.cat([base_fgr, base_pha], dim=1),
  17. torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1)
  18. return fgr, pha
  19. def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
  20. B, T = fine_src.shape[:2]
  21. fgr, pha = self.forward_single_frame(
  22. fine_src.flatten(0, 1),
  23. base_src.flatten(0, 1),
  24. base_fgr.flatten(0, 1),
  25. base_pha.flatten(0, 1))
  26. fgr = fgr.unflatten(0, (B, T))
  27. pha = pha.unflatten(0, (B, T))
  28. return fgr, pha
  29. def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  30. if fine_src.ndim == 5:
  31. return self.forward_time_series(fine_src, base_src, base_fgr, base_pha)
  32. else:
  33. return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha)
  34. class FastGuidedFilter(nn.Module):
  35. def __init__(self, r: int, eps: float = 1e-5):
  36. super().__init__()
  37. self.r = r
  38. self.eps = eps
  39. self.boxfilter = BoxFilter(r)
  40. def forward(self, lr_x, lr_y, hr_x):
  41. mean_x = self.boxfilter(lr_x)
  42. mean_y = self.boxfilter(lr_y)
  43. cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y
  44. var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x
  45. A = cov_xy / (var_x + self.eps)
  46. b = mean_y - A * mean_x
  47. A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False)
  48. b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False)
  49. return A * hr_x + b
  50. class BoxFilter(nn.Module):
  51. def __init__(self, r):
  52. super(BoxFilter, self).__init__()
  53. self.r = r
  54. def forward(self, x):
  55. # Note: The original implementation at <https://github.com/wuhuikai/DeepGuidedFilter/>
  56. # uses faster box blur. However, it may not be friendly for ONNX export.
  57. # We are switching to use simple convolution for box blur.
  58. kernel_size = 2 * self.r + 1
  59. kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype)
  60. kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype)
  61. x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1])
  62. x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1])
  63. return x