|
import torch.utils.data |
|
from torch.nn import CTCLoss |
|
from torch.nn.utils import clip_grad_norm_ |
|
import sys |
|
import torchvision.models as models |
|
|
|
from models.inception import InceptionV3 |
|
from models.transformer import * |
|
from util.augmentations import OCRAugment |
|
from util.misc import SmoothedValue |
|
from util.text import get_generator, AugmentedGenerator |
|
from .BigGAN_networks import * |
|
from .OCR_network import * |
|
from models.blocks import Conv2dBlock, ResBlocks |
|
from util.util import loss_hinge_dis, loss_hinge_gen, make_one_hot |
|
|
|
import models.config as config |
|
from .positional_encodings import PositionalEncoding1D |
|
from models.unifont_module import UnifontModule |
|
from PIL import Image |
|
|
|
|
|
def get_rgb(x): |
|
R = 255 - int(int(x > 0.5) * 255 * (x - 0.5) / 0.5) |
|
G = 0 |
|
B = 255 + int(int(x < 0.5) * 255 * (x - 0.5) / 0.5) |
|
return R, G, B |
|
|
|
|
|
def get_page_from_words(word_lists, MAX_IMG_WIDTH=800): |
|
line_all = [] |
|
line_t = [] |
|
|
|
width_t = 0 |
|
|
|
for i in word_lists: |
|
|
|
width_t = width_t + i.shape[1] + 16 |
|
|
|
if width_t > MAX_IMG_WIDTH: |
|
line_all.append(np.concatenate(line_t, 1)) |
|
|
|
line_t = [] |
|
|
|
width_t = i.shape[1] + 16 |
|
|
|
line_t.append(i) |
|
line_t.append(np.ones((i.shape[0], 16))) |
|
|
|
if len(line_all) == 0: |
|
line_all.append(np.concatenate(line_t, 1)) |
|
|
|
max_lin_widths = MAX_IMG_WIDTH |
|
gap_h = np.ones([16, max_lin_widths]) |
|
|
|
page_ = [] |
|
|
|
for l in line_all: |
|
pad_ = np.ones([l.shape[0], max_lin_widths - l.shape[1]]) |
|
|
|
page_.append(np.concatenate([l, pad_], 1)) |
|
page_.append(gap_h) |
|
|
|
page = np.concatenate(page_, 0) |
|
|
|
return page * 255 |
|
|
|
|
|
class FCNDecoder(nn.Module): |
|
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'): |
|
super(FCNDecoder, self).__init__() |
|
|
|
self.model = [] |
|
self.model += [ResBlocks(n_res, dim, res_norm, |
|
activ, pad_type=pad_type)] |
|
for i in range(ups): |
|
self.model += [nn.Upsample(scale_factor=2), |
|
Conv2dBlock(dim, dim // 2, 5, 1, 2, |
|
norm='in', |
|
activation=activ, |
|
pad_type=pad_type)] |
|
dim //= 2 |
|
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3, |
|
norm='none', |
|
activation='tanh', |
|
pad_type=pad_type)] |
|
self.model = nn.Sequential(*self.model) |
|
|
|
def forward(self, x): |
|
y = self.model(x) |
|
|
|
return y |
|
|
|
|
|
class Generator(nn.Module): |
|
|
|
def __init__(self, args): |
|
super(Generator, self).__init__() |
|
self.args = args |
|
INP_CHANNEL = 1 |
|
|
|
encoder_layer = TransformerEncoderLayer(config.tn_hidden_dim, config.tn_nheads, |
|
config.tn_dim_feedforward, |
|
config.tn_dropout, "relu", True) |
|
encoder_norm = nn.LayerNorm(config.tn_hidden_dim) if True else None |
|
self.encoder = TransformerEncoder(encoder_layer, config.tn_enc_layers, encoder_norm) |
|
|
|
decoder_layer = TransformerDecoderLayer(config.tn_hidden_dim, config.tn_nheads, |
|
config.tn_dim_feedforward, |
|
config.tn_dropout, "relu", True) |
|
decoder_norm = nn.LayerNorm(config.tn_hidden_dim) |
|
self.decoder = TransformerDecoder(decoder_layer, config.tn_dec_layers, decoder_norm, |
|
return_intermediate=True) |
|
|
|
self.Feat_Encoder = models.resnet18(weights='ResNet18_Weights.DEFAULT') |
|
self.Feat_Encoder.conv1 = nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
self.Feat_Encoder.fc = nn.Identity() |
|
self.Feat_Encoder.avgpool = nn.Identity() |
|
|
|
|
|
self.query_embed = UnifontModule( |
|
config.tn_dim_feedforward, |
|
self.args.alphabet + self.args.special_alphabet, |
|
input_type=self.args.query_input, |
|
device=self.args.device |
|
) |
|
|
|
self.pos_encoder = PositionalEncoding1D(config.tn_hidden_dim) |
|
|
|
self.linear_q = nn.Linear(config.tn_dim_feedforward, config.tn_dim_feedforward * 8) |
|
|
|
self.DEC = FCNDecoder(res_norm='in', dim=config.tn_hidden_dim) |
|
|
|
self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0])) |
|
|
|
def evaluate(self, style_images, queries): |
|
style = self.compute_style(style_images) |
|
|
|
results = [] |
|
|
|
for i in range(queries.shape[1]): |
|
query = queries[:, i, :] |
|
h = self.generate(style, query) |
|
|
|
results.append(h.detach()) |
|
|
|
return results |
|
|
|
def compute_style(self, style_images): |
|
B, N, R, C = style_images.shape |
|
FEAT_ST = self.Feat_Encoder(style_images.view(B * N, 1, R, C)) |
|
FEAT_ST = FEAT_ST.view(B, 512, 1, -1) |
|
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2, 0, 1) |
|
memory = self.encoder(FEAT_ST_ENC) |
|
return memory |
|
|
|
def generate(self, style_vector, query): |
|
query_embed = self.query_embed(query).permute(1, 0, 2) |
|
|
|
tgt = torch.zeros_like(query_embed) |
|
hs = self.decoder(tgt, style_vector, query_pos=query_embed) |
|
|
|
h = hs.transpose(1, 2)[-1] |
|
|
|
if self.args.add_noise: |
|
h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device) |
|
|
|
h = self.linear_q(h) |
|
h = h.contiguous() |
|
|
|
h = h.view(h.size(0), h.shape[1] * 2, 4, -1) |
|
h = h.permute(0, 3, 2, 1) |
|
|
|
h = self.DEC(h) |
|
|
|
return h |
|
|
|
def forward(self, style_images, query): |
|
enc_attn_weights, dec_attn_weights = [], [] |
|
|
|
self.hooks = [ |
|
|
|
self.encoder.layers[-1].self_attn.register_forward_hook( |
|
lambda self, input, output: enc_attn_weights.append(output[1]) |
|
), |
|
self.decoder.layers[-1].multihead_attn.register_forward_hook( |
|
lambda self, input, output: dec_attn_weights.append(output[1]) |
|
), |
|
] |
|
|
|
style = self.compute_style(style_images) |
|
|
|
h = self.generate(style, query) |
|
|
|
self.dec_attn_weights = dec_attn_weights[-1].detach() |
|
self.enc_attn_weights = enc_attn_weights[-1].detach() |
|
|
|
for hook in self.hooks: |
|
hook.remove() |
|
|
|
return h, style |
|
|
|
|
|
class VATr(nn.Module): |
|
|
|
def __init__(self, args): |
|
super(VATr, self).__init__() |
|
self.args = args |
|
self.args.vocab_size = len(args.alphabet) |
|
|
|
self.epsilon = 1e-7 |
|
self.netG = Generator(self.args).to(self.args.device) |
|
self.netD = Discriminator( |
|
resolution=self.args.resolution, crop_size=args.d_crop_size, |
|
).to(self.args.device) |
|
|
|
self.netW = WDiscriminator(resolution=self.args.resolution, n_classes=self.args.vocab_size, output_dim=self.args.num_writers) |
|
self.netW = self.netW.to(self.args.device) |
|
self.netconverter = strLabelConverter(self.args.alphabet + self.args.special_alphabet) |
|
|
|
self.netOCR = CRNN(self.args).to(self.args.device) |
|
|
|
self.ocr_augmenter = OCRAugment(prob=0.5, no=3) |
|
self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none') |
|
|
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] |
|
self.inception = InceptionV3([block_idx]).to(self.args.device) |
|
|
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), |
|
lr=self.args.g_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) |
|
|
|
self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(), |
|
lr=self.args.ocr_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) |
|
|
|
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), |
|
lr=self.args.d_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) |
|
|
|
self.optimizer_wl = torch.optim.Adam(self.netW.parameters(), |
|
lr=self.args.w_lr, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) |
|
|
|
self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl] |
|
|
|
self.optimizer_G.zero_grad() |
|
self.optimizer_OCR.zero_grad() |
|
self.optimizer_D.zero_grad() |
|
self.optimizer_wl.zero_grad() |
|
|
|
self.loss_G = 0 |
|
self.loss_D = 0 |
|
self.loss_Dfake = 0 |
|
self.loss_Dreal = 0 |
|
self.loss_OCR_fake = 0 |
|
self.loss_OCR_real = 0 |
|
self.loss_w_fake = 0 |
|
self.loss_w_real = 0 |
|
self.Lcycle = 0 |
|
self.d_acc = SmoothedValue() |
|
|
|
self.word_generator = get_generator(args) |
|
|
|
self.epoch = 0 |
|
|
|
with open('mytext.txt', 'r', encoding='utf-8') as f: |
|
self.text = f.read() |
|
self.text = self.text.replace('\n', ' ') |
|
self.text = self.text.replace('\n', ' ') |
|
self.text = ''.join(c for c in self.text if c in (self.args.alphabet + self.args.special_alphabet)) |
|
self.text = [word.encode() for word in self.text.split()] |
|
|
|
self.eval_text_encode, self.eval_len_text, self.eval_encode_pos = self.netconverter.encode(self.text) |
|
self.eval_text_encode = self.eval_text_encode.to(self.args.device).repeat(self.args.batch_size, 1, 1) |
|
|
|
self.rv_sample_size = 64 * 4 |
|
self.last_fakes = [] |
|
|
|
def update_last_fakes(self, fakes): |
|
for fake in fakes: |
|
self.last_fakes.append(fake) |
|
self.last_fakes = self.last_fakes[-self.rv_sample_size:] |
|
|
|
def update_acc(self, pred_real, pred_fake): |
|
correct = (pred_real >= 0.5).float().sum() + (pred_fake < 0.5).float().sum() |
|
self.d_acc.update(correct / (len(pred_real) + len(pred_fake))) |
|
|
|
def set_text_aug_strength(self, strength): |
|
if not isinstance(self.word_generator, AugmentedGenerator): |
|
print("WARNING: Text generator is not augmented, strength cannot be set") |
|
else: |
|
self.word_generator.set_strength(strength) |
|
|
|
def get_text_aug_strength(self): |
|
if isinstance(self.word_generator, AugmentedGenerator): |
|
return self.word_generator.strength |
|
else: |
|
return 0.0 |
|
|
|
def update_parameters(self, epoch: int): |
|
self.epoch = epoch |
|
self.netD.update_parameters(epoch) |
|
self.netW.update_parameters(epoch) |
|
|
|
def get_text_sample(self, size: int) -> list: |
|
return [self.word_generator.generate() for _ in range(size)] |
|
|
|
def _generate_fakes(self, ST, eval_text_encode=None, eval_len_text=None): |
|
if eval_text_encode == None: |
|
eval_text_encode = self.eval_text_encode |
|
if eval_len_text == None: |
|
eval_len_text = self.eval_len_text |
|
|
|
self.fakes = self.netG.evaluate(ST, eval_text_encode) |
|
|
|
np_fakes = [] |
|
for batch_idx in range(self.fakes[0].shape[0]): |
|
for idx, fake in enumerate(self.fakes): |
|
fake = fake[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution] |
|
fake = (fake + 1) / 2 |
|
np_fakes.append(fake.cpu().numpy()) |
|
return np_fakes |
|
|
|
def _generate_page(self, ST, SLEN, eval_text_encode=None, eval_len_text=None, eval_encode_pos=None, lwidth=260, rwidth=980): |
|
|
|
|
|
if eval_text_encode == None: |
|
eval_text_encode = self.eval_text_encode |
|
if eval_len_text == None: |
|
eval_len_text = self.eval_len_text |
|
if eval_encode_pos is None: |
|
eval_encode_pos = self.eval_encode_pos |
|
|
|
text_encode, text_len, _ = self.netconverter.encode(self.args.special_alphabet) |
|
symbols = self.netG.query_embed.symbols[text_encode].reshape(-1, 16, 16).cpu().numpy() |
|
imgs = [Image.fromarray(s).resize((32, 32), resample=0) for s in symbols] |
|
special_examples = 1 - np.concatenate([np.array(i) for i in imgs], axis=-1) |
|
|
|
self.fakes = self.netG.evaluate(ST, eval_text_encode) |
|
|
|
page1s = [] |
|
page2s = [] |
|
|
|
for batch_idx in range(ST.shape[0]): |
|
|
|
word_t = [] |
|
word_l = [] |
|
|
|
gap = np.ones([self.args.img_height, 16]) |
|
|
|
line_wids = [] |
|
|
|
for idx, fake_ in enumerate(self.fakes): |
|
|
|
word_t.append((fake_[batch_idx, 0, :, :eval_len_text[idx] * self.args.resolution].cpu().numpy() + 1) / 2) |
|
|
|
word_t.append(gap) |
|
|
|
if sum(t.shape[-1] for t in word_t) >= rwidth or idx == len(self.fakes) - 1 or (len(self.fakes) - len(self.args.special_alphabet) - 1) == idx: |
|
line_ = np.concatenate(word_t, -1) |
|
|
|
word_l.append(line_) |
|
line_wids.append(line_.shape[1]) |
|
|
|
word_t = [] |
|
|
|
|
|
word_l.append(special_examples) |
|
line_wids.append(special_examples.shape[1]) |
|
|
|
gap_h = np.ones([16, max(line_wids)]) |
|
|
|
page_ = [] |
|
|
|
for l in word_l: |
|
pad_ = np.ones([self.args.img_height, max(line_wids) - l.shape[1]]) |
|
|
|
page_.append(np.concatenate([l, pad_], 1)) |
|
page_.append(gap_h) |
|
|
|
page1 = np.concatenate(page_, 0) |
|
|
|
word_t = [] |
|
word_l = [] |
|
|
|
|
|
line_wids = [] |
|
|
|
sdata_ = [i.unsqueeze(1) for i in torch.unbind(ST, 1)] |
|
gap = np.ones([sdata_[0].shape[-2], 16]) |
|
|
|
for idx, st in enumerate((sdata_)): |
|
|
|
word_t.append((st[batch_idx, 0, :, :int(SLEN.cpu().numpy()[batch_idx][idx])].cpu().numpy() + 1) / 2) |
|
|
|
|
|
word_t.append(gap) |
|
|
|
if sum(t.shape[-1] for t in word_t) >= lwidth or idx == len(sdata_) - 1: |
|
line_ = np.concatenate(word_t, -1) |
|
|
|
word_l.append(line_) |
|
line_wids.append(line_.shape[1]) |
|
|
|
word_t = [] |
|
|
|
gap_h = np.ones([16, max(line_wids)]) |
|
|
|
page_ = [] |
|
|
|
for l in word_l: |
|
pad_ = np.ones([sdata_[0].shape[-2], max(line_wids) - l.shape[1]]) |
|
|
|
page_.append(np.concatenate([l, pad_], 1)) |
|
page_.append(gap_h) |
|
|
|
page2 = np.concatenate(page_, 0) |
|
|
|
merge_w_size = max(page1.shape[0], page2.shape[0]) |
|
|
|
if page1.shape[0] != merge_w_size: |
|
page1 = np.concatenate([page1, np.ones([merge_w_size - page1.shape[0], page1.shape[1]])], 0) |
|
|
|
if page2.shape[0] != merge_w_size: |
|
page2 = np.concatenate([page2, np.ones([merge_w_size - page2.shape[0], page2.shape[1]])], 0) |
|
|
|
page1s.append(page1) |
|
page2s.append(page2) |
|
|
|
|
|
|
|
page1s_ = np.concatenate(page1s, 0) |
|
max_wid = max([i.shape[1] for i in page2s]) |
|
padded_page2s = [] |
|
|
|
for para in page2s: |
|
padded_page2s.append(np.concatenate([para, np.ones([para.shape[0], max_wid - para.shape[1]])], 1)) |
|
|
|
padded_page2s_ = np.concatenate(padded_page2s, 0) |
|
|
|
return np.concatenate([padded_page2s_, page1s_], 1) |
|
|
|
def get_current_losses(self): |
|
|
|
losses = {} |
|
|
|
losses['G'] = self.loss_G |
|
losses['D'] = self.loss_D |
|
losses['Dfake'] = self.loss_Dfake |
|
losses['Dreal'] = self.loss_Dreal |
|
losses['OCR_fake'] = self.loss_OCR_fake |
|
losses['OCR_real'] = self.loss_OCR_real |
|
losses['w_fake'] = self.loss_w_fake |
|
losses['w_real'] = self.loss_w_real |
|
losses['cycle'] = self.Lcycle |
|
|
|
return losses |
|
|
|
def _set_input(self, input): |
|
self.input = input |
|
|
|
self.real = self.input['img'].to(self.args.device) |
|
self.label = self.input['label'] |
|
|
|
self.set_ocr_data(self.input['img'], self.input['label']) |
|
|
|
self.sdata = self.input['simg'].to(self.args.device) |
|
self.slabels = self.input['slabels'] |
|
|
|
self.ST_LEN = self.input['swids'] |
|
|
|
def set_requires_grad(self, nets, requires_grad=False): |
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations |
|
Parameters: |
|
nets (network list) -- a list of networks |
|
requires_grad (bool) -- whether the networks require gradients or not |
|
""" |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
def forward(self): |
|
self.text_encode, self.len_text, self.encode_pos = self.netconverter.encode(self.label) |
|
self.text_encode = self.text_encode.to(self.args.device).detach() |
|
self.len_text = self.len_text.detach() |
|
|
|
self.words = [self.word_generator.generate().encode('utf-8') for _ in range(self.args.batch_size)] |
|
self.text_encode_fake, self.len_text_fake, self.encode_pos_fake = self.netconverter.encode(self.words) |
|
self.text_encode_fake = self.text_encode_fake.to(self.args.device) |
|
self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.args.vocab_size).to( |
|
self.args.device) |
|
|
|
self.fake, self.style = self.netG(self.sdata, self.text_encode_fake) |
|
|
|
self.update_last_fakes(self.fake) |
|
|
|
def pad_width(self, t, new_width): |
|
result = torch.ones((t.size(0), t.size(1), t.size(2), new_width), device=t.device) |
|
result[:,:,:,:t.size(-1)] = t |
|
|
|
return result |
|
|
|
def compute_real_ocr_loss(self, ocr_network = None): |
|
network = ocr_network if ocr_network is not None else self.netOCR |
|
real_input = self.ocr_images |
|
input_images = real_input |
|
input_labels = self.ocr_labels |
|
|
|
input_images = input_images.detach() |
|
|
|
if self.ocr_augmenter is not None: |
|
input_images = self.ocr_augmenter(input_images) |
|
|
|
pred_real = network(input_images) |
|
preds_size = torch.IntTensor([pred_real.size(0)] * len(input_labels)).detach() |
|
text_encode, len_text, _ = self.netconverter.encode(input_labels) |
|
|
|
loss = self.OCR_criterion(pred_real, text_encode.detach(), preds_size, len_text.detach()) |
|
|
|
return torch.mean(loss[~torch.isnan(loss)]) |
|
|
|
def compute_fake_ocr_loss(self, ocr_network = None): |
|
network = ocr_network if ocr_network is not None else self.netOCR |
|
|
|
pred_fake_OCR = network(self.fake) |
|
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.args.batch_size).detach() |
|
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, |
|
self.len_text_fake.detach()) |
|
return torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) |
|
|
|
def set_ocr_data(self, images, labels): |
|
self.ocr_images = images.to(self.args.device) |
|
self.ocr_labels = labels |
|
|
|
def backward_D_OCR(self): |
|
self.real.__repr__() |
|
self.fake.__repr__() |
|
pred_real = self.netD(self.real.detach()) |
|
pred_fake = self.netD(**{'x': self.fake.detach()}) |
|
|
|
self.update_acc(pred_real, pred_fake) |
|
|
|
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), |
|
self.len_text.detach(), True) |
|
|
|
self.loss_D = self.loss_Dreal + self.loss_Dfake |
|
|
|
if not self.args.no_ocr_loss: |
|
self.loss_OCR_real = self.compute_real_ocr_loss() |
|
loss_total = self.loss_D + self.loss_OCR_real |
|
else: |
|
loss_total = self.loss_D |
|
|
|
|
|
loss_total.backward() |
|
if not self.args.no_ocr_loss: |
|
self.clean_grad(self.netOCR.parameters()) |
|
|
|
return loss_total |
|
|
|
def clean_grad(self, params): |
|
for param in params: |
|
param.grad[param.grad != param.grad] = 0 |
|
param.grad[torch.isnan(param.grad)] = 0 |
|
param.grad[torch.isinf(param.grad)] = 0 |
|
|
|
def backward_D_WL(self): |
|
|
|
pred_real = self.netD(self.real.detach()) |
|
|
|
pred_fake = self.netD(**{'x': self.fake.detach()}) |
|
|
|
self.update_acc(pred_real, pred_fake) |
|
|
|
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), |
|
self.len_text.detach(), True) |
|
|
|
self.loss_D = self.loss_Dreal + self.loss_Dfake |
|
|
|
if not self.args.no_writer_loss: |
|
self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(self.args.device)).mean() |
|
|
|
loss_total = self.loss_D + self.loss_w_real * self.args.writer_loss_weight |
|
else: |
|
loss_total = self.loss_D |
|
|
|
|
|
loss_total.backward() |
|
|
|
return loss_total |
|
|
|
def optimize_D_WL(self): |
|
self.forward() |
|
self.set_requires_grad([self.netD], True) |
|
self.set_requires_grad([self.netOCR], False) |
|
self.set_requires_grad([self.netW], True) |
|
self.set_requires_grad([self.netW], True) |
|
|
|
self.optimizer_D.zero_grad() |
|
self.optimizer_wl.zero_grad() |
|
|
|
self.backward_D_WL() |
|
|
|
def optimize_D_WL_step(self): |
|
self.optimizer_D.step() |
|
self.optimizer_wl.step() |
|
self.optimizer_D.zero_grad() |
|
self.optimizer_wl.zero_grad() |
|
|
|
def compute_cycle_loss(self): |
|
fake_input = torch.ones_like(self.sdata) |
|
width = min(self.sdata.size(-1), self.fake.size(-1)) |
|
fake_input[:, :, :, :width] = self.fake.repeat(1, 15, 1, 1)[:, :, :, :width] |
|
with torch.no_grad(): |
|
fake_style = self.netG.compute_style(fake_input) |
|
|
|
return torch.sum(torch.abs(self.style.detach() - fake_style), dim=1).mean() |
|
|
|
def backward_G_only(self): |
|
self.gb_alpha = 0.7 |
|
if self.args.is_cycle: |
|
self.Lcycle = self.compute_cycle_loss() |
|
|
|
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean() |
|
|
|
compute_ocr = not self.args.no_ocr_loss |
|
|
|
if compute_ocr: |
|
self.loss_OCR_fake = self.compute_fake_ocr_loss() |
|
|
|
self.loss_G = self.loss_G + self.Lcycle |
|
|
|
if compute_ocr: |
|
self.loss_T = self.loss_G + self.loss_OCR_fake |
|
else: |
|
self.loss_T = self.loss_G |
|
|
|
if compute_ocr: |
|
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] |
|
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) |
|
|
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0] |
|
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) |
|
|
|
self.loss_T.backward(retain_graph=True) |
|
|
|
if compute_ocr: |
|
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] |
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] |
|
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR)) |
|
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake |
|
self.loss_T = self.loss_G + self.loss_OCR_fake |
|
else: |
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] |
|
a = 1 |
|
self.loss_T = self.loss_G |
|
|
|
if a is None: |
|
print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv)) |
|
if a > 1000 or a < 0.0001: |
|
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}') |
|
|
|
self.loss_T.backward(retain_graph=True) |
|
if compute_ocr: |
|
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] |
|
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) |
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] |
|
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) |
|
|
|
with torch.no_grad(): |
|
self.loss_T.backward() |
|
if compute_ocr: |
|
if any(torch.isnan(torch.unsqueeze(self.loss_OCR_fake, dim=0))) or torch.isnan(self.loss_G): |
|
print('loss OCR fake: ', self.loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) |
|
sys.exit() |
|
|
|
def backward_G_WL(self): |
|
self.gb_alpha = 0.7 |
|
if self.args.is_cycle: |
|
self.Lcycle = self.compute_cycle_loss() |
|
|
|
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean() |
|
|
|
if not self.args.no_writer_loss: |
|
self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(self.args.device)).mean() |
|
|
|
self.loss_G = self.loss_G + self.Lcycle |
|
|
|
if not self.args.no_writer_loss: |
|
self.loss_T = self.loss_G + self.loss_w_fake * self.args.writer_loss_weight |
|
else: |
|
self.loss_T = self.loss_G |
|
|
|
self.loss_T.backward(retain_graph=True) |
|
|
|
if not self.args.no_writer_loss: |
|
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0] |
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] |
|
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_WL)) |
|
self.loss_w_fake = a.detach() * self.loss_w_fake |
|
self.loss_T = self.loss_G + self.loss_w_fake |
|
else: |
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] |
|
a = 1 |
|
self.loss_T = self.loss_G |
|
|
|
if a is None: |
|
print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv)) |
|
if a > 1000 or a < 0.0001: |
|
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}') |
|
|
|
self.loss_T.backward(retain_graph=True) |
|
|
|
if not self.args.no_writer_loss: |
|
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0] |
|
self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2) |
|
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] |
|
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) |
|
|
|
with torch.no_grad(): |
|
self.loss_T.backward() |
|
|
|
def backward_G(self): |
|
self.opt.gb_alpha = 0.7 |
|
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), |
|
self.opt.mask_loss) |
|
|
|
compute_ocr = not self.args.no_ocr_loss |
|
|
|
if compute_ocr: |
|
self.loss_OCR_fake = self.compute_fake_ocr_loss() |
|
else: |
|
self.loss_OCR_fake = 0.0 |
|
|
|
self.loss_w_fake = self.netW(self.fake, self.wcl) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.loss_G_ = 10 * self.loss_G + self.loss_w_fake |
|
self.loss_T = self.loss_G_ + self.loss_OCR_fake |
|
|
|
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] |
|
|
|
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) |
|
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0] |
|
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) |
|
|
|
if not False: |
|
|
|
self.loss_T.backward(retain_graph=True) |
|
|
|
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] |
|
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0] |
|
|
|
|
|
a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR)) |
|
|
|
|
|
|
|
if a is None: |
|
print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) |
|
if a > 1000 or a < 0.0001: |
|
print(f'WARNING: alpha > 1000 or alpha < 0.0001 - alpha={a.item()}') |
|
b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) - |
|
torch.div(torch.std(grad_fake_adv), self.epsilon + torch.std(grad_fake_OCR)) * |
|
torch.mean(grad_fake_OCR)) |
|
|
|
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake |
|
|
|
|
|
self.loss_T = (1 - 1 * self.opt.onlyOCR) * self.loss_G_ + self.loss_OCR_fake |
|
self.loss_T.backward(retain_graph=True) |
|
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] |
|
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0] |
|
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) |
|
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) |
|
with torch.no_grad(): |
|
self.loss_T.backward() |
|
else: |
|
self.loss_T.backward() |
|
|
|
if self.opt.clip_grad > 0: |
|
clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad) |
|
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_): |
|
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) |
|
sys.exit() |
|
|
|
def optimize_D_OCR(self): |
|
self.forward() |
|
self.set_requires_grad([self.netD], True) |
|
self.set_requires_grad([self.netOCR], True) |
|
self.optimizer_D.zero_grad() |
|
|
|
self.optimizer_OCR.zero_grad() |
|
self.backward_D_OCR() |
|
|
|
def optimize_D_OCR_step(self): |
|
self.optimizer_D.step() |
|
|
|
self.optimizer_OCR.step() |
|
self.optimizer_D.zero_grad() |
|
self.optimizer_OCR.zero_grad() |
|
|
|
def optimize_G_WL(self): |
|
self.forward() |
|
self.set_requires_grad([self.netD], False) |
|
self.set_requires_grad([self.netOCR], False) |
|
self.set_requires_grad([self.netW], False) |
|
self.backward_G_WL() |
|
|
|
def optimize_G_only(self): |
|
self.forward() |
|
self.set_requires_grad([self.netD], False) |
|
self.set_requires_grad([self.netOCR], False) |
|
self.set_requires_grad([self.netW], False) |
|
self.backward_G_only() |
|
|
|
def optimize_G_step(self): |
|
self.optimizer_G.step() |
|
self.optimizer_G.zero_grad() |
|
|
|
def save_networks(self, epoch, save_dir): |
|
"""Save all the networks to the disk. |
|
|
|
Parameters: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
save_filename = '%s_net_%s.pth' % (epoch, name) |
|
save_path = os.path.join(save_dir, save_filename) |
|
net = getattr(self, 'net' + name) |
|
|
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): |
|
|
|
if len(self.gpu_ids) > 1: |
|
torch.save(net.module.cpu().state_dict(), save_path) |
|
else: |
|
torch.save(net.cpu().state_dict(), save_path) |
|
net.cuda(self.gpu_ids[0]) |
|
else: |
|
torch.save(net.cpu().state_dict(), save_path) |
|
|
|
def compute_d_scores(self, data_loader: torch.utils.data.DataLoader, amount: int = None): |
|
scores = [] |
|
words = [] |
|
amount = len(data_loader) if amount is None else amount // data_loader.batch_size |
|
|
|
with torch.no_grad(): |
|
for i in range(amount): |
|
data = next(iter(data_loader)) |
|
words.extend([d.decode() for d in data['label']]) |
|
scores.extend(list(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy())) |
|
|
|
return scores, words |
|
|
|
def compute_d_scores_fake(self, data_loader: torch.utils.data.DataLoader, amount: int = None): |
|
scores = [] |
|
words = [] |
|
amount = len(data_loader) if amount is None else amount // data_loader.batch_size |
|
|
|
with torch.no_grad(): |
|
for i in range(amount): |
|
data = next(iter(data_loader)) |
|
to_generate = [self.word_generator.generate().encode('utf-8') for _ in range(data_loader.batch_size)] |
|
text_encode_fake, len_text_fake, encode_pos_fake = self.netconverter.encode(to_generate) |
|
fake, _ = self.netG(data['simg'].to(self.args.device), text_encode_fake.to(self.args.device)) |
|
|
|
words.extend([d.decode() for d in to_generate]) |
|
scores.extend(list(self.netD(fake).squeeze().detach().cpu().numpy())) |
|
|
|
return scores, words |
|
|
|
def compute_d_stats(self, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader): |
|
train_values = [] |
|
val_values = [] |
|
fake_values = [] |
|
with torch.no_grad(): |
|
for i in range(self.rv_sample_size // train_loader.batch_size): |
|
data = next(iter(train_loader)) |
|
train_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy()) |
|
|
|
for i in range(self.rv_sample_size // val_loader.batch_size): |
|
data = next(iter(val_loader)) |
|
val_values.append(self.netD(data['img'].to(self.args.device)).squeeze().detach().cpu().numpy()) |
|
|
|
for i in range(self.rv_sample_size): |
|
data = self.last_fakes[i] |
|
fake_values.append(self.netD(data.unsqueeze(0)).squeeze().detach().cpu().numpy()) |
|
|
|
return np.mean(train_values), np.mean(val_values), np.mean(fake_values) |
|
|