CCLAP / hist_loss.py
RobinWZQ's picture
Upload 6 files
c8c90c7
"""
Copyright 2021 Mahmoud Afifi.
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
In CVPR, 2021.
@inproceedings{afifi2021histogan,
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
Color Histograms},
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
booktitle={CVPR},
year={2021}
}
"""
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
EPS = 1e-6
class RGBuvHistBlock(nn.Module):
def __init__(self, h=64, insz=150, resizing='interpolation',
method='inverse-quadratic', sigma=0.02, intensity_scale=True,
device='cuda'):
""" Computes the RGB-uv histogram feature of a given image.
Args:
h: histogram dimension size (scalar). The default value is 64.
insz: maximum size of the input image; if it is larger than this size, the
image will be resized (scalar). Default value is 150 (i.e., 150 x 150
pixels).
resizing: resizing method if applicable. Options are: 'interpolation' or
'sampling'. Default is 'interpolation'.
method: the method used to count the number of pixels for each bin in the
histogram feature. Options are: 'thresholding', 'RBF' (radial basis
function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'.
sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is
the sigma parameter of the kernel function. The default value is 0.02.
intensity_scale: boolean variable to use the intensity scale (I_y in
Equation 2). Default value is True.
Methods:
forward: accepts input image and returns its histogram feature. Note that
unless the method is 'thresholding', this is a differentiable function
and can be easily integrated with the loss function. As mentioned in the
paper, the 'inverse-quadratic' was found more stable than 'RBF' in our
training.
"""
super(RGBuvHistBlock, self).__init__()
self.h = h
self.insz = insz
self.device = device
self.resizing = resizing
self.method = method
self.intensity_scale = intensity_scale
if self.method == 'thresholding':
self.eps = 6.0 / h
else:
self.sigma = sigma
def forward(self, x):
x = torch.clamp(x, 0, 1)
if x.shape[2] > self.insz or x.shape[3] > self.insz:
if self.resizing == 'interpolation':
x_sampled = F.interpolate(x, size=(self.insz, self.insz),
mode='bilinear', align_corners=False)
elif self.resizing == 'sampling':
inds_1 = torch.LongTensor(
np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
device=self.device)
inds_2 = torch.LongTensor(
np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
device=self.device)
x_sampled = x.index_select(2, inds_1)
x_sampled = x_sampled.index_select(3, inds_2)
else:
raise Exception(
f'Wrong resizing method. It should be: interpolation or sampling. '
f'But the given value is {self.resizing}.')
else:
x_sampled = x
L = x_sampled.shape[0] # size of mini-batch
if x_sampled.shape[1] > 3:
x_sampled = x_sampled[:, :3, :, :]
X = torch.unbind(x_sampled, dim=0)
hists = torch.zeros((x_sampled.shape[0], 3, self.h, self.h)).to(
device=self.device)
for l in range(L):
I = torch.t(torch.reshape(X[l], (3, -1)))
II = torch.pow(I, 2)
if self.intensity_scale:
Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
dim=1)
else:
Iy = 1
Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + EPS),
dim=1)
Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + EPS),
dim=1)
diff_u0 = abs(
Iu0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
dim=0).to(self.device))
diff_v0 = abs(
Iv0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
dim=0).to(self.device))
if self.method == 'thresholding':
diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
elif self.method == 'RBF':
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
2) / self.sigma ** 2
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
2) / self.sigma ** 2
diff_u0 = torch.exp(-diff_u0) # Radial basis function
diff_v0 = torch.exp(-diff_v0)
elif self.method == 'inverse-quadratic':
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
2) / self.sigma ** 2
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
2) / self.sigma ** 2
diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic
diff_v0 = 1 / (1 + diff_v0)
else:
raise Exception(
f'Wrong kernel method. It should be either thresholding, RBF,'
f' inverse-quadratic. But the given value is {self.method}.')
diff_u0 = diff_u0.type(torch.float32)
diff_v0 = diff_v0.type(torch.float32)
a = torch.t(Iy * diff_u0)
hists[l, 0, :, :] = torch.mm(a, diff_v0)
Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
dim=1)
Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
dim=1)
diff_u1 = abs(
Iu1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
dim=0).to(self.device))
diff_v1 = abs(
Iv1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
dim=0).to(self.device))
if self.method == 'thresholding':
diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
elif self.method == 'RBF':
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
2) / self.sigma ** 2
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
2) / self.sigma ** 2
diff_u1 = torch.exp(-diff_u1) # Gaussian
diff_v1 = torch.exp(-diff_v1)
elif self.method == 'inverse-quadratic':
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
2) / self.sigma ** 2
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
2) / self.sigma ** 2
diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic
diff_v1 = 1 / (1 + diff_v1)
diff_u1 = diff_u1.type(torch.float32)
diff_v1 = diff_v1.type(torch.float32)
a = torch.t(Iy * diff_u1)
hists[l, 1, :, :] = torch.mm(a, diff_v1)
Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + EPS),
dim=1)
Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + EPS),
dim=1)
diff_u2 = abs(
Iu2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
dim=0).to(self.device))
diff_v2 = abs(
Iv2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
dim=0).to(self.device))
if self.method == 'thresholding':
diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
elif self.method == 'RBF':
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
2) / self.sigma ** 2
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
2) / self.sigma ** 2
diff_u2 = torch.exp(-diff_u2) # Gaussian
diff_v2 = torch.exp(-diff_v2)
elif self.method == 'inverse-quadratic':
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
2) / self.sigma ** 2
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
2) / self.sigma ** 2
diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic
diff_v2 = 1 / (1 + diff_v2)
diff_u2 = diff_u2.type(torch.float32)
diff_v2 = diff_v2.type(torch.float32)
a = torch.t(Iy * diff_u2)
hists[l, 2, :, :] = torch.mm(a, diff_v2)
# normalization
hists_normalized = hists / (
((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)
return hists_normalized