| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 | import torchfrom torch import nnfrom torch.nn import functional as F"""Adopted from <https://github.com/wuhuikai/DeepGuidedFilter/>"""class DeepGuidedFilterRefiner(nn.Module):    def __init__(self, hid_channels=16):        super().__init__()        self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)        self.box_filter.weight.data[...] = 1 / 9        self.conv = nn.Sequential(            nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),            nn.BatchNorm2d(hid_channels),            nn.ReLU(True),            nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),            nn.BatchNorm2d(hid_channels),            nn.ReLU(True),            nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)        )            def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid):        fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)        base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)        base_y = torch.cat([base_fgr, base_pha], dim=1)                mean_x = self.box_filter(base_x)        mean_y = self.box_filter(base_y)        cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y        var_x  = self.box_filter(base_x * base_x) - mean_x * mean_x                A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))        b = mean_y - A * mean_x                H, W = fine_src.shape[2:]        A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)        b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)                out = A * fine_x + b        fgr, pha = out.split([3, 1], dim=1)        return fgr, pha        def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid):        B, T = fine_src.shape[:2]        fgr, pha = self.forward_single_frame(            fine_src.flatten(0, 1),            base_src.flatten(0, 1),            base_fgr.flatten(0, 1),            base_pha.flatten(0, 1),            base_hid.flatten(0, 1))        fgr = fgr.unflatten(0, (B, T))        pha = pha.unflatten(0, (B, T))        return fgr, pha        def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):        if fine_src.ndim == 5:            return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid)        else:            return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid)
 |