Spaces:
Build error
Build error
File size: 3,506 Bytes
daf0288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from typing import Tuple, Any, Optional, Union
from torch import Tensor
import random
from PIL import Image
import torchvision.transforms.functional as F
from torchvision import datasets, transforms
from torchvision.transforms.transforms import _setup_size
_PIL_INTERPOLATION = {
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"lanczos": Image.LANCZOS,
"hamming": Image.HAMMING,
}
get_interpolation = lambda method: _PIL_INTERPOLATION.get(method, Image.BILINEAR)
class RandomResizedCropAndInterpolationWithTwoPic(transforms.RandomResizedCrop):
"""Ensure both crops of vqvae and visual encoder have the same scale and size."""
def __init__(
self,
size: Union[int, Tuple[int, int]], # transformer
second_size: Union[int, Tuple[int, int]], # vqvae
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: str = "bilinear",
second_interpolation: str = "lanczos",
):
self.second_size = _setup_size(
second_size,
error_msg="Please provide only two dimensions (h, w) for second size.",
)
if interpolation == "random":
interpolation = random.choice(
[get_interpolation("bilinear"), get_interpolation("bicubic")]
)
else:
interpolation = get_interpolation(interpolation)
self.second_interpolation = get_interpolation(second_interpolation)
super().__init__(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
def forward(self, img: Image):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
out = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
out_second = F.resized_crop(
img, i, j, h, w, self.second_size, self.second_interpolation
)
return out, out_second
class AugmentationForMIM(object):
def __init__(
self,
mean: float,
std: float,
trans_size: Union[int, Tuple[int, int]],
vqvae_size: Union[int, Tuple[int, int]],
trans_interpolation: str,
vqvae_interpolation: str,
) -> None:
self.common_transform = transforms.Compose(
[
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(p=0.5),
RandomResizedCropAndInterpolationWithTwoPic(
size=trans_size,
second_size=vqvae_size,
interpolation=trans_interpolation,
second_interpolation=vqvae_interpolation,
),
]
)
self.trans_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
)
self.vqvae_transform = transforms.ToTensor()
def __call__(self, img: Image) -> Tuple[Tensor, Tensor]:
trans_img, vqvae_img = self.common_transform(img)
trans_img = self.trans_transform(trans_img)
vqvae_img = self.vqvae_transform(vqvae_img)
return trans_img, vqvae_img
if __name__ == "__main__":
mean = [240.380, 240.390, 240.486]
std = [45.735, 45.785, 45.756]
T = RandomResizedCropAndInterpolationWithTwoPic(
size=(256, 256),
second_size=(256, 256),
interpolation="bicubic",
second_interpolation="lanczos",
)
print(T)
|