train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. """
  2. # First update `train_config.py` to set paths to your dataset locations.
  3. # You may want to change `--num-workers` according to your machine's memory.
  4. # The default num-workers=8 may cause dataloader to exit unexpectedly when
  5. # machine is out of memory.
  6. # Stage 1
  7. python train.py \
  8. --model-variant mobilenetv3 \
  9. --dataset videomatte \
  10. --resolution-lr 512 \
  11. --seq-length-lr 15 \
  12. --learning-rate-backbone 0.0001 \
  13. --learning-rate-aspp 0.0002 \
  14. --learning-rate-decoder 0.0002 \
  15. --learning-rate-refiner 0 \
  16. --checkpoint-dir checkpoint/stage1 \
  17. --log-dir log/stage1 \
  18. --epoch-start 0 \
  19. --epoch-end 20
  20. # Stage 2
  21. python train.py \
  22. --model-variant mobilenetv3 \
  23. --dataset videomatte \
  24. --resolution-lr 512 \
  25. --seq-length-lr 50 \
  26. --learning-rate-backbone 0.00005 \
  27. --learning-rate-aspp 0.0001 \
  28. --learning-rate-decoder 0.0001 \
  29. --learning-rate-refiner 0 \
  30. --checkpoint checkpoint/stage1/epoch-19.pth \
  31. --checkpoint-dir checkpoint/stage2 \
  32. --log-dir log/stage2 \
  33. --epoch-start 20 \
  34. --epoch-end 22
  35. # Stage 3
  36. python train.py \
  37. --model-variant mobilenetv3 \
  38. --dataset videomatte \
  39. --train-hr \
  40. --resolution-lr 512 \
  41. --resolution-hr 2048 \
  42. --seq-length-lr 40 \
  43. --seq-length-hr 6 \
  44. --learning-rate-backbone 0.00001 \
  45. --learning-rate-aspp 0.00001 \
  46. --learning-rate-decoder 0.00001 \
  47. --learning-rate-refiner 0.0002 \
  48. --checkpoint checkpoint/stage2/epoch-21.pth \
  49. --checkpoint-dir checkpoint/stage3 \
  50. --log-dir log/stage3 \
  51. --epoch-start 22 \
  52. --epoch-end 23
  53. # Stage 4
  54. python train.py \
  55. --model-variant mobilenetv3 \
  56. --dataset imagematte \
  57. --train-hr \
  58. --resolution-lr 512 \
  59. --resolution-hr 2048 \
  60. --seq-length-lr 40 \
  61. --seq-length-hr 6 \
  62. --learning-rate-backbone 0.00001 \
  63. --learning-rate-aspp 0.00001 \
  64. --learning-rate-decoder 0.00005 \
  65. --learning-rate-refiner 0.0002 \
  66. --checkpoint checkpoint/stage3/epoch-22.pth \
  67. --checkpoint-dir checkpoint/stage4 \
  68. --log-dir log/stage4 \
  69. --epoch-start 23 \
  70. --epoch-end 28
  71. """
  72. import argparse
  73. import torch
  74. import random
  75. import os
  76. from torch import nn
  77. from torch import distributed as dist
  78. from torch import multiprocessing as mp
  79. from torch.nn import functional as F
  80. from torch.nn.parallel import DistributedDataParallel as DDP
  81. from torch.optim import Adam
  82. from torch.cuda.amp import autocast, GradScaler
  83. from torch.utils.data import DataLoader, ConcatDataset
  84. from torch.utils.data.distributed import DistributedSampler
  85. from torch.utils.tensorboard import SummaryWriter
  86. from torchvision.utils import make_grid
  87. from torchvision.transforms.functional import center_crop
  88. from tqdm import tqdm
  89. from dataset.videomatte import (
  90. VideoMatteDataset,
  91. VideoMatteTrainAugmentation,
  92. VideoMatteValidAugmentation,
  93. )
  94. from dataset.imagematte import (
  95. ImageMatteDataset,
  96. ImageMatteAugmentation
  97. )
  98. from dataset.coco import (
  99. CocoPanopticDataset,
  100. CocoPanopticTrainAugmentation,
  101. )
  102. from dataset.spd import (
  103. SuperviselyPersonDataset
  104. )
  105. from dataset.youtubevis import (
  106. YouTubeVISDataset,
  107. YouTubeVISAugmentation
  108. )
  109. from dataset.augmentation import (
  110. TrainFrameSampler,
  111. ValidFrameSampler
  112. )
  113. from model import MattingNetwork
  114. from train_config import DATA_PATHS
  115. from train_loss import matting_loss, segmentation_loss
  116. class Trainer:
  117. def __init__(self, rank, world_size):
  118. self.parse_args()
  119. self.init_distributed(rank, world_size)
  120. self.init_datasets()
  121. self.init_model()
  122. self.init_writer()
  123. self.train()
  124. self.cleanup()
  125. def parse_args(self):
  126. parser = argparse.ArgumentParser()
  127. # Model
  128. parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
  129. # Matting dataset
  130. parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte'])
  131. # Learning rate
  132. parser.add_argument('--learning-rate-backbone', type=float, required=True)
  133. parser.add_argument('--learning-rate-aspp', type=float, required=True)
  134. parser.add_argument('--learning-rate-decoder', type=float, required=True)
  135. parser.add_argument('--learning-rate-refiner', type=float, required=True)
  136. # Training setting
  137. parser.add_argument('--train-hr', action='store_true')
  138. parser.add_argument('--resolution-lr', type=int, default=512)
  139. parser.add_argument('--resolution-hr', type=int, default=2048)
  140. parser.add_argument('--seq-length-lr', type=int, required=True)
  141. parser.add_argument('--seq-length-hr', type=int, default=6)
  142. parser.add_argument('--downsample-ratio', type=float, default=0.25)
  143. parser.add_argument('--batch-size-per-gpu', type=int, default=1)
  144. parser.add_argument('--num-workers', type=int, default=8)
  145. parser.add_argument('--epoch-start', type=int, default=0)
  146. parser.add_argument('--epoch-end', type=int, default=16)
  147. # Tensorboard logging
  148. parser.add_argument('--log-dir', type=str, required=True)
  149. parser.add_argument('--log-train-loss-interval', type=int, default=20)
  150. parser.add_argument('--log-train-images-interval', type=int, default=500)
  151. # Checkpoint loading and saving
  152. parser.add_argument('--checkpoint', type=str)
  153. parser.add_argument('--checkpoint-dir', type=str, required=True)
  154. parser.add_argument('--checkpoint-save-interval', type=int, default=500)
  155. # Distributed
  156. parser.add_argument('--distributed-addr', type=str, default='localhost')
  157. parser.add_argument('--distributed-port', type=str, default='12355')
  158. # Debugging
  159. parser.add_argument('--disable-progress-bar', action='store_true')
  160. parser.add_argument('--disable-validation', action='store_true')
  161. parser.add_argument('--disable-mixed-precision', action='store_true')
  162. self.args = parser.parse_args()
  163. def init_distributed(self, rank, world_size):
  164. self.rank = rank
  165. self.world_size = world_size
  166. self.log('Initializing distributed')
  167. os.environ['MASTER_ADDR'] = self.args.distributed_addr
  168. os.environ['MASTER_PORT'] = self.args.distributed_port
  169. dist.init_process_group("nccl", rank=rank, world_size=world_size)
  170. def init_datasets(self):
  171. self.log('Initializing matting datasets')
  172. size_hr = (self.args.resolution_hr, self.args.resolution_hr)
  173. size_lr = (self.args.resolution_lr, self.args.resolution_lr)
  174. # Matting datasets:
  175. if self.args.dataset == 'videomatte':
  176. self.dataset_lr_train = VideoMatteDataset(
  177. videomatte_dir=DATA_PATHS['videomatte']['train'],
  178. background_image_dir=DATA_PATHS['background_images']['train'],
  179. background_video_dir=DATA_PATHS['background_videos']['train'],
  180. size=self.args.resolution_lr,
  181. seq_length=self.args.seq_length_lr,
  182. seq_sampler=TrainFrameSampler(),
  183. transform=VideoMatteTrainAugmentation(size_lr))
  184. if self.args.train_hr:
  185. self.dataset_hr_train = VideoMatteDataset(
  186. videomatte_dir=DATA_PATHS['videomatte']['train'],
  187. background_image_dir=DATA_PATHS['background_images']['train'],
  188. background_video_dir=DATA_PATHS['background_videos']['train'],
  189. size=self.args.resolution_hr,
  190. seq_length=self.args.seq_length_hr,
  191. seq_sampler=TrainFrameSampler(),
  192. transform=VideoMatteTrainAugmentation(size_hr))
  193. self.dataset_valid = VideoMatteDataset(
  194. videomatte_dir=DATA_PATHS['videomatte']['valid'],
  195. background_image_dir=DATA_PATHS['background_images']['valid'],
  196. background_video_dir=DATA_PATHS['background_videos']['valid'],
  197. size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
  198. seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
  199. seq_sampler=ValidFrameSampler(),
  200. transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
  201. else:
  202. self.dataset_lr_train = ImageMatteDataset(
  203. imagematte_dir=DATA_PATHS['imagematte']['train'],
  204. background_image_dir=DATA_PATHS['background_images']['train'],
  205. background_video_dir=DATA_PATHS['background_videos']['train'],
  206. size=self.args.resolution_lr,
  207. seq_length=self.args.seq_length_lr,
  208. seq_sampler=TrainFrameSampler(),
  209. transform=ImageMatteAugmentation(size_lr))
  210. if self.args.train_hr:
  211. self.dataset_hr_train = ImageMatteDataset(
  212. imagematte_dir=DATA_PATHS['imagematte']['train'],
  213. background_image_dir=DATA_PATHS['background_images']['train'],
  214. background_video_dir=DATA_PATHS['background_videos']['train'],
  215. size=self.args.resolution_hr,
  216. seq_length=self.args.seq_length_hr,
  217. seq_sampler=TrainFrameSampler(),
  218. transform=ImageMatteAugmentation(size_hr))
  219. self.dataset_valid = ImageMatteDataset(
  220. imagematte_dir=DATA_PATHS['imagematte']['valid'],
  221. background_image_dir=DATA_PATHS['background_images']['valid'],
  222. background_video_dir=DATA_PATHS['background_videos']['valid'],
  223. size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
  224. seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
  225. seq_sampler=ValidFrameSampler(),
  226. transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
  227. # Matting dataloaders:
  228. self.datasampler_lr_train = DistributedSampler(
  229. dataset=self.dataset_lr_train,
  230. rank=self.rank,
  231. num_replicas=self.world_size,
  232. shuffle=True)
  233. self.dataloader_lr_train = DataLoader(
  234. dataset=self.dataset_lr_train,
  235. batch_size=self.args.batch_size_per_gpu,
  236. num_workers=self.args.num_workers,
  237. sampler=self.datasampler_lr_train,
  238. pin_memory=True)
  239. if self.args.train_hr:
  240. self.datasampler_hr_train = DistributedSampler(
  241. dataset=self.dataset_hr_train,
  242. rank=self.rank,
  243. num_replicas=self.world_size,
  244. shuffle=True)
  245. self.dataloader_hr_train = DataLoader(
  246. dataset=self.dataset_hr_train,
  247. batch_size=self.args.batch_size_per_gpu,
  248. num_workers=self.args.num_workers,
  249. sampler=self.datasampler_hr_train,
  250. pin_memory=True)
  251. self.dataloader_valid = DataLoader(
  252. dataset=self.dataset_valid,
  253. batch_size=self.args.batch_size_per_gpu,
  254. num_workers=self.args.num_workers,
  255. pin_memory=True)
  256. # Segementation datasets
  257. self.log('Initializing image segmentation datasets')
  258. self.dataset_seg_image = ConcatDataset([
  259. CocoPanopticDataset(
  260. imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
  261. anndir=DATA_PATHS['coco_panoptic']['anndir'],
  262. annfile=DATA_PATHS['coco_panoptic']['annfile'],
  263. transform=CocoPanopticTrainAugmentation(size_lr)),
  264. SuperviselyPersonDataset(
  265. imgdir=DATA_PATHS['spd']['imgdir'],
  266. segdir=DATA_PATHS['spd']['segdir'],
  267. transform=CocoPanopticTrainAugmentation(size_lr))
  268. ])
  269. self.datasampler_seg_image = DistributedSampler(
  270. dataset=self.dataset_seg_image,
  271. rank=self.rank,
  272. num_replicas=self.world_size,
  273. shuffle=True)
  274. self.dataloader_seg_image = DataLoader(
  275. dataset=self.dataset_seg_image,
  276. batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
  277. num_workers=self.args.num_workers,
  278. sampler=self.datasampler_seg_image,
  279. pin_memory=True)
  280. self.log('Initializing video segmentation datasets')
  281. self.dataset_seg_video = YouTubeVISDataset(
  282. videodir=DATA_PATHS['youtubevis']['videodir'],
  283. annfile=DATA_PATHS['youtubevis']['annfile'],
  284. size=self.args.resolution_lr,
  285. seq_length=self.args.seq_length_lr,
  286. seq_sampler=TrainFrameSampler(speed=[1]),
  287. transform=YouTubeVISAugmentation(size_lr))
  288. self.datasampler_seg_video = DistributedSampler(
  289. dataset=self.dataset_seg_video,
  290. rank=self.rank,
  291. num_replicas=self.world_size,
  292. shuffle=True)
  293. self.dataloader_seg_video = DataLoader(
  294. dataset=self.dataset_seg_video,
  295. batch_size=self.args.batch_size_per_gpu,
  296. num_workers=self.args.num_workers,
  297. sampler=self.datasampler_seg_video,
  298. pin_memory=True)
  299. def init_model(self):
  300. self.log('Initializing model')
  301. self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank)
  302. if self.args.checkpoint:
  303. self.log(f'Restoring from checkpoint: {self.args.checkpoint}')
  304. self.log(self.model.load_state_dict(
  305. torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}')))
  306. self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
  307. self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True)
  308. self.optimizer = Adam([
  309. {'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone},
  310. {'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp},
  311. {'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder},
  312. {'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder},
  313. {'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder},
  314. {'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner},
  315. ])
  316. self.scaler = GradScaler()
  317. def init_writer(self):
  318. if self.rank == 0:
  319. self.log('Initializing writer')
  320. self.writer = SummaryWriter(self.args.log_dir)
  321. def train(self):
  322. for epoch in range(self.args.epoch_start, self.args.epoch_end):
  323. self.epoch = epoch
  324. self.step = epoch * len(self.dataloader_lr_train)
  325. if not self.args.disable_validation:
  326. self.validate()
  327. self.log(f'Training epoch: {epoch}')
  328. for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True):
  329. # Low resolution pass
  330. self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr')
  331. # High resolution pass
  332. if self.args.train_hr:
  333. true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample()
  334. self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
  335. # Segmentation pass
  336. if self.step % 2 == 0:
  337. true_img, true_seg = self.load_next_seg_video_sample()
  338. self.train_seg(true_img, true_seg, log_label='seg_video')
  339. else:
  340. true_img, true_seg = self.load_next_seg_image_sample()
  341. self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
  342. if self.step % self.args.checkpoint_save_interval == 0:
  343. self.save()
  344. self.step += 1
  345. def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
  346. true_fgr = true_fgr.to(self.rank, non_blocking=True)
  347. true_pha = true_pha.to(self.rank, non_blocking=True)
  348. true_bgr = true_bgr.to(self.rank, non_blocking=True)
  349. true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
  350. true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
  351. with autocast(enabled=not self.args.disable_mixed_precision):
  352. pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2]
  353. loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
  354. self.scaler.scale(loss['total']).backward()
  355. self.scaler.step(self.optimizer)
  356. self.scaler.update()
  357. self.optimizer.zero_grad()
  358. if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0:
  359. for loss_name, loss_value in loss.items():
  360. self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step)
  361. if self.rank == 0 and self.step % self.args.log_train_images_interval == 0:
  362. self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step)
  363. self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step)
  364. self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
  365. self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
  366. self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
  367. def train_seg(self, true_img, true_seg, log_label):
  368. true_img = true_img.to(self.rank, non_blocking=True)
  369. true_seg = true_seg.to(self.rank, non_blocking=True)
  370. true_img, true_seg = self.random_crop(true_img, true_seg)
  371. with autocast(enabled=not self.args.disable_mixed_precision):
  372. pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
  373. loss = segmentation_loss(pred_seg, true_seg)
  374. self.scaler.scale(loss).backward()
  375. self.scaler.step(self.optimizer)
  376. self.scaler.update()
  377. self.optimizer.zero_grad()
  378. if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
  379. self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
  380. if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
  381. self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
  382. self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
  383. self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
  384. def load_next_mat_hr_sample(self):
  385. try:
  386. sample = next(self.dataiterator_mat_hr)
  387. except:
  388. self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1)
  389. self.dataiterator_mat_hr = iter(self.dataloader_hr_train)
  390. sample = next(self.dataiterator_mat_hr)
  391. return sample
  392. def load_next_seg_video_sample(self):
  393. try:
  394. sample = next(self.dataiterator_seg_video)
  395. except:
  396. self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1)
  397. self.dataiterator_seg_video = iter(self.dataloader_seg_video)
  398. sample = next(self.dataiterator_seg_video)
  399. return sample
  400. def load_next_seg_image_sample(self):
  401. try:
  402. sample = next(self.dataiterator_seg_image)
  403. except:
  404. self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1)
  405. self.dataiterator_seg_image = iter(self.dataloader_seg_image)
  406. sample = next(self.dataiterator_seg_image)
  407. return sample
  408. def validate(self):
  409. if self.rank == 0:
  410. self.log(f'Validating at the start of epoch: {self.epoch}')
  411. self.model_ddp.eval()
  412. total_loss, total_count = 0, 0
  413. with torch.no_grad():
  414. with autocast(enabled=not self.args.disable_mixed_precision):
  415. for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True):
  416. true_fgr = true_fgr.to(self.rank, non_blocking=True)
  417. true_pha = true_pha.to(self.rank, non_blocking=True)
  418. true_bgr = true_bgr.to(self.rank, non_blocking=True)
  419. true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
  420. batch_size = true_src.size(0)
  421. pred_fgr, pred_pha = self.model(true_src)[:2]
  422. total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
  423. total_count += batch_size
  424. avg_loss = total_loss / total_count
  425. self.log(f'Validation set average loss: {avg_loss}')
  426. self.writer.add_scalar('valid_loss', avg_loss, self.step)
  427. self.model_ddp.train()
  428. dist.barrier()
  429. def random_crop(self, *imgs):
  430. h, w = imgs[0].shape[-2:]
  431. w = random.choice(range(w // 2, w))
  432. h = random.choice(range(h // 2, h))
  433. results = []
  434. for img in imgs:
  435. B, T = img.shape[:2]
  436. img = img.flatten(0, 1)
  437. img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False)
  438. img = center_crop(img, (h, w))
  439. img = img.reshape(B, T, *img.shape[1:])
  440. results.append(img)
  441. return results
  442. def save(self):
  443. if self.rank == 0:
  444. os.makedirs(self.args.checkpoint_dir, exist_ok=True)
  445. torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth'))
  446. self.log('Model saved')
  447. dist.barrier()
  448. def cleanup(self):
  449. dist.destroy_process_group()
  450. def log(self, msg):
  451. print(f'[GPU{self.rank}] {msg}')
  452. if __name__ == '__main__':
  453. world_size = torch.cuda.device_count()
  454. mp.spawn(
  455. Trainer,
  456. nprocs=world_size,
  457. args=(world_size,),
  458. join=True)