import json import os import random import shutil from collections import defaultdict import time from datetime import timedelta from pathlib import Path import cv2 import numpy as np import torch from data.dataset import FolderDataset from models.model import VATr from util.loading import load_checkpoint, load_generator from util.misc import FakeArgs from util.text import TextGenerator from util.vision import detect_text_bounds def get_long_tail_chars(): with open(f"files/longtail.txt", 'r') as f: chars = [c.rstrip() for c in f] chars.remove('') return chars class Writer: def __init__(self, checkpoint_path, args, only_generator: bool = False): self.model = VATr(args) checkpoint = torch.load(checkpoint_path, map_location=args.device) load_checkpoint(self.model, checkpoint) if not only_generator else load_generator(self.model, checkpoint) self.model.eval() self.style_dataset = None def set_style_folder(self, style_folder, num_examples=15): word_lengths = None if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): word_lengths = {} with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: for line in f: word, length = line.rstrip().split(",") word_lengths[word] = int(length) self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) @torch.no_grad() def generate(self, texts, align_words: bool = False, at_once: bool = False): if isinstance(texts, str): texts = [texts] if self.style_dataset is None: raise Exception('Style is not set') fakes = [] for i, text in enumerate(texts, 1): print(f'[{i}/{len(texts)}] Generating for text: {text}') style = self.style_dataset.sample_style() style_images = style['simg'].unsqueeze(0).to(self.model.args.device) fake = self.create_fake_sentence(style_images, text, align_words, at_once) fakes.append(fake) return fakes @torch.no_grad() def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): text = "".join([c for c in text if c in self.model.args.alphabet]) text = text.split() if not at_once else [text] gap = np.ones((32, 16)) text_encode, len_text, encode_pos = self.model.netconverter.encode(text) text_encode = text_encode.to(self.model.args.device).unsqueeze(0) fake = self.model._generate_fakes(style_images, text_encode, len_text) if not at_once: if align_words: fake = self.stitch_words(fake, show_lines=False) else: fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] else: fake = fake[0] fake = (fake * 255).astype(np.uint8) return fake @torch.no_grad() def generate_authors(self, text, dataset, align_words: bool = False, at_once: bool = False): fakes = [] author_ids = [] style = [] for item in dataset: print(f"Generating author {item['wcl']}") style_images = item['simg'].to(self.model.args.device).unsqueeze(0) generated_lines = [self.create_fake_sentence(style_images, line, align_words, at_once) for line in text] fakes.append(generated_lines) author_ids.append(item['author_id']) style.append((((item['simg'].numpy() + 1.0) / 2.0) * 255).astype(np.uint8)) return fakes, author_ids, style @torch.no_grad() def generate_characters(self, dataset, characters: str): """ Generate each of the given characters for each of the authors in the dataset. """ fakes = [] text_encode, len_text, encode_pos = self.model.netconverter.encode([c for c in characters]) text_encode = text_encode.to(self.model.args.device).unsqueeze(0) for item in dataset: print(f"Generating author {item['wcl']}") style_images = item['simg'].to(self.model.args.device).unsqueeze(0) fake = self.model.netG.evaluate(style_images, text_encode) fakes.append(fake) return fakes @torch.no_grad() def generate_batch(self, style_imgs, text): """ Given a batch of style images and text, generate images using the model """ device = self.model.args.device text_encode, _, _ = self.model.netconverter.encode(text) fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) return fakes @torch.no_grad() def generate_ocr(self, dataset, number: int, output_folder: str = 'saved_images/ocr', interpolate_style: bool = False, text_generator: TextGenerator = None, long_tail: bool = False): def create_and_write(style, text, interpolated=False): nonlocal image_counter, annotations text_encode, len_text, encode_pos = self.model.netconverter.encode([text]) text_encode = text_encode.to(self.model.args.device) fake = self.model.netG.generate(style, text_encode) fake = (fake + 1) / 2 fake = fake.cpu().numpy() fake = np.squeeze((fake * 255).astype(np.uint8)) image_filename = f"{image_counter}.png" if not interpolated else f"{image_counter}_i.png" cv2.imwrite(os.path.join(output_folder, "generated", image_filename), fake) annotations.append((image_filename, text)) image_counter += 1 image_counter = 0 annotations = [] previous_style = None long_tail_chars = get_long_tail_chars() os.mkdir(os.path.join(output_folder, "generated")) if text_generator is None: os.mkdir(os.path.join(output_folder, "reference")) while image_counter < number: author_index = random.randint(0, len(dataset) - 1) item = dataset[author_index] style_images = item['simg'].to(self.model.args.device).unsqueeze(0) style = self.model.netG.compute_style(style_images) if interpolate_style and previous_style is not None: factor = float(np.clip(random.gauss(0.5, 0.15), 0.0, 1.0)) intermediate_style = torch.lerp(previous_style, style, factor) text = text_generator.generate() create_and_write(intermediate_style, text, interpolated=True) if text_generator is not None: text = text_generator.generate() else: text = str(item['label'].decode()) if long_tail and not any(c in long_tail_chars for c in text): continue fake = (item['img'] + 1) / 2 fake = fake.cpu().numpy() fake = np.squeeze((fake * 255).astype(np.uint8)) image_filename = f"{image_counter}.png" cv2.imwrite(os.path.join(output_folder, "reference", image_filename), fake) create_and_write(style, text) previous_style = style if text_generator is None: with open(os.path.join(output_folder, "reference", "labels.csv"), 'w') as fr: fr.write(f"filename,words\n") for annotation in annotations: fr.write(f"{annotation[0]},{annotation[1]}\n") with open(os.path.join(output_folder, "generated", "labels.csv"), 'w') as fg: fg.write(f"filename,words\n") for annotation in annotations: fg.write(f"{annotation[0]},{annotation[1]}\n") @staticmethod def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): gap_width = 16 bottom_lines = [] top_lines = [] for i in range(len(words)): b, t = detect_text_bounds(words[i]) bottom_lines.append(b) top_lines.append(t) if show_lines: words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) bottom_lines = np.array(bottom_lines, dtype=float) if scale_words: top_lines = np.array(top_lines, dtype=float) gaps = bottom_lines - top_lines target_gap = np.mean(gaps) scales = target_gap / gaps bottom_lines *= scales top_lines *= scales words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] highest = np.max(bottom_lines) offsets = highest - bottom_lines height = np.max(offsets + [word.shape[0] for word in words]) result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) x_pos = 0 for bottom_line, word in zip(bottom_lines, words): offset = int(highest - bottom_line) result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word x_pos += word.shape[1] + gap_width return result @torch.no_grad() def generate_fid(self, path, loader, model_tag, split='train', fake_only=False, long_tail_only=False): if not isinstance(path, Path): path = Path(path) path.mkdir(exist_ok=True, parents=True) appendix = f"{split}" if not long_tail_only else f"{split}_lt" real_base = path / f'real_{appendix}' fake_base = path / model_tag / f'fake_{appendix}' if real_base.exists() and not fake_only: shutil.rmtree(real_base) if fake_base.exists(): shutil.rmtree(fake_base) real_base.mkdir(exist_ok=True) fake_base.mkdir(exist_ok=True, parents=True) print('Saving images...') print(' Saving images on {}'.format(str(real_base))) print(' Saving images on {}'.format(str(fake_base))) long_tail_chars = get_long_tail_chars() counter = 0 ann = defaultdict(lambda: {}) start_time = time.time() for step, data in enumerate(loader): style_images = data['simg'].to(self.model.args.device) texts = [l.decode('utf-8') for l in data['label']] texts = [t.encode('utf-8') for t in texts] eval_text_encode, eval_len_text, _ = self.model.netconverter.encode(texts) eval_text_encode = eval_text_encode.to(self.model.args.device).unsqueeze(1) vis_style = np.vstack(style_images[0].detach().cpu().numpy()) vis_style = ((vis_style + 1) / 2) * 255 fakes = self.model.netG.evaluate(style_images, eval_text_encode) fake_images = torch.cat(fakes, 1).detach().cpu().numpy() real_images = data['img'].detach().cpu().numpy() writer_ids = data['wcl'].int().tolist() for i, (fake, real, wid, lb, img_id) in enumerate(zip(fake_images, real_images, writer_ids, data['label'], data['idx'])): lb = lb.decode() ann[f"{wid:03d}"][f'{img_id:05d}'] = lb img_id = f'{img_id:05d}.png' is_long_tail = any(c in long_tail_chars for c in lb) if long_tail_only and not is_long_tail: continue fake_img_path = fake_base / f"{wid:03d}" / img_id fake_img_path.parent.mkdir(exist_ok=True, parents=True) cv2.imwrite(str(fake_img_path), 255 * ((fake.squeeze() + 1) / 2)) if not fake_only: real_img_path = real_base / f"{wid:03d}" / img_id real_img_path.parent.mkdir(exist_ok=True, parents=True) cv2.imwrite(str(real_img_path), 255 * ((real.squeeze() + 1) / 2)) counter += 1 eta = (time.time() - start_time) / (step + 1) * (len(loader) - step - 1) eta = str(timedelta(seconds=eta)) if step % 100 == 0: print(f'[{(step + 1) / len(loader) * 100:.02f}%][{counter:05d}] ETA {eta}') with open(path / 'ann.json', 'w') as f: json.dump(ann, f)