inference_utils.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import av
  2. import os
  3. import pims
  4. import numpy as np
  5. from torch.utils.data import Dataset
  6. from torchvision.transforms.functional import to_pil_image
  7. from PIL import Image
  8. class VideoReader(Dataset):
  9. def __init__(self, path, transform=None):
  10. self.video = pims.PyAVVideoReader(path)
  11. self.rate = self.video.frame_rate
  12. self.transform = transform
  13. @property
  14. def frame_rate(self):
  15. return self.rate
  16. def __len__(self):
  17. return len(self.video)
  18. def __getitem__(self, idx):
  19. frame = self.video[idx]
  20. frame = Image.fromarray(np.asarray(frame))
  21. if self.transform is not None:
  22. frame = self.transform(frame)
  23. return frame
  24. class VideoWriter:
  25. def __init__(self, path, frame_rate, bit_rate=1000000):
  26. self.container = av.open(path, mode='w')
  27. self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}')
  28. self.stream.pix_fmt = 'yuv420p'
  29. self.stream.bit_rate = bit_rate
  30. def write(self, frames):
  31. # frames: [T, C, H, W]
  32. self.stream.width = frames.size(3)
  33. self.stream.height = frames.size(2)
  34. if frames.size(1) == 1:
  35. frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
  36. frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
  37. for t in range(frames.shape[0]):
  38. frame = frames[t]
  39. frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
  40. self.container.mux(self.stream.encode(frame))
  41. def close(self):
  42. self.container.mux(self.stream.encode())
  43. self.container.close()
  44. class ImageSequenceReader(Dataset):
  45. def __init__(self, path, transform=None):
  46. self.path = path
  47. self.files = sorted(os.listdir(path))
  48. self.transform = transform
  49. def __len__(self):
  50. return len(self.files)
  51. def __getitem__(self, idx):
  52. with Image.open(os.path.join(self.path, self.files[idx])) as img:
  53. img.load()
  54. if self.transform is not None:
  55. return self.transform(img)
  56. return img
  57. class ImageSequenceWriter:
  58. def __init__(self, path, extension='jpg'):
  59. self.path = path
  60. self.extension = extension
  61. self.counter = 0
  62. os.makedirs(path, exist_ok=True)
  63. def write(self, frames):
  64. # frames: [T, C, H, W]
  65. for t in range(frames.shape[0]):
  66. to_pil_image(frames[t]).save(os.path.join(
  67. self.path, str(self.counter).zfill(4) + '.' + self.extension))
  68. self.counter += 1
  69. def close(self):
  70. pass