File size: 2,415 Bytes
223d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------

from typing import Optional
import numpy as np

from torchvision import transforms
import albumentations as A

BICUBIC = transforms.InterpolationMode.BICUBIC

normalize_params = {
    "plain": {"mean": (0.5,), "std": (0.5,)},
    "cnn": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
    "clip": {"mean": (0.48145466, 0.4578275, 0.40821073), "std": (0.26862954, 0.26130258, 0.27577711)}
}

recover_map_dict = {
    "plain": transforms.Normalize(
        mean=(-1,), std=(2,)
    ),
    "cnn": transforms.Normalize(
        mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
        std=(1/0.229, 1/0.224, 1/0.225)
    ),
    "clip": transforms.Normalize(
        mean=(-0.48145466/0.26862954, -0.4578275/0.26130258, -0.40821073/0.27577711),
        std=(1/0.26862954, 1/0.26130258, 1/0.27577711)
    )
}

def get_recover_map(name: str):
    return recover_map_dict[name]

###########################################
# Preprocessor
###########################################

def plain_preprocessor(resize: Optional[int] = 32):
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.Resize(resize),
    ])

def imagenet_preprocessor(resize: Optional[int] = 256, is_train: bool = True):
    if is_train:
        # augmentation v1
        # transform = A.Compose([
        #     A.SmallestMaxSize(max_size=resize),
        #     A.RandomCrop(height=resize, width=resize),
        #     A.HorizontalFlip(p=0.5),
        # ])

        # augmentation v2
        transform = A.Compose([
            A.SmallestMaxSize(max_size=resize),
            A.RandomResizedCrop(width=resize, height=resize, scale=(0.2, 1.0)),
            A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
            A.GaussianBlur(blur_limit=7, p=0.5),
            A.HorizontalFlip(p=0.5),
        ])
    else:
        transform = A.Compose([
            A.SmallestMaxSize(max_size=resize),
            A.CenterCrop(height=resize, width=resize),
        ])
    return transform