train_loss.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import torch
  2. from torch.nn import functional as F
  3. # --------------------------------------------------------------------------------- Train Loss
  4. def matting_loss(pred_fgr, pred_pha, true_fgr, true_pha):
  5. """
  6. Args:
  7. pred_fgr: Shape(B, T, 3, H, W)
  8. pred_pha: Shape(B, T, 1, H, W)
  9. true_fgr: Shape(B, T, 3, H, W)
  10. true_pha: Shape(B, T, 1, H, W)
  11. """
  12. loss = dict()
  13. # Alpha losses
  14. loss['pha_l1'] = F.l1_loss(pred_pha, true_pha)
  15. loss['pha_laplacian'] = laplacian_loss(pred_pha.flatten(0, 1), true_pha.flatten(0, 1))
  16. loss['pha_coherence'] = F.mse_loss(pred_pha[:, 1:] - pred_pha[:, :-1],
  17. true_pha[:, 1:] - true_pha[:, :-1]) * 5
  18. # Foreground losses
  19. true_msk = true_pha.gt(0)
  20. pred_fgr = pred_fgr * true_msk
  21. true_fgr = true_fgr * true_msk
  22. loss['fgr_l1'] = F.l1_loss(pred_fgr, true_fgr)
  23. loss['fgr_coherence'] = F.mse_loss(pred_fgr[:, 1:] - pred_fgr[:, :-1],
  24. true_fgr[:, 1:] - true_fgr[:, :-1]) * 5
  25. # Total
  26. loss['total'] = loss['pha_l1'] + loss['pha_coherence'] + loss['pha_laplacian'] \
  27. + loss['fgr_l1'] + loss['fgr_coherence']
  28. return loss
  29. def segmentation_loss(pred_seg, true_seg):
  30. """
  31. Args:
  32. pred_seg: Shape(B, T, 1, H, W)
  33. true_seg: Shape(B, T, 1, H, W)
  34. """
  35. return F.binary_cross_entropy_with_logits(pred_seg, true_seg)
  36. # ----------------------------------------------------------------------------- Laplacian Loss
  37. def laplacian_loss(pred, true, max_levels=5):
  38. kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
  39. pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
  40. true_pyramid = laplacian_pyramid(true, kernel, max_levels)
  41. loss = 0
  42. for level in range(max_levels):
  43. loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
  44. return loss / max_levels
  45. def laplacian_pyramid(img, kernel, max_levels):
  46. current = img
  47. pyramid = []
  48. for _ in range(max_levels):
  49. current = crop_to_even_size(current)
  50. down = downsample(current, kernel)
  51. up = upsample(down, kernel)
  52. diff = current - up
  53. pyramid.append(diff)
  54. current = down
  55. return pyramid
  56. def gauss_kernel(device='cpu', dtype=torch.float32):
  57. kernel = torch.tensor([[1, 4, 6, 4, 1],
  58. [4, 16, 24, 16, 4],
  59. [6, 24, 36, 24, 6],
  60. [4, 16, 24, 16, 4],
  61. [1, 4, 6, 4, 1]], device=device, dtype=dtype)
  62. kernel /= 256
  63. kernel = kernel[None, None, :, :]
  64. return kernel
  65. def gauss_convolution(img, kernel):
  66. B, C, H, W = img.shape
  67. img = img.reshape(B * C, 1, H, W)
  68. img = F.pad(img, (2, 2, 2, 2), mode='reflect')
  69. img = F.conv2d(img, kernel)
  70. img = img.reshape(B, C, H, W)
  71. return img
  72. def downsample(img, kernel):
  73. img = gauss_convolution(img, kernel)
  74. img = img[:, :, ::2, ::2]
  75. return img
  76. def upsample(img, kernel):
  77. B, C, H, W = img.shape
  78. out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
  79. out[:, :, ::2, ::2] = img * 4
  80. out = gauss_convolution(out, kernel)
  81. return out
  82. def crop_to_even_size(img):
  83. H, W = img.shape[2:]
  84. H = H - H % 2
  85. W = W - W % 2
  86. return img[:, :, :H, :W]