123456789101112131415161718192021222324252627 |
- import os
- from torch.utils.data import Dataset
- from PIL import Image
- class SuperviselyPersonDataset(Dataset):
- def __init__(self, imgdir, segdir, transform=None):
- self.img_dir = imgdir
- self.img_files = sorted(os.listdir(imgdir))
- self.seg_dir = segdir
- self.seg_files = sorted(os.listdir(segdir))
- assert len(self.img_files) == len(self.seg_files)
- self.transform = transform
-
- def __len__(self):
- return len(self.img_files)
-
- def __getitem__(self, idx):
- with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \
- Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg:
- img = img.convert('RGB')
- seg = seg.convert('L')
-
- if self.transform is not None:
- img, seg = self.transform(img, seg)
-
- return img, seg
|