| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 | import torchfrom torch.nn import functional as F# --------------------------------------------------------------------------------- Train Lossdef matting_loss(pred_fgr, pred_pha, true_fgr, true_pha):    """    Args:        pred_fgr: Shape(B, T, 3, H, W)        pred_pha: Shape(B, T, 1, H, W)        true_fgr: Shape(B, T, 3, H, W)        true_pha: Shape(B, T, 1, H, W)    """    loss = dict()    # Alpha losses    loss['pha_l1'] = F.l1_loss(pred_pha, true_pha)    loss['pha_laplacian'] = laplacian_loss(pred_pha.flatten(0, 1), true_pha.flatten(0, 1))    loss['pha_coherence'] = F.mse_loss(pred_pha[:, 1:] - pred_pha[:, :-1],                                       true_pha[:, 1:] - true_pha[:, :-1]) * 5    # Foreground losses    true_msk = true_pha.gt(0)    pred_fgr = pred_fgr * true_msk    true_fgr = true_fgr * true_msk    loss['fgr_l1'] = F.l1_loss(pred_fgr, true_fgr)    loss['fgr_coherence'] = F.mse_loss(pred_fgr[:, 1:] - pred_fgr[:, :-1],                                       true_fgr[:, 1:] - true_fgr[:, :-1]) * 5    # Total    loss['total'] = loss['pha_l1'] + loss['pha_coherence'] + loss['pha_laplacian'] \                  + loss['fgr_l1'] + loss['fgr_coherence']    return lossdef segmentation_loss(pred_seg, true_seg):    """    Args:        pred_seg: Shape(B, T, 1, H, W)        true_seg: Shape(B, T, 1, H, W)    """    return F.binary_cross_entropy_with_logits(pred_seg, true_seg)# ----------------------------------------------------------------------------- Laplacian Lossdef laplacian_loss(pred, true, max_levels=5):    kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)    pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)    true_pyramid = laplacian_pyramid(true, kernel, max_levels)    loss = 0    for level in range(max_levels):        loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])    return loss / max_levelsdef laplacian_pyramid(img, kernel, max_levels):    current = img    pyramid = []    for _ in range(max_levels):        current = crop_to_even_size(current)        down = downsample(current, kernel)        up = upsample(down, kernel)        diff = current - up        pyramid.append(diff)        current = down    return pyramiddef gauss_kernel(device='cpu', dtype=torch.float32):    kernel = torch.tensor([[1,  4,  6,  4, 1],                           [4, 16, 24, 16, 4],                           [6, 24, 36, 24, 6],                           [4, 16, 24, 16, 4],                           [1,  4,  6,  4, 1]], device=device, dtype=dtype)    kernel /= 256    kernel = kernel[None, None, :, :]    return kerneldef gauss_convolution(img, kernel):    B, C, H, W = img.shape    img = img.reshape(B * C, 1, H, W)    img = F.pad(img, (2, 2, 2, 2), mode='reflect')    img = F.conv2d(img, kernel)    img = img.reshape(B, C, H, W)    return imgdef downsample(img, kernel):    img = gauss_convolution(img, kernel)    img = img[:, :, ::2, ::2]    return imgdef upsample(img, kernel):    B, C, H, W = img.shape    out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)    out[:, :, ::2, ::2] = img * 4    out = gauss_convolution(out, kernel)    return outdef crop_to_even_size(img):    H, W = img.shape[2:]    H = H - H % 2    W = W - W % 2    return img[:, :, :H, :W]
 |