Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from models.vgg import Vgg19 | |
from utils.image_processing import gram | |
def to_gray_scale(image): | |
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_color.py#L33 | |
# Image are assum in range 1, -1 | |
image = (image + 1.0) / 2.0 # To [0, 1] | |
r, g, b = image.unbind(dim=-3) | |
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) | |
l_img = l_img.unsqueeze(dim=-3) | |
l_img = l_img.to(image.dtype) | |
l_img = l_img.expand(image.shape) | |
l_img = l_img / 0.5 - 1.0 # To [-1, 1] | |
return l_img | |
class ColorLoss(nn.Module): | |
def __init__(self): | |
super(ColorLoss, self).__init__() | |
self.l1 = nn.L1Loss() | |
self.huber = nn.SmoothL1Loss() | |
# self._rgb_to_yuv_kernel = torch.tensor([ | |
# [0.299, -0.14714119, 0.61497538], | |
# [0.587, -0.28886916, -0.51496512], | |
# [0.114, 0.43601035, -0.10001026] | |
# ]).float() | |
self._rgb_to_yuv_kernel = torch.tensor([ | |
[0.299, 0.587, 0.114], | |
[-0.14714119, -0.28886916, 0.43601035], | |
[0.61497538, -0.51496512, -0.10001026], | |
]).float() | |
def to(self, device): | |
new_self = super(ColorLoss, self).to(device) | |
new_self._rgb_to_yuv_kernel = new_self._rgb_to_yuv_kernel.to(device) | |
return new_self | |
def rgb_to_yuv(self, image): | |
''' | |
https://en.wikipedia.org/wiki/YUV | |
output: Image of shape (H, W, C) (channel last) | |
''' | |
# -1 1 -> 0 1 | |
image = (image + 1.0) / 2.0 | |
image = image.permute(0, 2, 3, 1) # To channel last | |
yuv_img = image @ self._rgb_to_yuv_kernel.T | |
return yuv_img | |
def forward(self, image, image_g): | |
image = self.rgb_to_yuv(image) | |
image_g = self.rgb_to_yuv(image_g) | |
# After convert to yuv, both images have channel last | |
return ( | |
self.l1(image[:, :, :, 0], image_g[:, :, :, 0]) | |
+ self.huber(image[:, :, :, 1], image_g[:, :, :, 1]) | |
+ self.huber(image[:, :, :, 2], image_g[:, :, :, 2]) | |
) | |
class AnimeGanLoss: | |
def __init__(self, args, device, gray_adv=False): | |
if isinstance(device, str): | |
device = torch.device(device) | |
self.content_loss = nn.L1Loss().to(device) | |
self.gram_loss = nn.L1Loss().to(device) | |
self.color_loss = ColorLoss().to(device) | |
self.wadvg = args.wadvg | |
self.wadvd = args.wadvd | |
self.wcon = args.wcon | |
self.wgra = args.wgra | |
self.wcol = args.wcol | |
self.wtvar = args.wtvar | |
# If true, use gray scale image to calculate adversarial loss | |
self.gray_adv = gray_adv | |
self.vgg19 = Vgg19().to(device).eval() | |
self.adv_type = args.gan_loss | |
self.bce_loss = nn.BCEWithLogitsLoss() | |
def compute_loss_G(self, fake_img, img, fake_logit, anime_gray): | |
''' | |
Compute loss for Generator | |
@Args: | |
- fake_img: generated image | |
- img: real image | |
- fake_logit: output of Discriminator given fake image | |
- anime_gray: grayscale of anime image | |
@Returns: | |
- Adversarial Loss of fake logits | |
- Content loss between real and fake features (vgg19) | |
- Gram loss between anime and fake features (Vgg19) | |
- Color loss between image and fake image | |
- Total variation loss of fake image | |
''' | |
fake_feat = self.vgg19(fake_img) | |
gray_feat = self.vgg19(anime_gray) | |
img_feat = self.vgg19(img) | |
# fake_gray_feat = self.vgg19(to_gray_scale(fake_img)) | |
return [ | |
# Want to be real image. | |
self.wadvg * self.adv_loss_g(fake_logit), | |
self.wcon * self.content_loss(img_feat, fake_feat), | |
self.wgra * self.gram_loss(gram(gray_feat), gram(fake_feat)), | |
self.wcol * self.color_loss(img, fake_img), | |
self.wtvar * self.total_variation_loss(fake_img) | |
] | |
def compute_loss_D( | |
self, | |
fake_img_d, | |
real_anime_d, | |
real_anime_gray_d, | |
real_anime_smooth_gray_d=None | |
): | |
if self.gray_adv: | |
# Treat gray scale image as real | |
return ( | |
self.adv_loss_d_real(real_anime_gray_d) | |
+ self.adv_loss_d_fake(fake_img_d) | |
+ 0.3 * self.adv_loss_d_fake(real_anime_smooth_gray_d) | |
) | |
else: | |
return ( | |
# Classify real anime as real | |
self.adv_loss_d_real(real_anime_d) | |
# Classify generated as fake | |
+ self.adv_loss_d_fake(fake_img_d) | |
# Classify real anime gray as fake | |
# + self.adv_loss_d_fake(real_anime_gray_d) | |
# Classify real anime as fake | |
# + 0.1 * self.adv_loss_d_fake(real_anime_smooth_gray_d) | |
) | |
def total_variation_loss(self, fake_img): | |
""" | |
A smooth loss in fact. Like the smooth prior in MRF. | |
V(y) = || y_{n+1} - y_n ||_2 | |
""" | |
# Channel first -> channel last | |
fake_img = fake_img.permute(0, 2, 3, 1) | |
def _l2(x): | |
# sum(t ** 2) / 2 | |
return torch.sum(x ** 2) / 2 | |
dh = fake_img[:, :-1, ...] - fake_img[:, 1:, ...] | |
dw = fake_img[:, :, :-1, ...] - fake_img[:, :, 1:, ...] | |
return _l2(dh) / dh.numel() + _l2(dw) / dw.numel() | |
def content_loss_vgg(self, image, recontruction): | |
feat = self.vgg19(image) | |
re_feat = self.vgg19(recontruction) | |
feature_loss = self.content_loss(feat, re_feat) | |
content_loss = self.content_loss(image, recontruction) | |
return feature_loss# + 0.5 * content_loss | |
def adv_loss_d_real(self, pred): | |
"""Push pred to class 1 (real)""" | |
if self.adv_type == 'hinge': | |
return torch.mean(F.relu(1.0 - pred)) | |
elif self.adv_type == 'lsgan': | |
# pred = torch.sigmoid(pred) | |
return torch.mean(torch.square(pred - 1.0)) | |
elif self.adv_type == 'bce': | |
return self.bce_loss(pred, torch.ones_like(pred)) | |
raise ValueError(f'Do not support loss type {self.adv_type}') | |
def adv_loss_d_fake(self, pred): | |
"""Push pred to class 0 (fake)""" | |
if self.adv_type == 'hinge': | |
return torch.mean(F.relu(1.0 + pred)) | |
elif self.adv_type == 'lsgan': | |
# pred = torch.sigmoid(pred) | |
return torch.mean(torch.square(pred)) | |
elif self.adv_type == 'bce': | |
return self.bce_loss(pred, torch.zeros_like(pred)) | |
raise ValueError(f'Do not support loss type {self.adv_type}') | |
def adv_loss_g(self, pred): | |
"""Push pred to class 1 (real)""" | |
if self.adv_type == 'hinge': | |
return -torch.mean(pred) | |
elif self.adv_type == 'lsgan': | |
# pred = torch.sigmoid(pred) | |
return torch.mean(torch.square(pred - 1.0)) | |
elif self.adv_type == 'bce': | |
return self.bce_loss(pred, torch.ones_like(pred)) | |
raise ValueError(f'Do not support loss type {self.adv_type}') | |
class LossSummary: | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.loss_g_adv = [] | |
self.loss_content = [] | |
self.loss_gram = [] | |
self.loss_color = [] | |
self.loss_d_adv = [] | |
def update_loss_G(self, adv, gram, color, content): | |
self.loss_g_adv.append(adv.cpu().detach().numpy()) | |
self.loss_gram.append(gram.cpu().detach().numpy()) | |
self.loss_color.append(color.cpu().detach().numpy()) | |
self.loss_content.append(content.cpu().detach().numpy()) | |
def update_loss_D(self, loss): | |
self.loss_d_adv.append(loss.cpu().detach().numpy()) | |
def avg_loss_G(self): | |
return ( | |
self._avg(self.loss_g_adv), | |
self._avg(self.loss_gram), | |
self._avg(self.loss_color), | |
self._avg(self.loss_content), | |
) | |
def avg_loss_D(self): | |
return self._avg(self.loss_d_adv) | |
def get_loss_description(self): | |
avg_adv, avg_gram, avg_color, avg_content = self.avg_loss_G() | |
avg_adv_d = self.avg_loss_D() | |
return f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}' | |
def _avg(losses): | |
return sum(losses) / len(losses) | |