File size: 370 Bytes
daf0288
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import transforms
import numpy as np


def normalize_image_for_visualization(mean: float, std: float):
    invNormalization = transforms.Compose(
        [
            transforms.Normalize(mean=[0.0] * 3, std=1.0 / np.array(std)),
            transforms.Normalize(mean=-1.0 * np.array(mean), std=[1.0] * 3),
        ]
    )

    return invNormalization