vatrpp / models /model.py
vittoriopippi
Initial commit
fa0f216
raw
history blame
35 kB
import torch.utils.data
from torch.nn import CTCLoss
from torch.nn.utils import clip_grad_norm_
import sys
import torchvision.models as models
from models.inception import InceptionV3
from models.transformer import *
from util.augmentations import OCRAugment
from util.misc import SmoothedValue
from util.text import get_generator, AugmentedGenerator
from .BigGAN_networks import *
from .OCR_network import *
from models.blocks import Conv2dBlock, ResBlocks
from util.util import loss_hinge_dis, loss_hinge_gen, make_one_hot
import models.config as config
from .positional_encodings import PositionalEncoding1D
from models.unifont_module import UnifontModule
from PIL import Image
def get_rgb(x):
R = 255 - int(int(x > 0.5) * 255 * (x - 0.5) / 0.5)
G = 0
B = 255 + int(int(x < 0.5) * 255 * (x - 0.5) / 0.5)
return R, G, B
def get_page_from_words(word_lists, MAX_IMG_WIDTH=800):
line_all = []
line_t = []
width_t = 0
for i in word_lists:
width_t = width_t + i.shape[1] + 16
if width_t > MAX_IMG_WIDTH:
line_all.append(np.concatenate(line_t, 1))
line_t = []
width_t = i.shape[1] + 16
line_t.append(i)
line_t.append(np.ones((i.shape[0], 16)))
if len(line_all) == 0:
line_all.append(np.concatenate(line_t, 1))
max_lin_widths = MAX_IMG_WIDTH # max([i.shape[1] for i in line_all])
gap_h = np.ones([16, max_lin_widths])
page_ = []
for l in line_all:
pad_ = np.ones([l.shape[0], max_lin_widths - l.shape[1]])
page_.append(np.concatenate([l, pad_], 1))
page_.append(gap_h)
page = np.concatenate(page_, 0)
return page * 255
class FCNDecoder(nn.Module):
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
super(FCNDecoder, self).__init__()
self.model = []
self.model += [ResBlocks(n_res, dim, res_norm,
activ, pad_type=pad_type)]
for i in range(ups):
self.model += [nn.Upsample(scale_factor=2),
Conv2dBlock(dim, dim // 2, 5, 1, 2,
norm='in',
activation=activ,
pad_type=pad_type)]
dim //= 2
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
norm='none',
activation='tanh',
pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
def forward(self, x):
y = self.model(x)
return y
class Generator(nn.Module):
def __init__(self, args):
super(Generator, self).__init__()
self.args = args
INP_CHANNEL = 1
encoder_layer = TransformerEncoderLayer(config.tn_hidden_dim, config.tn_nheads,
config.tn_dim_feedforward,
config.tn_dropout, "relu", True)
encoder_norm = nn.LayerNorm(config.tn_hidden_dim) if True else None
self.encoder = TransformerEncoder(encoder_layer, config.tn_enc_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(config.tn_hidden_dim, config.tn_nheads,
config.tn_dim_feedforward,
config.tn_dropout, "relu", True)
decoder_norm = nn.LayerNorm(config.tn_hidden_dim)
self.decoder = TransformerDecoder(decoder_layer, config.tn_dec_layers, decoder_norm,
return_intermediate=True)
self.Feat_Encoder = models.resnet18(weights='ResNet18_Weights.DEFAULT')
self.Feat_Encoder.conv1 = nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.Feat_Encoder.fc = nn.Identity()
self.Feat_Encoder.avgpool = nn.Identity()
# self.query_embed = nn.Embedding(self.args.vocab_size, self.args.tn_hidden_dim)
self.query_embed = UnifontModule(
config.tn_dim_feedforward,
self.args.alphabet + self.args.special_alphabet,
input_type=self.args.query_input,
device=self.args.device
)
self.pos_encoder = PositionalEncoding1D(config.tn_hidden_dim)
self.linear_q = nn.Linear(config.tn_dim_feedforward, config.tn_dim_feedforward * 8)
self.DEC = FCNDecoder(res_norm='in', dim=config.tn_hidden_dim)
self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
def evaluate(self, style_images, queries):
style = self.compute_style(style_images)
results = []
for i in range(queries.shape[1]):
query = queries[:, i, :]
h = self.generate(style, query)
results.append(h.detach())
return results
def compute_style(self, style_images):
B, N, R, C = style_images.shape
FEAT_ST = self.Feat_Encoder(style_images.view(B * N, 1, R, C))
FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2, 0, 1)
memory = self.encoder(FEAT_ST_ENC)
return memory
def generate(self, style_vector, query):
query_embed = self.query_embed(query).permute(1, 0, 2)
tgt = torch.zeros_like(query_embed)
hs = self.decoder(tgt, style_vector, query_pos=query_embed)
h = hs.transpose(1, 2)[-1]
if self.args.add_noise:
h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
h = self.linear_q(h)
h = h.contiguous()
h = h.view(h.size(0), h.shape[1] * 2, 4, -1)
h = h.permute(0, 3, 2, 1)
h = self.DEC(h)
return h
def forward(self, style_images, query):
enc_attn_weights, dec_attn_weights = [], []
self.hooks = [
self.encoder.layers[-1].self_attn.register_forward_hook(
lambda self, input, output: enc_attn_weights.append(output[1])
),
self.decoder.layers[-1].multihead_attn.register_forward_hook(
lambda self, input, output: dec_attn_weights.append(output[1])
),
]
style = self.compute_style(style_images)
h = self.generate(style, query)
self.dec_attn_weights = dec_attn_weights[-1].detach()
self.enc_attn_weights = enc_attn_weights[-1].detach()
for hook in self.hooks:
hook.remove()
return h, style
class VATr(nn.Module):
def __init__(self, args):
super(VATr, self).__init__()
self.args = args
self.args.vocab_size = len(args.alphabet)
self.epsilon = 1e-7
self.netG = Generator(self.args).to(self.args.device)
self.netD = Discriminator(
resolution=self.args.resolution, crop_size=args.d_crop_size,
).to(self.args.device)
self.netW = WDiscriminator(resolution=self.args.resolution, n_classes=self.args.vocab_size, output_dim=self.args.num_writers)
self.netW = self.netW.to(self.args.device)
self.netconverter = strLabelConverter(self.args.alphabet + self.args.special_alphabet)
self.netOCR = CRNN(self.args).to(self.args.device)
self.ocr_augmenter = OCRAugment(prob=0.5, no=3)
self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
self.inception = InceptionV3([block_idx]).to(self.args.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=self.args.g_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
lr=self.args.ocr_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=self.args.d_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
lr=self.args.w_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
self.optimizer_G.zero_grad()
self.optimizer_OCR.zero_grad()
self.optimizer_D.zero_grad()
self.optimizer_wl.zero_grad()
self.loss_G = 0
self.loss_D = 0
self.loss_Dfake = 0
self.loss_Dreal = 0
self.loss_OCR_fake = 0
self.loss_OCR_real = 0
self.loss_w_fake = 0
self.loss_w_real = 0
self.Lcycle = 0
self.d_acc = SmoothedValue()
self.word_generator = get_generator(args)
self.epoch = 0
with open('mytext.txt', 'r', encoding='utf-8') as f:
self.text = f.read()
self.text = self.text.replace('\n', ' ')
self.text = self.text.replace('\n', ' ')
self.text = ''.join(c for c in self.text if c in (self.args.alphabet + self.args.special_alphabet)) # just to avoid problems with the font dataset
self.text = [word.encode() for word in self.text.split()] # [:args.num_examples]
self.eval_text_encode, self.eval_len_text, self.eval_encode_pos = self.netconverter.encode(self.text)
self.eval_text_encode = self.eval_text_encode.to(self.args.device).repeat(self.args.batch_size, 1, 1)
self.rv_sample_size = 64 * 4
self.last_fakes = []
def update_last_fakes(self, fakes):
for fake in fakes:
self.last_fakes.append(fake)
self.last_fakes = self.last_fakes[-self.rv_sample_size:]
def update_acc(self, pred_real, pred_fake):
correct = (pred_real >= 0.5).float().sum() + (pred_fake < 0.5).float().sum()
self.d_acc.update(correct / (len(pred_real) + len(pred_fake)))
def set_text_aug_strength(self, strength):
if not isinstance(self.word_generator, AugmentedGenerator):
print("WARNING: Text generator is not augmented, strength cannot be set")
else:
self.word_generator.set_strength(strength)
def get_text_aug_strength(self):
if isinstance(self.word_generator, AugmentedGenerator):
return self.word_generator.strength
else:
return 0.0
def update_parameters(self, epoch: int):
self.epoch = epoch
self.netD.update_parameters(epoch)
self.netW.update_parameters(epoch)
def get_text_sample(self, size: int) -> list:
return [self.word_generator.generate() for _ in range(size)]
def _generate_fakes(self, ST, eval_text_encode=None, eval_len_text=None):
if eval_text_encode == None:
eval_text_encode = self.eval_text_encode
if eval_len_text == None:
eval_len_text = self.eval_len_text
self.fakes = self.netG.evaluate(ST, eval_text_encode)
np_fakes = []
for batch_idx in range(self.fakes[0].shape[0]):
for idx, fake in enumerate(self.fakes):
fake = fake[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution]
fake = (fake + 1) / 2
np_fakes.append(fake.cpu().numpy())
return np_fakes
def _generate_page(self, ST, SLEN, eval_text_encode=None, eval_len_text=None, eval_encode_pos=None, lwidth=260, rwidth=980):
# ST -> Style?
if eval_text_encode == None:
eval_text_encode = self.eval_text_encode
if eval_len_text == None:
eval_len_text = self.eval_len_text
if eval_encode_pos is None:
eval_encode_pos = self.eval_encode_pos
text_encode, text_len, _ = self.netconverter.encode(self.args.special_alphabet)
symbols = self.netG.query_embed.symbols[text_encode].reshape(-1, 16, 16).cpu().numpy()
imgs = [Image.fromarray(s).resize((32, 32), resample=0) for s in symbols]
special_examples = 1 - np.concatenate([np.array(i) for i in imgs], axis=-1)
self.fakes = self.netG.evaluate(ST, eval_text_encode)
page1s = []
page2s = []
for batch_idx in range(ST.shape[0]):
word_t = []
word_l = []
gap = np.ones([self.args.img_height, 16])
line_wids = []
for idx, fake_ in enumerate(self.fakes):
word_t.append((fake_[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution].cpu().numpy() + 1) / 2)
word_t.append(gap)
if sum(t.shape[-1] for t in word_t) >= rwidth or idx == len(self.fakes) - 1 or (len(self.fakes) - len(self.args.special_alphabet) - 1) == idx:
line_ = np.concatenate(word_t, -1)
word_l.append(line_)
line_wids.append(line_.shape[1])
word_t = []
# add the examples from the UnifontModules
word_l.append(special_examples)
line_wids.append(special_examples.shape[1])
gap_h = np.ones([16, max(line_wids)])
page_ = []
for l in word_l:
pad_ = np.ones([self.args.img_height, max(line_wids) - l.shape[1]])
page_.append(np.concatenate([l, pad_], 1))
page_.append(gap_h)
page1 = np.concatenate(page_, 0)
word_t = []
word_l = []
line_wids = []
sdata_ = [i.unsqueeze(1) for i in torch.unbind(ST, 1)]
gap = np.ones([sdata_[0].shape[-2], 16])
for idx, st in enumerate((sdata_)):
word_t.append((st[batch_idx, 0, :, :int(SLEN.cpu().numpy()[batch_idx][idx])].cpu().numpy() + 1) / 2)
# word_t.append((st[batch_idx, 0, :, :].cpu().numpy() + 1) / 2)
word_t.append(gap)
if sum(t.shape[-1] for t in word_t) >= lwidth or idx == len(sdata_) - 1:
line_ = np.concatenate(word_t, -1)
word_l.append(line_)
line_wids.append(line_.shape[1])
word_t = []
gap_h = np.ones([16, max(line_wids)])
page_ = []
for l in word_l:
pad_ = np.ones([sdata_[0].shape[-2], max(line_wids) - l.shape[1]])
page_.append(np.concatenate([l, pad_], 1))
page_.append(gap_h)
page2 = np.concatenate(page_, 0)
merge_w_size = max(page1.shape[0], page2.shape[0])
if page1.shape[0] != merge_w_size:
page1 = np.concatenate([page1, np.ones([merge_w_size - page1.shape[0], page1.shape[1]])], 0)
if page2.shape[0] != merge_w_size:
page2 = np.concatenate([page2, np.ones([merge_w_size - page2.shape[0], page2.shape[1]])], 0)
page1s.append(page1)
page2s.append(page2)
# page = np.concatenate([page2, page1], 1)
page1s_ = np.concatenate(page1s, 0)
max_wid = max([i.shape[1] for i in page2s])
padded_page2s = []
for para in page2s:
padded_page2s.append(np.concatenate([para, np.ones([para.shape[0], max_wid - para.shape[1]])], 1))
padded_page2s_ = np.concatenate(padded_page2s, 0)
return np.concatenate([padded_page2s_, page1s_], 1)
def get_current_losses(self):
losses = {}
losses['G'] = self.loss_G
losses['D'] = self.loss_D
losses['Dfake'] = self.loss_Dfake
losses['Dreal'] = self.loss_Dreal
losses['OCR_fake'] = self.loss_OCR_fake
losses['OCR_real'] = self.loss_OCR_real
losses['w_fake'] = self.loss_w_fake
losses['w_real'] = self.loss_w_real
losses['cycle'] = self.Lcycle
return losses
def _set_input(self, input):
self.input = input
self.real = self.input['img'].to(self.args.device)
self.label = self.input['label']
self.set_ocr_data(self.input['img'], self.input['label'])
self.sdata = self.input['simg'].to(self.args.device)
self.slabels = self.input['slabels']
self.ST_LEN = self.input['swids']
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def forward(self):
self.text_encode, self.len_text, self.encode_pos = self.netconverter.encode(self.label)
self.text_encode = self.text_encode.to(self.args.device).detach()
self.len_text = self.len_text.detach()
self.words = [self.word_generator.generate().encode('utf-8') for _ in range(self.args.batch_size)]
self.text_encode_fake, self.len_text_fake, self.encode_pos_fake = self.netconverter.encode(self.words)
self.text_encode_fake = self.text_encode_fake.to(self.args.device)
self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.args.vocab_size).to(
self.args.device)
self.fake, self.style = self.netG(self.sdata, self.text_encode_fake)
self.update_last_fakes(self.fake)
def pad_width(self, t, new_width):
result = torch.ones((t.size(0), t.size(1), t.size(2), new_width), device=t.device)
result[:,:,:,:t.size(-1)] = t
return result
def compute_real_ocr_loss(self, ocr_network = None):
network = ocr_network if ocr_network is not None else self.netOCR
real_input = self.ocr_images
input_images = real_input
input_labels = self.ocr_labels
input_images = input_images.detach()
if self.ocr_augmenter is not None:
input_images = self.ocr_augmenter(input_images)
pred_real = network(input_images)
preds_size = torch.IntTensor([pred_real.size(0)] * len(input_labels)).detach()
text_encode, len_text, _ = self.netconverter.encode(input_labels)
loss = self.OCR_criterion(pred_real, text_encode.detach(), preds_size, len_text.detach())
return torch.mean(loss[~torch.isnan(loss)])
def compute_fake_ocr_loss(self, ocr_network = None):
network = ocr_network if ocr_network is not None else self.netOCR
pred_fake_OCR = network(self.fake)
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.args.batch_size).detach()
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size,
self.len_text_fake.detach())
return torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
def set_ocr_data(self, images, labels):
self.ocr_images = images.to(self.args.device)
self.ocr_labels = labels
def backward_D_OCR(self):
self.real.__repr__()
self.fake.__repr__()
pred_real = self.netD(self.real.detach())
pred_fake = self.netD(**{'x': self.fake.detach()})
self.update_acc(pred_real, pred_fake)
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(),
self.len_text.detach(), True)
self.loss_D = self.loss_Dreal + self.loss_Dfake
if not self.args.no_ocr_loss:
self.loss_OCR_real = self.compute_real_ocr_loss()
loss_total = self.loss_D + self.loss_OCR_real
else:
loss_total = self.loss_D
# backward
loss_total.backward()
if not self.args.no_ocr_loss:
self.clean_grad(self.netOCR.parameters())
return loss_total
def clean_grad(self, params):
for param in params:
param.grad[param.grad != param.grad] = 0
param.grad[torch.isnan(param.grad)] = 0
param.grad[torch.isinf(param.grad)] = 0
def backward_D_WL(self):
# Real
pred_real = self.netD(self.real.detach())
pred_fake = self.netD(**{'x': self.fake.detach()})
self.update_acc(pred_real, pred_fake)
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(),
self.len_text.detach(), True)
self.loss_D = self.loss_Dreal + self.loss_Dfake
if not self.args.no_writer_loss:
self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(self.args.device)).mean()
# total loss
loss_total = self.loss_D + self.loss_w_real * self.args.writer_loss_weight
else:
loss_total = self.loss_D
# backward
loss_total.backward()
return loss_total
def optimize_D_WL(self):
self.forward()
self.set_requires_grad([self.netD], True)
self.set_requires_grad([self.netOCR], False)
self.set_requires_grad([self.netW], True)
self.set_requires_grad([self.netW], True)
self.optimizer_D.zero_grad()
self.optimizer_wl.zero_grad()
self.backward_D_WL()
def optimize_D_WL_step(self):
self.optimizer_D.step()
self.optimizer_wl.step()
self.optimizer_D.zero_grad()
self.optimizer_wl.zero_grad()
def compute_cycle_loss(self):
fake_input = torch.ones_like(self.sdata)
width = min(self.sdata.size(-1), self.fake.size(-1))
fake_input[:, :, :, :width] = self.fake.repeat(1, 15, 1, 1)[:, :, :, :width]
with torch.no_grad():
fake_style = self.netG.compute_style(fake_input)
return torch.sum(torch.abs(self.style.detach() - fake_style), dim=1).mean()
def backward_G_only(self):
self.gb_alpha = 0.7
if self.args.is_cycle:
self.Lcycle = self.compute_cycle_loss()
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
compute_ocr = not self.args.no_ocr_loss
if compute_ocr:
self.loss_OCR_fake = self.compute_fake_ocr_loss()
self.loss_G = self.loss_G + self.Lcycle
if compute_ocr:
self.loss_T = self.loss_G + self.loss_OCR_fake
else:
self.loss_T = self.loss_G
if compute_ocr:
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
self.loss_T.backward(retain_graph=True)
if compute_ocr:
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR))
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
self.loss_T = self.loss_G + self.loss_OCR_fake
else:
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
a = 1
self.loss_T = self.loss_G
if a is None:
print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv))
if a > 1000 or a < 0.0001:
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
self.loss_T.backward(retain_graph=True)
if compute_ocr:
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
with torch.no_grad():
self.loss_T.backward()
if compute_ocr:
if any(torch.isnan(torch.unsqueeze(self.loss_OCR_fake, dim=0))) or torch.isnan(self.loss_G):
print('loss OCR fake: ', self.loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
sys.exit()
def backward_G_WL(self):
self.gb_alpha = 0.7
if self.args.is_cycle:
self.Lcycle = self.compute_cycle_loss()
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
if not self.args.no_writer_loss:
self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(self.args.device)).mean()
self.loss_G = self.loss_G + self.Lcycle
if not self.args.no_writer_loss:
self.loss_T = self.loss_G + self.loss_w_fake * self.args.writer_loss_weight
else:
self.loss_T = self.loss_G
self.loss_T.backward(retain_graph=True)
if not self.args.no_writer_loss:
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_WL))
self.loss_w_fake = a.detach() * self.loss_w_fake
self.loss_T = self.loss_G + self.loss_w_fake
else:
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
a = 1
self.loss_T = self.loss_G
if a is None:
print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv))
if a > 1000 or a < 0.0001:
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
self.loss_T.backward(retain_graph=True)
if not self.args.no_writer_loss:
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
with torch.no_grad():
self.loss_T.backward()
def backward_G(self):
self.opt.gb_alpha = 0.7
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(),
self.opt.mask_loss)
# OCR loss on real data
compute_ocr = not self.args.no_ocr_loss
if compute_ocr:
self.loss_OCR_fake = self.compute_fake_ocr_loss()
else:
self.loss_OCR_fake = 0.0
self.loss_w_fake = self.netW(self.fake, self.wcl)
# self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
# total loss
# l1 = self.params[0]*self.loss_G
# l2 = self.params[0]*self.loss_OCR_fake
# l3 = self.params[0]*self.loss_w_fake
self.loss_G_ = 10 * self.loss_G + self.loss_w_fake
self.loss_T = self.loss_G_ + self.loss_OCR_fake
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
if not False:
self.loss_T.backward(retain_graph=True)
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
# grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR))
# a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
if a is None:
print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
if a > 1000 or a < 0.0001:
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}')
b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR)) *
torch.mean(grad_fake_OCR))
# self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
# self.loss_w_fake = a0.detach() * self.loss_w_fake
self.loss_T = (1 - 1 * self.opt.onlyOCR) * self.loss_G_ + self.loss_OCR_fake # + self.loss_w_fake
self.loss_T.backward(retain_graph=True)
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
with torch.no_grad():
self.loss_T.backward()
else:
self.loss_T.backward()
if self.opt.clip_grad > 0:
clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
sys.exit()
def optimize_D_OCR(self):
self.forward()
self.set_requires_grad([self.netD], True)
self.set_requires_grad([self.netOCR], True)
self.optimizer_D.zero_grad()
# if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
self.optimizer_OCR.zero_grad()
self.backward_D_OCR()
def optimize_D_OCR_step(self):
self.optimizer_D.step()
self.optimizer_OCR.step()
self.optimizer_D.zero_grad()
self.optimizer_OCR.zero_grad()
def optimize_G_WL(self):
self.forward()
self.set_requires_grad([self.netD], False)
self.set_requires_grad([self.netOCR], False)
self.set_requires_grad([self.netW], False)
self.backward_G_WL()
def optimize_G_only(self):
self.forward()
self.set_requires_grad([self.netD], False)
self.set_requires_grad([self.netOCR], False)
self.set_requires_grad([self.netW], False)
self.backward_G_only()
def optimize_G_step(self):
self.optimizer_G.step()
self.optimizer_G.zero_grad()
def save_networks(self, epoch, save_dir):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
# torch.save(net.module.cpu().state_dict(), save_path)
if len(self.gpu_ids) > 1:
torch.save(net.module.cpu().state_dict(), save_path)
else:
torch.save(net.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def compute_d_scores(self, data_loader: torch.utils.data.DataLoader, amount: int = None):
scores = []
words = []
amount = len(data_loader) if amount is None else amount // data_loader.batch_size
with torch.no_grad():
for i in range(amount):
data = next(iter(data_loader))
words.extend([d.decode() for d in data['label']])
scores.extend(list(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy()))
return scores, words
def compute_d_scores_fake(self, data_loader: torch.utils.data.DataLoader, amount: int = None):
scores = []
words = []
amount = len(data_loader) if amount is None else amount // data_loader.batch_size
with torch.no_grad():
for i in range(amount):
data = next(iter(data_loader))
to_generate = [self.word_generator.generate().encode('utf-8') for _ in range(data_loader.batch_size)]
text_encode_fake, len_text_fake, encode_pos_fake = self.netconverter.encode(to_generate)
fake, _ = self.netG(data['simg'].to(self.args.device), text_encode_fake.to(self.args.device))
words.extend([d.decode() for d in to_generate])
scores.extend(list(self.netD(fake).squeeze().detach().cpu().numpy()))
return scores, words
def compute_d_stats(self, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader):
train_values = []
val_values = []
fake_values = []
with torch.no_grad():
for i in range(self.rv_sample_size // train_loader.batch_size):
data = next(iter(train_loader))
train_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())
for i in range(self.rv_sample_size // val_loader.batch_size):
data = next(iter(val_loader))
val_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())
for i in range(self.rv_sample_size):
data = self.last_fakes[i]
fake_values.append(self.netD(data.unsqueeze(0)).squeeze().detach().cpu().numpy())
return np.mean(train_values), np.mean(val_values), np.mean(fake_values)