Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| import numbers | |
| from torchvision.transforms import RandomCrop, RandomResizedCrop | |
| from PIL import Image | |
| def _is_tensor_video_clip(clip): | |
| if not torch.is_tensor(clip): | |
| raise TypeError("clip should be Tensor. Got %s" % type(clip)) | |
| if not clip.ndimension() == 4: | |
| raise ValueError("clip should be 4D. Got %dD" % clip.dim()) | |
| return True | |
| def center_crop_arr(pil_image, image_size): | |
| """ | |
| Center cropping implementation from ADM. | |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
| """ | |
| while min(*pil_image.size) >= 2 * image_size: | |
| pil_image = pil_image.resize( | |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
| ) | |
| scale = image_size / min(*pil_image.size) | |
| pil_image = pil_image.resize( | |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
| ) | |
| arr = np.array(pil_image) | |
| crop_y = (arr.shape[0] - image_size) // 2 | |
| crop_x = (arr.shape[1] - image_size) // 2 | |
| return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) | |
| def crop(clip, i, j, h, w): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| """ | |
| if len(clip.size()) != 4: | |
| raise ValueError("clip should be a 4D tensor") | |
| return clip[..., i : i + h, j : j + w] | |
| def resize(clip, target_size, interpolation_mode): | |
| if len(target_size) != 2: | |
| raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") | |
| return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) | |
| def resize_scale(clip, target_size, interpolation_mode): | |
| if len(target_size) != 2: | |
| raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") | |
| H, W = clip.size(-2), clip.size(-1) | |
| scale_ = target_size[0] / min(H, W) | |
| return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) | |
| def resize_with_scale_factor(clip, scale_factor, interpolation_mode): | |
| return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) | |
| def resize_scale_with_height(clip, target_size, interpolation_mode): | |
| H, W = clip.size(-2), clip.size(-1) | |
| scale_ = target_size / H | |
| return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) | |
| def resize_scale_with_weight(clip, target_size, interpolation_mode): | |
| H, W = clip.size(-2), clip.size(-1) | |
| scale_ = target_size / W | |
| return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) | |
| def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): | |
| """ | |
| Do spatial cropping and resizing to the video clip | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| i (int): i in (i,j) i.e coordinates of the upper left corner. | |
| j (int): j in (i,j) i.e coordinates of the upper left corner. | |
| h (int): Height of the cropped region. | |
| w (int): Width of the cropped region. | |
| size (tuple(int, int)): height and width of resized clip | |
| Returns: | |
| clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) | |
| """ | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| clip = crop(clip, i, j, h, w) | |
| clip = resize(clip, size, interpolation_mode) | |
| return clip | |
| def center_crop(clip, crop_size): | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| h, w = clip.size(-2), clip.size(-1) | |
| # print(clip.shape) | |
| th, tw = crop_size | |
| if h < th or w < tw: | |
| # print(h, w) | |
| raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) | |
| i = int(round((h - th) / 2.0)) | |
| j = int(round((w - tw) / 2.0)) | |
| return crop(clip, i, j, th, tw) | |
| def center_crop_using_short_edge(clip): | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| h, w = clip.size(-2), clip.size(-1) | |
| if h < w: | |
| th, tw = h, h | |
| i = 0 | |
| j = int(round((w - tw) / 2.0)) | |
| else: | |
| th, tw = w, w | |
| i = int(round((h - th) / 2.0)) | |
| j = 0 | |
| return crop(clip, i, j, th, tw) | |
| def random_shift_crop(clip): | |
| ''' | |
| Slide along the long edge, with the short edge as crop size | |
| ''' | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| h, w = clip.size(-2), clip.size(-1) | |
| if h <= w: | |
| long_edge = w | |
| short_edge = h | |
| else: | |
| long_edge = h | |
| short_edge =w | |
| th, tw = short_edge, short_edge | |
| i = torch.randint(0, h - th + 1, size=(1,)).item() | |
| j = torch.randint(0, w - tw + 1, size=(1,)).item() | |
| return crop(clip, i, j, th, tw) | |
| def to_tensor(clip): | |
| """ | |
| Convert tensor data type from uint8 to float, divide value by 255.0 and | |
| permute the dimensions of clip tensor | |
| Args: | |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) | |
| Return: | |
| clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) | |
| """ | |
| _is_tensor_video_clip(clip) | |
| if not clip.dtype == torch.uint8: | |
| raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) | |
| # return clip.float().permute(3, 0, 1, 2) / 255.0 | |
| return clip.float() / 255.0 | |
| def normalize(clip, mean, std, inplace=False): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) | |
| mean (tuple): pixel RGB mean. Size is (3) | |
| std (tuple): pixel standard deviation. Size is (3) | |
| Returns: | |
| normalized clip (torch.tensor): Size is (T, C, H, W) | |
| """ | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| if not inplace: | |
| clip = clip.clone() | |
| mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) | |
| # print(mean) | |
| std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) | |
| clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) | |
| return clip | |
| def hflip(clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) | |
| Returns: | |
| flipped clip (torch.tensor): Size is (T, C, H, W) | |
| """ | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| return clip.flip(-1) | |
| class RandomCropVideo: | |
| def __init__(self, size): | |
| if isinstance(size, numbers.Number): | |
| self.size = (int(size), int(size)) | |
| else: | |
| self.size = size | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: randomly cropped video clip. | |
| size is (T, C, OH, OW) | |
| """ | |
| i, j, h, w = self.get_params(clip) | |
| return crop(clip, i, j, h, w) | |
| def get_params(self, clip): | |
| h, w = clip.shape[-2:] | |
| th, tw = self.size | |
| if h < th or w < tw: | |
| raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") | |
| if w == tw and h == th: | |
| return 0, 0, h, w | |
| i = torch.randint(0, h - th + 1, size=(1,)).item() | |
| j = torch.randint(0, w - tw + 1, size=(1,)).item() | |
| return i, j, th, tw | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size})" | |
| class CenterCropResizeVideo: | |
| ''' | |
| First use the short side for cropping length, | |
| center crop video, then resize to the specified size | |
| ''' | |
| def __init__( | |
| self, | |
| size, | |
| interpolation_mode="bilinear", | |
| ): | |
| if isinstance(size, tuple): | |
| if len(size) != 2: | |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation_mode = interpolation_mode | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: scale resized / center cropped video clip. | |
| size is (T, C, crop_size, crop_size) | |
| """ | |
| # print(clip.shape) | |
| clip_center_crop = center_crop_using_short_edge(clip) | |
| # print(clip_center_crop.shape) 320 512 | |
| clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) | |
| return clip_center_crop_resize | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
| class WebVideo320512: | |
| def __init__( | |
| self, | |
| size, | |
| interpolation_mode="bilinear", | |
| ): | |
| if isinstance(size, tuple): | |
| if len(size) != 2: | |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation_mode = interpolation_mode | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: scale resized / center cropped video clip. | |
| size is (T, C, crop_size, crop_size) | |
| """ | |
| # add aditional one pixel for avoiding error in center crop | |
| h, w = clip.size(-2), clip.size(-1) | |
| # print('before resize', clip.shape) | |
| if h < 320: | |
| clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode) | |
| # print('after h resize', clip.shape) | |
| if w < 512: | |
| clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode) | |
| # print('after w resize', clip.shape) | |
| clip_center_crop = center_crop(clip, self.size) | |
| # print(clip_center_crop.shape) | |
| return clip_center_crop | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
| class UCFCenterCropVideo: | |
| ''' | |
| First scale to the specified size in equal proportion to the short edge, | |
| then center cropping | |
| ''' | |
| def __init__( | |
| self, | |
| size, | |
| interpolation_mode="bilinear", | |
| ): | |
| if isinstance(size, tuple): | |
| if len(size) != 2: | |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation_mode = interpolation_mode | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: scale resized / center cropped video clip. | |
| size is (T, C, crop_size, crop_size) | |
| """ | |
| clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) | |
| clip_center_crop = center_crop(clip_resize, self.size) | |
| return clip_center_crop | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
| class CenterCropVideo: | |
| def __init__( | |
| self, | |
| size, | |
| interpolation_mode="bilinear", | |
| ): | |
| if isinstance(size, tuple): | |
| if len(size) != 2: | |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation_mode = interpolation_mode | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: center cropped video clip. | |
| size is (T, C, crop_size, crop_size) | |
| """ | |
| clip_center_crop = center_crop(clip, self.size) | |
| return clip_center_crop | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
| class NormalizeVideo: | |
| """ | |
| Normalize the video clip by mean subtraction and division by standard deviation | |
| Args: | |
| mean (3-tuple): pixel RGB mean | |
| std (3-tuple): pixel RGB standard deviation | |
| inplace (boolean): whether do in-place normalization | |
| """ | |
| def __init__(self, mean, std, inplace=False): | |
| self.mean = mean | |
| self.std = std | |
| self.inplace = inplace | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) | |
| """ | |
| return normalize(clip, self.mean, self.std, self.inplace) | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" | |
| class ToTensorVideo: | |
| """ | |
| Convert tensor data type from uint8 to float, divide value by 255.0 and | |
| permute the dimensions of clip tensor | |
| """ | |
| def __init__(self): | |
| pass | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) | |
| Return: | |
| clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) | |
| """ | |
| return to_tensor(clip) | |
| def __repr__(self) -> str: | |
| return self.__class__.__name__ | |
| class ResizeVideo(): | |
| ''' | |
| First use the short side for cropping length, | |
| center crop video, then resize to the specified size | |
| ''' | |
| def __init__( | |
| self, | |
| size, | |
| interpolation_mode="bilinear", | |
| ): | |
| if isinstance(size, tuple): | |
| if len(size) != 2: | |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation_mode = interpolation_mode | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: scale resized / center cropped video clip. | |
| size is (T, C, crop_size, crop_size) | |
| """ | |
| clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) | |
| return clip_resize | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
| # ------------------------------------------------------------ | |
| # --------------------- Sampling --------------------------- | |
| # ------------------------------------------------------------ | |
| class TemporalRandomCrop(object): | |
| """Temporally crop the given frame indices at a random location. | |
| Args: | |
| size (int): Desired length of frames will be seen in the model. | |
| """ | |
| def __init__(self, size): | |
| self.size = size | |
| def __call__(self, total_frames): | |
| rand_end = max(0, total_frames - self.size - 1) | |
| begin_index = random.randint(0, rand_end) | |
| end_index = min(begin_index + self.size, total_frames) | |
| return begin_index, end_index | |