|
import json |
|
import os |
|
import random |
|
import shutil |
|
from collections import defaultdict |
|
import time |
|
from datetime import timedelta |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
|
|
from data.dataset import FolderDataset |
|
from models.model import VATr |
|
from util.loading import load_checkpoint, load_generator |
|
from util.misc import FakeArgs |
|
from util.text import TextGenerator |
|
from util.vision import detect_text_bounds |
|
|
|
|
|
def get_long_tail_chars(): |
|
with open(f"files/longtail.txt", 'r') as f: |
|
chars = [c.rstrip() for c in f] |
|
|
|
chars.remove('') |
|
|
|
return chars |
|
|
|
|
|
class Writer: |
|
def __init__(self, checkpoint_path, args, only_generator: bool = False): |
|
self.model = VATr(args) |
|
checkpoint = torch.load(checkpoint_path, map_location=args.device) |
|
load_checkpoint(self.model, checkpoint) if not only_generator else load_generator(self.model, checkpoint) |
|
self.model.eval() |
|
self.style_dataset = None |
|
|
|
def set_style_folder(self, style_folder, num_examples=15): |
|
word_lengths = None |
|
if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): |
|
word_lengths = {} |
|
with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: |
|
for line in f: |
|
word, length = line.rstrip().split(",") |
|
word_lengths[word] = int(length) |
|
|
|
self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) |
|
|
|
@torch.no_grad() |
|
def generate(self, texts, align_words: bool = False, at_once: bool = False): |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
if self.style_dataset is None: |
|
raise Exception('Style is not set') |
|
|
|
fakes = [] |
|
for i, text in enumerate(texts, 1): |
|
print(f'[{i}/{len(texts)}] Generating for text: {text}') |
|
style = self.style_dataset.sample_style() |
|
style_images = style['simg'].unsqueeze(0).to(self.model.args.device) |
|
|
|
fake = self.create_fake_sentence(style_images, text, align_words, at_once) |
|
|
|
fakes.append(fake) |
|
return fakes |
|
|
|
@torch.no_grad() |
|
def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): |
|
text = "".join([c for c in text if c in self.model.args.alphabet]) |
|
|
|
text = text.split() if not at_once else [text] |
|
gap = np.ones((32, 16)) |
|
|
|
text_encode, len_text, encode_pos = self.model.netconverter.encode(text) |
|
text_encode = text_encode.to(self.model.args.device).unsqueeze(0) |
|
|
|
fake = self.model._generate_fakes(style_images, text_encode, len_text) |
|
if not at_once: |
|
if align_words: |
|
fake = self.stitch_words(fake, show_lines=False) |
|
else: |
|
fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] |
|
else: |
|
fake = fake[0] |
|
fake = (fake * 255).astype(np.uint8) |
|
|
|
return fake |
|
|
|
@torch.no_grad() |
|
def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False): |
|
fakes = [] |
|
author_ids = [] |
|
style = [] |
|
|
|
for item in dataset: |
|
print(f"Generating author {item['wcl']}") |
|
style_images = item['simg'].to(self.model.args.device).unsqueeze(0) |
|
|
|
generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text] |
|
|
|
fakes.append(generated_lines) |
|
author_ids.append(item['author_id']) |
|
style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8)) |
|
|
|
return fakes, author_ids, style |
|
|
|
@torch.no_grad() |
|
def generate_characters(self, dataset, characters: str): |
|
""" |
|
Generate each of the given characters for each of the authors in the dataset. |
|
""" |
|
fakes = [] |
|
|
|
text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters]) |
|
text_encode = text_encode.to(self.model.args.device).unsqueeze(0) |
|
|
|
for item in dataset: |
|
print(f"Generating author {item['wcl']}") |
|
style_images = item['simg'].to(self.model.args.device).unsqueeze(0) |
|
fake = self.model.netG.evaluate(style_images, text_encode) |
|
|
|
fakes.append(fake) |
|
|
|
return fakes |
|
|
|
@torch.no_grad() |
|
def generate_batch(self, style_imgs, text): |
|
""" |
|
Given a batch of style images and text, generate images using the model |
|
""" |
|
device = self.model.args.device |
|
text_encode, _, _ = self.model.netconverter.encode(text) |
|
fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) |
|
return fakes |
|
|
|
@torch.no_grad() |
|
def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False): |
|
def create_and_write(style, text, interpolated=False): |
|
nonlocal image_counter, annotations |
|
|
|
text_encode, len_text, encode_pos = self.model.netconverter.encode([text]) |
|
text_encode = text_encode.to(self.model.args.device) |
|
|
|
fake = self.model.netG.generate(style, text_encode) |
|
|
|
fake = (fake + 1) / 2 |
|
fake = fake.cpu().numpy() |
|
fake = np.squeeze((fake * 255).astype(np.uint8)) |
|
|
|
image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png" |
|
|
|
cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake) |
|
|
|
annotations.append((image_filename, text)) |
|
|
|
image_counter += 1 |
|
|
|
image_counter = 0 |
|
annotations = [] |
|
previous_style = None |
|
long_tail_chars = get_long_tail_chars() |
|
|
|
os.mkdir(os.path.join(output_folder, "generated")) |
|
if text_generator is None: |
|
os.mkdir(os.path.join(output_folder, "reference")) |
|
|
|
while image_counter < number: |
|
author_index = random.randint(0, len(dataset) - 1) |
|
item = dataset[author_index] |
|
|
|
style_images = item['simg'].to(self.model.args.device).unsqueeze(0) |
|
style = self.model.netG.compute_style(style_images) |
|
|
|
if interpolate_style and previous_style is not None: |
|
factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0)) |
|
intermediate_style = torch.lerp(previous_style, style, factor) |
|
text = text_generator.generate() |
|
|
|
create_and_write(intermediate_style, text, interpolated=True) |
|
|
|
if text_generator is not None: |
|
text = text_generator.generate() |
|
else: |
|
text = str(item['label'].decode()) |
|
|
|
if long_tail and not any(c in long_tail_chars for c in text): |
|
continue |
|
|
|
fake = (item['img'] + 1) / 2 |
|
fake = fake.cpu().numpy() |
|
fake = np.squeeze((fake * 255).astype(np.uint8)) |
|
|
|
image_filename = f"{image_counter}.png" |
|
|
|
cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake) |
|
|
|
create_and_write(style, text) |
|
|
|
previous_style = style |
|
|
|
if text_generator is None: |
|
with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr: |
|
fr.write(f"filename,words\n") |
|
for annotation in annotations: |
|
fr.write(f"{annotation[0]},{annotation[1]}\n") |
|
|
|
with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg: |
|
fg.write(f"filename,words\n") |
|
for annotation in annotations: |
|
fg.write(f"{annotation[0]},{annotation[1]}\n") |
|
|
|
|
|
@staticmethod |
|
def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): |
|
gap_width = 16 |
|
|
|
bottom_lines = [] |
|
top_lines = [] |
|
for i in range(len(words)): |
|
b, t = detect_text_bounds(words[i]) |
|
bottom_lines.append(b) |
|
top_lines.append(t) |
|
if show_lines: |
|
words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) |
|
words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) |
|
|
|
bottom_lines = np.array(bottom_lines, dtype=float) |
|
|
|
if scale_words: |
|
top_lines = np.array(top_lines, dtype=float) |
|
gaps = bottom_lines - top_lines |
|
target_gap = np.mean(gaps) |
|
scales = target_gap / gaps |
|
|
|
bottom_lines *= scales |
|
top_lines *= scales |
|
words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] |
|
|
|
highest = np.max(bottom_lines) |
|
offsets = highest - bottom_lines |
|
height = np.max(offsets + [word.shape[0] for word in words]) |
|
|
|
result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) |
|
|
|
x_pos = 0 |
|
for bottom_line, word in zip(bottom_lines, words): |
|
offset = int(highest - bottom_line) |
|
|
|
result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word |
|
|
|
x_pos += word.shape[1] + gap_width |
|
|
|
return result |
|
|
|
@torch.no_grad() |
|
def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False): |
|
if not isinstance(path, Path): |
|
path = Path(path) |
|
|
|
path.mkdir(exist_ok=True, parents=True) |
|
|
|
appendix = f"{split}" if not long_tail_only else f"{split}_lt" |
|
|
|
real_base = path / f'real_{appendix}' |
|
fake_base = path / model_tag / f'fake_{appendix}' |
|
|
|
if real_base.exists() and not fake_only: |
|
shutil.rmtree(real_base) |
|
|
|
if fake_base.exists(): |
|
shutil.rmtree(fake_base) |
|
|
|
real_base.mkdir(exist_ok=True) |
|
fake_base.mkdir(exist_ok=True, parents=True) |
|
|
|
print('Saving images...') |
|
|
|
print(' Saving images on {}'.format(str(real_base))) |
|
print(' Saving images on {}'.format(str(fake_base))) |
|
|
|
long_tail_chars = get_long_tail_chars() |
|
counter = 0 |
|
ann = defaultdict(lambda: {}) |
|
start_time = time.time() |
|
for step, data in enumerate(loader): |
|
style_images = data['simg'].to(self.model.args.device) |
|
|
|
texts = [l.decode('utf-8') for l in data['label']] |
|
texts = [t.encode('utf-8') for t in texts] |
|
eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts) |
|
eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1) |
|
|
|
vis_style = np.vstack(style_images[0].detach().cpu().numpy()) |
|
vis_style = ((vis_style + 1) / 2) * 255 |
|
|
|
fakes = self.model.netG.evaluate(style_images, eval_text_encode) |
|
fake_images = torch.cat(fakes, 1).detach().cpu().numpy() |
|
real_images = data['img'].detach().cpu().numpy() |
|
writer_ids = data['wcl'].int().tolist() |
|
|
|
for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])): |
|
lb = lb.decode() |
|
ann[f"{wid:03d}"][f'{img_id:05d}'] = lb |
|
img_id = f'{img_id:05d}.png' |
|
|
|
is_long_tail = any(c in long_tail_chars for c in lb) |
|
|
|
if long_tail_only and not is_long_tail: |
|
continue |
|
|
|
fake_img_path = fake_base / f"{wid:03d}" / img_id |
|
fake_img_path.parent.mkdir(exist_ok=True, parents=True) |
|
cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2)) |
|
|
|
if not fake_only: |
|
real_img_path = real_base / f"{wid:03d}" / img_id |
|
real_img_path.parent.mkdir(exist_ok=True, parents=True) |
|
cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2)) |
|
|
|
counter += 1 |
|
|
|
eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1) |
|
eta = str(timedelta(seconds=eta)) |
|
if step % 100 == 0: |
|
print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}') |
|
|
|
with open(path / 'ann.json', 'w') as f: |
|
json.dump(ann, f) |
|
|