from transformers import PreTrainedModel, AutoModel, AutoConfig from .configuration_vatrpp import VATrPPConfig import json import os import random import shutil from collections import defaultdict import time from datetime import timedelta from pathlib import Path import cv2 import numpy as np import torch from data.dataset import FolderDataset from models.model import VATr from util.loading import load_checkpoint, load_generator from util.misc import FakeArgs from util.text import TextGenerator from util.vision import detect_text_bounds from torchvision.transforms.functional import to_pil_image def get_long_tail_chars(): with open(f"files/longtail.txt", 'r') as f: chars = [c.rstrip() for c in f] chars.remove('') return chars class VATrPP(PreTrainedModel): config_class = VATrPPConfig def __init__(self, config: VATrPPConfig) -> None: super().__init__(config) self.model = VATr(config) self.model.eval() def set_style_folder(self, style_folder, num_examples=15): word_lengths = None if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): word_lengths = {} with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: for line in f: word, length = line.rstrip().split(",") word_lengths[word] = int(length) self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) @torch.no_grad() def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False): style_images = style_imgs.unsqueeze(0).to(self.model.args.device) fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once) return to_pil_image(fake) @torch.no_grad() def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): text = "".join([c for c in text if c in self.model.args.alphabet]) text = text.split() if not at_once else [text] gap = np.ones((32, 16)) text_encode, len_text, encode_pos = self.model.netconverter.encode(text) text_encode = text_encode.to(self.model.args.device).unsqueeze(0) fake = self.model._generate_fakes(style_images, text_encode, len_text) if not at_once: if align_words: fake = self.stitch_words(fake, show_lines=False) else: fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] else: fake = fake[0] fake = (fake * 255).astype(np.uint8) return fake @torch.no_grad() def generate_batch(self, style_imgs, text): """ Given a batch of style images and text, generate images using the model """ device = self.model.args.device text_encode, _, _ = self.model.netconverter.encode(text) fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) return fakes @staticmethod def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): gap_width = 16 bottom_lines = [] top_lines = [] for i in range(len(words)): b, t = detect_text_bounds(words[i]) bottom_lines.append(b) top_lines.append(t) if show_lines: words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) bottom_lines = np.array(bottom_lines, dtype=float) if scale_words: top_lines = np.array(top_lines, dtype=float) gaps = bottom_lines - top_lines target_gap = np.mean(gaps) scales = target_gap / gaps bottom_lines *= scales top_lines *= scales words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] highest = np.max(bottom_lines) offsets = highest - bottom_lines height = np.max(offsets + [word.shape[0] for word in words]) result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) x_pos = 0 for bottom_line, word in zip(bottom_lines, words): offset = int(highest - bottom_line) result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word x_pos += word.shape[1] + gap_width return result AutoConfig.register("vatrpp", VATrPPConfig) AutoModel.register(VATrPPConfig, VATrPP)