| 123456789101112131415161718192021222324252627 | import osfrom torch.utils.data import Datasetfrom PIL import Imageclass SuperviselyPersonDataset(Dataset):    def __init__(self, imgdir, segdir, transform=None):        self.img_dir = imgdir        self.img_files = sorted(os.listdir(imgdir))        self.seg_dir = segdir        self.seg_files = sorted(os.listdir(segdir))        assert len(self.img_files) == len(self.seg_files)        self.transform = transform            def __len__(self):        return len(self.img_files)        def __getitem__(self, idx):        with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \             Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg:            img = img.convert('RGB')            seg = seg.convert('L')                if self.transform is not None:            img, seg = self.transform(img, seg)                    return img, seg
 |