123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506 |
- """
- # First update `train_config.py` to set paths to your dataset locations.
- # You may want to change `--num-workers` according to your machine's memory.
- # The default num-workers=8 may cause dataloader to exit unexpectedly when
- # machine is out of memory.
- # Stage 1
- python train.py \
- --model-variant mobilenetv3 \
- --dataset videomatte \
- --resolution-lr 512 \
- --seq-length-lr 15 \
- --learning-rate-backbone 0.0001 \
- --learning-rate-aspp 0.0002 \
- --learning-rate-decoder 0.0002 \
- --learning-rate-refiner 0 \
- --checkpoint-dir checkpoint/stage1 \
- --log-dir log/stage1 \
- --epoch-start 0 \
- --epoch-end 20
- # Stage 2
- python train.py \
- --model-variant mobilenetv3 \
- --dataset videomatte \
- --resolution-lr 512 \
- --seq-length-lr 50 \
- --learning-rate-backbone 0.00005 \
- --learning-rate-aspp 0.0001 \
- --learning-rate-decoder 0.0001 \
- --learning-rate-refiner 0 \
- --checkpoint checkpoint/stage1/epoch-19.pth \
- --checkpoint-dir checkpoint/stage2 \
- --log-dir log/stage2 \
- --epoch-start 20 \
- --epoch-end 22
-
- # Stage 3
- python train.py \
- --model-variant mobilenetv3 \
- --dataset videomatte \
- --train-hr \
- --resolution-lr 512 \
- --resolution-hr 2048 \
- --seq-length-lr 40 \
- --seq-length-hr 6 \
- --learning-rate-backbone 0.00001 \
- --learning-rate-aspp 0.00001 \
- --learning-rate-decoder 0.00001 \
- --learning-rate-refiner 0.0002 \
- --checkpoint checkpoint/stage2/epoch-21.pth \
- --checkpoint-dir checkpoint/stage3 \
- --log-dir log/stage3 \
- --epoch-start 22 \
- --epoch-end 23
- # Stage 4
- python train.py \
- --model-variant mobilenetv3 \
- --dataset imagematte \
- --train-hr \
- --resolution-lr 512 \
- --resolution-hr 2048 \
- --seq-length-lr 40 \
- --seq-length-hr 6 \
- --learning-rate-backbone 0.00001 \
- --learning-rate-aspp 0.00001 \
- --learning-rate-decoder 0.00005 \
- --learning-rate-refiner 0.0002 \
- --checkpoint checkpoint/stage3/epoch-22.pth \
- --checkpoint-dir checkpoint/stage4 \
- --log-dir log/stage4 \
- --epoch-start 23 \
- --epoch-end 28
- """
- import argparse
- import torch
- import random
- import os
- from torch import nn
- from torch import distributed as dist
- from torch import multiprocessing as mp
- from torch.nn import functional as F
- from torch.nn.parallel import DistributedDataParallel as DDP
- from torch.optim import Adam
- from torch.cuda.amp import autocast, GradScaler
- from torch.utils.data import DataLoader, ConcatDataset
- from torch.utils.data.distributed import DistributedSampler
- from torch.utils.tensorboard import SummaryWriter
- from torchvision.utils import make_grid
- from torchvision.transforms.functional import center_crop
- from tqdm import tqdm
- from dataset.videomatte import (
- VideoMatteDataset,
- VideoMatteTrainAugmentation,
- VideoMatteValidAugmentation,
- )
- from dataset.imagematte import (
- ImageMatteDataset,
- ImageMatteAugmentation
- )
- from dataset.coco import (
- CocoPanopticDataset,
- CocoPanopticTrainAugmentation,
- )
- from dataset.spd import (
- SuperviselyPersonDataset
- )
- from dataset.youtubevis import (
- YouTubeVISDataset,
- YouTubeVISAugmentation
- )
- from dataset.augmentation import (
- TrainFrameSampler,
- ValidFrameSampler
- )
- from model import MattingNetwork
- from train_config import DATA_PATHS
- from train_loss import matting_loss, segmentation_loss
- class Trainer:
- def __init__(self, rank, world_size):
- self.parse_args()
- self.init_distributed(rank, world_size)
- self.init_datasets()
- self.init_model()
- self.init_writer()
- self.train()
- self.cleanup()
-
- def parse_args(self):
- parser = argparse.ArgumentParser()
- # Model
- parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
- # Matting dataset
- parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte'])
- # Learning rate
- parser.add_argument('--learning-rate-backbone', type=float, required=True)
- parser.add_argument('--learning-rate-aspp', type=float, required=True)
- parser.add_argument('--learning-rate-decoder', type=float, required=True)
- parser.add_argument('--learning-rate-refiner', type=float, required=True)
- # Training setting
- parser.add_argument('--train-hr', action='store_true')
- parser.add_argument('--resolution-lr', type=int, default=512)
- parser.add_argument('--resolution-hr', type=int, default=2048)
- parser.add_argument('--seq-length-lr', type=int, required=True)
- parser.add_argument('--seq-length-hr', type=int, default=6)
- parser.add_argument('--downsample-ratio', type=float, default=0.25)
- parser.add_argument('--batch-size-per-gpu', type=int, default=1)
- parser.add_argument('--num-workers', type=int, default=8)
- parser.add_argument('--epoch-start', type=int, default=0)
- parser.add_argument('--epoch-end', type=int, default=16)
- # Tensorboard logging
- parser.add_argument('--log-dir', type=str, required=True)
- parser.add_argument('--log-train-loss-interval', type=int, default=20)
- parser.add_argument('--log-train-images-interval', type=int, default=500)
- # Checkpoint loading and saving
- parser.add_argument('--checkpoint', type=str)
- parser.add_argument('--checkpoint-dir', type=str, required=True)
- parser.add_argument('--checkpoint-save-interval', type=int, default=500)
- # Distributed
- parser.add_argument('--distributed-addr', type=str, default='localhost')
- parser.add_argument('--distributed-port', type=str, default='12355')
- # Debugging
- parser.add_argument('--disable-progress-bar', action='store_true')
- parser.add_argument('--disable-validation', action='store_true')
- parser.add_argument('--disable-mixed-precision', action='store_true')
- self.args = parser.parse_args()
-
- def init_distributed(self, rank, world_size):
- self.rank = rank
- self.world_size = world_size
- self.log('Initializing distributed')
- os.environ['MASTER_ADDR'] = self.args.distributed_addr
- os.environ['MASTER_PORT'] = self.args.distributed_port
- dist.init_process_group("nccl", rank=rank, world_size=world_size)
-
- def init_datasets(self):
- self.log('Initializing matting datasets')
- size_hr = (self.args.resolution_hr, self.args.resolution_hr)
- size_lr = (self.args.resolution_lr, self.args.resolution_lr)
-
- # Matting datasets:
- if self.args.dataset == 'videomatte':
- self.dataset_lr_train = VideoMatteDataset(
- videomatte_dir=DATA_PATHS['videomatte']['train'],
- background_image_dir=DATA_PATHS['background_images']['train'],
- background_video_dir=DATA_PATHS['background_videos']['train'],
- size=self.args.resolution_lr,
- seq_length=self.args.seq_length_lr,
- seq_sampler=TrainFrameSampler(),
- transform=VideoMatteTrainAugmentation(size_lr))
- if self.args.train_hr:
- self.dataset_hr_train = VideoMatteDataset(
- videomatte_dir=DATA_PATHS['videomatte']['train'],
- background_image_dir=DATA_PATHS['background_images']['train'],
- background_video_dir=DATA_PATHS['background_videos']['train'],
- size=self.args.resolution_hr,
- seq_length=self.args.seq_length_hr,
- seq_sampler=TrainFrameSampler(),
- transform=VideoMatteTrainAugmentation(size_hr))
- self.dataset_valid = VideoMatteDataset(
- videomatte_dir=DATA_PATHS['videomatte']['valid'],
- background_image_dir=DATA_PATHS['background_images']['valid'],
- background_video_dir=DATA_PATHS['background_videos']['valid'],
- size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
- seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
- seq_sampler=ValidFrameSampler(),
- transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
- else:
- self.dataset_lr_train = ImageMatteDataset(
- imagematte_dir=DATA_PATHS['imagematte']['train'],
- background_image_dir=DATA_PATHS['background_images']['train'],
- background_video_dir=DATA_PATHS['background_videos']['train'],
- size=self.args.resolution_lr,
- seq_length=self.args.seq_length_lr,
- seq_sampler=TrainFrameSampler(),
- transform=ImageMatteAugmentation(size_lr))
- if self.args.train_hr:
- self.dataset_hr_train = ImageMatteDataset(
- imagematte_dir=DATA_PATHS['imagematte']['train'],
- background_image_dir=DATA_PATHS['background_images']['train'],
- background_video_dir=DATA_PATHS['background_videos']['train'],
- size=self.args.resolution_hr,
- seq_length=self.args.seq_length_hr,
- seq_sampler=TrainFrameSampler(),
- transform=ImageMatteAugmentation(size_hr))
- self.dataset_valid = ImageMatteDataset(
- imagematte_dir=DATA_PATHS['imagematte']['valid'],
- background_image_dir=DATA_PATHS['background_images']['valid'],
- background_video_dir=DATA_PATHS['background_videos']['valid'],
- size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
- seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
- seq_sampler=ValidFrameSampler(),
- transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
-
- # Matting dataloaders:
- self.datasampler_lr_train = DistributedSampler(
- dataset=self.dataset_lr_train,
- rank=self.rank,
- num_replicas=self.world_size,
- shuffle=True)
- self.dataloader_lr_train = DataLoader(
- dataset=self.dataset_lr_train,
- batch_size=self.args.batch_size_per_gpu,
- num_workers=self.args.num_workers,
- sampler=self.datasampler_lr_train,
- pin_memory=True)
- if self.args.train_hr:
- self.datasampler_hr_train = DistributedSampler(
- dataset=self.dataset_hr_train,
- rank=self.rank,
- num_replicas=self.world_size,
- shuffle=True)
- self.dataloader_hr_train = DataLoader(
- dataset=self.dataset_hr_train,
- batch_size=self.args.batch_size_per_gpu,
- num_workers=self.args.num_workers,
- sampler=self.datasampler_hr_train,
- pin_memory=True)
- self.dataloader_valid = DataLoader(
- dataset=self.dataset_valid,
- batch_size=self.args.batch_size_per_gpu,
- num_workers=self.args.num_workers,
- pin_memory=True)
-
- # Segementation datasets
- self.log('Initializing image segmentation datasets')
- self.dataset_seg_image = ConcatDataset([
- CocoPanopticDataset(
- imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
- anndir=DATA_PATHS['coco_panoptic']['anndir'],
- annfile=DATA_PATHS['coco_panoptic']['annfile'],
- transform=CocoPanopticTrainAugmentation(size_lr)),
- SuperviselyPersonDataset(
- imgdir=DATA_PATHS['spd']['imgdir'],
- segdir=DATA_PATHS['spd']['segdir'],
- transform=CocoPanopticTrainAugmentation(size_lr))
- ])
- self.datasampler_seg_image = DistributedSampler(
- dataset=self.dataset_seg_image,
- rank=self.rank,
- num_replicas=self.world_size,
- shuffle=True)
- self.dataloader_seg_image = DataLoader(
- dataset=self.dataset_seg_image,
- batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
- num_workers=self.args.num_workers,
- sampler=self.datasampler_seg_image,
- pin_memory=True)
-
- self.log('Initializing video segmentation datasets')
- self.dataset_seg_video = YouTubeVISDataset(
- videodir=DATA_PATHS['youtubevis']['videodir'],
- annfile=DATA_PATHS['youtubevis']['annfile'],
- size=self.args.resolution_lr,
- seq_length=self.args.seq_length_lr,
- seq_sampler=TrainFrameSampler(speed=[1]),
- transform=YouTubeVISAugmentation(size_lr))
- self.datasampler_seg_video = DistributedSampler(
- dataset=self.dataset_seg_video,
- rank=self.rank,
- num_replicas=self.world_size,
- shuffle=True)
- self.dataloader_seg_video = DataLoader(
- dataset=self.dataset_seg_video,
- batch_size=self.args.batch_size_per_gpu,
- num_workers=self.args.num_workers,
- sampler=self.datasampler_seg_video,
- pin_memory=True)
-
- def init_model(self):
- self.log('Initializing model')
- self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank)
-
- if self.args.checkpoint:
- self.log(f'Restoring from checkpoint: {self.args.checkpoint}')
- self.log(self.model.load_state_dict(
- torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}')))
-
- self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
- self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True)
- self.optimizer = Adam([
- {'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone},
- {'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp},
- {'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder},
- {'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder},
- {'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder},
- {'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner},
- ])
- self.scaler = GradScaler()
-
- def init_writer(self):
- if self.rank == 0:
- self.log('Initializing writer')
- self.writer = SummaryWriter(self.args.log_dir)
-
- def train(self):
- for epoch in range(self.args.epoch_start, self.args.epoch_end):
- self.epoch = epoch
- self.step = epoch * len(self.dataloader_lr_train)
-
- if not self.args.disable_validation:
- self.validate()
-
- self.log(f'Training epoch: {epoch}')
- for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True):
- # Low resolution pass
- self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr')
- # High resolution pass
- if self.args.train_hr:
- true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample()
- self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
-
- # Segmentation pass
- if self.step % 2 == 0:
- true_img, true_seg = self.load_next_seg_video_sample()
- self.train_seg(true_img, true_seg, log_label='seg_video')
- else:
- true_img, true_seg = self.load_next_seg_image_sample()
- self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
-
- if self.step % self.args.checkpoint_save_interval == 0:
- self.save()
-
- self.step += 1
-
- def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
- true_fgr = true_fgr.to(self.rank, non_blocking=True)
- true_pha = true_pha.to(self.rank, non_blocking=True)
- true_bgr = true_bgr.to(self.rank, non_blocking=True)
- true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
- true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
-
- with autocast(enabled=not self.args.disable_mixed_precision):
- pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
- loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
- self.scaler.scale(loss['total']).backward()
- self.scaler.step(self.optimizer)
- self.scaler.update()
- self.optimizer.zero_grad()
-
- if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0:
- for loss_name, loss_value in loss.items():
- self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step)
-
- if self.rank == 0 and self.step % self.args.log_train_images_interval == 0:
- self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step)
- self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step)
- self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
- self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
- self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
-
- def train_seg(self, true_img, true_seg, log_label):
- true_img = true_img.to(self.rank, non_blocking=True)
- true_seg = true_seg.to(self.rank, non_blocking=True)
-
- true_img, true_seg = self.random_crop(true_img, true_seg)
-
- with autocast(enabled=not self.args.disable_mixed_precision):
- pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
- loss = segmentation_loss(pred_seg, true_seg)
-
- self.scaler.scale(loss).backward()
- self.scaler.step(self.optimizer)
- self.scaler.update()
- self.optimizer.zero_grad()
-
- if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
- self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
-
- if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
- self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
- self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
- self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
-
- def load_next_mat_hr_sample(self):
- try:
- sample = next(self.dataiterator_mat_hr)
- except:
- self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1)
- self.dataiterator_mat_hr = iter(self.dataloader_hr_train)
- sample = next(self.dataiterator_mat_hr)
- return sample
-
- def load_next_seg_video_sample(self):
- try:
- sample = next(self.dataiterator_seg_video)
- except:
- self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
- self.dataiterator_seg_video = iter(self.dataloader_seg_video)
- sample = next(self.dataiterator_seg_video)
- return sample
-
- def load_next_seg_image_sample(self):
- try:
- sample = next(self.dataiterator_seg_image)
- except:
- self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
- self.dataiterator_seg_image = iter(self.dataloader_seg_image)
- sample = next(self.dataiterator_seg_image)
- return sample
-
- def validate(self):
- if self.rank == 0:
- self.log(f'Validating at the start of epoch: {self.epoch}')
- self.model_ddp.eval()
- total_loss, total_count = 0, 0
- with torch.no_grad():
- with autocast(enabled=not self.args.disable_mixed_precision):
- for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True):
- true_fgr = true_fgr.to(self.rank, non_blocking=True)
- true_pha = true_pha.to(self.rank, non_blocking=True)
- true_bgr = true_bgr.to(self.rank, non_blocking=True)
- true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
- batch_size = true_src.size(0)
- pred_fgr, pred_pha = self.model(true_src)[:2]
- total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
- total_count += batch_size
- avg_loss = total_loss / total_count
- self.log(f'Validation set average loss: {avg_loss}')
- self.writer.add_scalar('valid_loss', avg_loss, self.step)
- self.model_ddp.train()
- dist.barrier()
-
- def random_crop(self, *imgs):
- h, w = imgs[0].shape[-2:]
- w = random.choice(range(w // 2, w))
- h = random.choice(range(h // 2, h))
- results = []
- for img in imgs:
- B, T = img.shape[:2]
- img = img.flatten(0, 1)
- img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False)
- img = center_crop(img, (h, w))
- img = img.reshape(B, T, *img.shape[1:])
- results.append(img)
- return results
-
- def save(self):
- if self.rank == 0:
- os.makedirs(self.args.checkpoint_dir, exist_ok=True)
- torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth'))
- self.log('Model saved')
- dist.barrier()
-
- def cleanup(self):
- dist.destroy_process_group()
-
- def log(self, msg):
- print(f'[GPU{self.rank}] {msg}')
-
- if __name__ == '__main__':
- world_size = torch.cuda.device_count()
- mp.spawn(
- Trainer,
- nprocs=world_size,
- args=(world_size,),
- join=True)
|