|
""" |
|
Copyright 2021 Mahmoud Afifi. |
|
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN: |
|
Controlling Colors of GAN-Generated and Real Images via Color Histograms." |
|
In CVPR, 2021. |
|
|
|
@inproceedings{afifi2021histogan, |
|
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via |
|
Color Histograms}, |
|
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.}, |
|
booktitle={CVPR}, |
|
year={2021} |
|
} |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
|
|
EPS = 1e-6 |
|
|
|
class RGBuvHistBlock(nn.Module): |
|
def __init__(self, h=64, insz=150, resizing='interpolation', |
|
method='inverse-quadratic', sigma=0.02, intensity_scale=True, |
|
device='cuda'): |
|
""" Computes the RGB-uv histogram feature of a given image. |
|
Args: |
|
h: histogram dimension size (scalar). The default value is 64. |
|
insz: maximum size of the input image; if it is larger than this size, the |
|
image will be resized (scalar). Default value is 150 (i.e., 150 x 150 |
|
pixels). |
|
resizing: resizing method if applicable. Options are: 'interpolation' or |
|
'sampling'. Default is 'interpolation'. |
|
method: the method used to count the number of pixels for each bin in the |
|
histogram feature. Options are: 'thresholding', 'RBF' (radial basis |
|
function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'. |
|
sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is |
|
the sigma parameter of the kernel function. The default value is 0.02. |
|
intensity_scale: boolean variable to use the intensity scale (I_y in |
|
Equation 2). Default value is True. |
|
|
|
Methods: |
|
forward: accepts input image and returns its histogram feature. Note that |
|
unless the method is 'thresholding', this is a differentiable function |
|
and can be easily integrated with the loss function. As mentioned in the |
|
paper, the 'inverse-quadratic' was found more stable than 'RBF' in our |
|
training. |
|
""" |
|
super(RGBuvHistBlock, self).__init__() |
|
self.h = h |
|
self.insz = insz |
|
self.device = device |
|
self.resizing = resizing |
|
self.method = method |
|
self.intensity_scale = intensity_scale |
|
if self.method == 'thresholding': |
|
self.eps = 6.0 / h |
|
else: |
|
self.sigma = sigma |
|
|
|
def forward(self, x): |
|
x = torch.clamp(x, 0, 1) |
|
if x.shape[2] > self.insz or x.shape[3] > self.insz: |
|
if self.resizing == 'interpolation': |
|
x_sampled = F.interpolate(x, size=(self.insz, self.insz), |
|
mode='bilinear', align_corners=False) |
|
elif self.resizing == 'sampling': |
|
inds_1 = torch.LongTensor( |
|
np.linspace(0, x.shape[2], self.h, endpoint=False)).to( |
|
device=self.device) |
|
inds_2 = torch.LongTensor( |
|
np.linspace(0, x.shape[3], self.h, endpoint=False)).to( |
|
device=self.device) |
|
x_sampled = x.index_select(2, inds_1) |
|
x_sampled = x_sampled.index_select(3, inds_2) |
|
else: |
|
raise Exception( |
|
f'Wrong resizing method. It should be: interpolation or sampling. ' |
|
f'But the given value is {self.resizing}.') |
|
else: |
|
x_sampled = x |
|
|
|
L = x_sampled.shape[0] |
|
if x_sampled.shape[1] > 3: |
|
x_sampled = x_sampled[:, :3, :, :] |
|
X = torch.unbind(x_sampled, dim=0) |
|
hists = torch.zeros((x_sampled.shape[0], 3, self.h, self.h)).to( |
|
device=self.device) |
|
for l in range(L): |
|
I = torch.t(torch.reshape(X[l], (3, -1))) |
|
II = torch.pow(I, 2) |
|
if self.intensity_scale: |
|
Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS), |
|
dim=1) |
|
else: |
|
Iy = 1 |
|
|
|
Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + EPS), |
|
dim=1) |
|
Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + EPS), |
|
dim=1) |
|
diff_u0 = abs( |
|
Iu0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), |
|
dim=0).to(self.device)) |
|
diff_v0 = abs( |
|
Iv0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), |
|
dim=0).to(self.device)) |
|
if self.method == 'thresholding': |
|
diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2 |
|
diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2 |
|
elif self.method == 'RBF': |
|
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_u0 = torch.exp(-diff_u0) |
|
diff_v0 = torch.exp(-diff_v0) |
|
elif self.method == 'inverse-quadratic': |
|
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_u0 = 1 / (1 + diff_u0) |
|
diff_v0 = 1 / (1 + diff_v0) |
|
else: |
|
raise Exception( |
|
f'Wrong kernel method. It should be either thresholding, RBF,' |
|
f' inverse-quadratic. But the given value is {self.method}.') |
|
diff_u0 = diff_u0.type(torch.float32) |
|
diff_v0 = diff_v0.type(torch.float32) |
|
a = torch.t(Iy * diff_u0) |
|
hists[l, 0, :, :] = torch.mm(a, diff_v0) |
|
|
|
Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS), |
|
dim=1) |
|
Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS), |
|
dim=1) |
|
diff_u1 = abs( |
|
Iu1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), |
|
dim=0).to(self.device)) |
|
diff_v1 = abs( |
|
Iv1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), |
|
dim=0).to(self.device)) |
|
|
|
if self.method == 'thresholding': |
|
diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2 |
|
diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2 |
|
elif self.method == 'RBF': |
|
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_u1 = torch.exp(-diff_u1) |
|
diff_v1 = torch.exp(-diff_v1) |
|
elif self.method == 'inverse-quadratic': |
|
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_u1 = 1 / (1 + diff_u1) |
|
diff_v1 = 1 / (1 + diff_v1) |
|
|
|
diff_u1 = diff_u1.type(torch.float32) |
|
diff_v1 = diff_v1.type(torch.float32) |
|
a = torch.t(Iy * diff_u1) |
|
hists[l, 1, :, :] = torch.mm(a, diff_v1) |
|
|
|
Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + EPS), |
|
dim=1) |
|
Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + EPS), |
|
dim=1) |
|
diff_u2 = abs( |
|
Iu2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), |
|
dim=0).to(self.device)) |
|
diff_v2 = abs( |
|
Iv2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), |
|
dim=0).to(self.device)) |
|
if self.method == 'thresholding': |
|
diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2 |
|
diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2 |
|
elif self.method == 'RBF': |
|
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_u2 = torch.exp(-diff_u2) |
|
diff_v2 = torch.exp(-diff_v2) |
|
elif self.method == 'inverse-quadratic': |
|
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), |
|
2) / self.sigma ** 2 |
|
diff_u2 = 1 / (1 + diff_u2) |
|
diff_v2 = 1 / (1 + diff_v2) |
|
diff_u2 = diff_u2.type(torch.float32) |
|
diff_v2 = diff_v2.type(torch.float32) |
|
a = torch.t(Iy * diff_u2) |
|
hists[l, 2, :, :] = torch.mm(a, diff_v2) |
|
|
|
|
|
hists_normalized = hists / ( |
|
((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS) |
|
|
|
return hists_normalized |