| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 | import avimport osimport pimsimport numpy as npfrom torch.utils.data import Datasetfrom torchvision.transforms.functional import to_pil_imagefrom PIL import Imageclass VideoReader(Dataset):    def __init__(self, path, transform=None):        self.video = pims.PyAVVideoReader(path)        self.rate = self.video.frame_rate        self.transform = transform            @property    def frame_rate(self):        return self.rate            def __len__(self):        return len(self.video)            def __getitem__(self, idx):        frame = self.video[idx]        frame = Image.fromarray(np.asarray(frame))        if self.transform is not None:            frame = self.transform(frame)        return frameclass VideoWriter:    def __init__(self, path, frame_rate, bit_rate=1000000):        self.container = av.open(path, mode='w')        self.stream = self.container.add_stream('h264', rate=round(frame_rate))        self.stream.pix_fmt = 'yuv420p'        self.stream.bit_rate = bit_rate        def write(self, frames):        # frames: [T, C, H, W]        self.stream.width = frames.size(3)        self.stream.height = frames.size(2)        if frames.size(1) == 1:            frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB        frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()        for t in range(frames.shape[0]):            frame = frames[t]            frame = av.VideoFrame.from_ndarray(frame, format='rgb24')            self.container.mux(self.stream.encode(frame))                    def close(self):        self.container.mux(self.stream.encode())        self.container.close()class ImageSequenceReader(Dataset):    def __init__(self, path, transform=None):        self.path = path        self.files = sorted(os.listdir(path))        self.transform = transform            def __len__(self):        return len(self.files)        def __getitem__(self, idx):        with Image.open(os.path.join(self.path, self.files[idx])) as img:            img.load()        if self.transform is not None:            return self.transform(img)        return imgclass ImageSequenceWriter:    def __init__(self, path, extension='jpg'):        self.path = path        self.extension = extension        self.counter = 0        os.makedirs(path, exist_ok=True)        def write(self, frames):        # frames: [T, C, H, W]        for t in range(frames.shape[0]):            to_pil_image(frames[t]).save(os.path.join(                self.path, str(self.counter).zfill(4) + '.' + self.extension))            self.counter += 1                def close(self):        pass        
 |