spd.py 898 B

123456789101112131415161718192021222324252627
  1. import os
  2. from torch.utils.data import Dataset
  3. from PIL import Image
  4. class SuperviselyPersonDataset(Dataset):
  5. def __init__(self, imgdir, segdir, transform=None):
  6. self.img_dir = imgdir
  7. self.img_files = sorted(os.listdir(imgdir))
  8. self.seg_dir = segdir
  9. self.seg_files = sorted(os.listdir(segdir))
  10. assert len(self.img_files) == len(self.seg_files)
  11. self.transform = transform
  12. def __len__(self):
  13. return len(self.img_files)
  14. def __getitem__(self, idx):
  15. with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \
  16. Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg:
  17. img = img.convert('RGB')
  18. seg = seg.convert('L')
  19. if self.transform is not None:
  20. img, seg = self.transform(img, seg)
  21. return img, seg