from transformers import PreTrainedModel, AutoModel, AutoConfig from .configuration_vatrpp import VATrPPConfig import os import cv2 import numpy as np import torch from .data.dataset import FolderDataset from .models.model import VATr from .models.util.vision import detect_text_bounds from torchvision.transforms.functional import to_pil_image from huggingface_hub import hf_hub_download def get_long_tail_chars(): with open(f"files/longtail.txt", 'r') as f: chars = [c.rstrip() for c in f] chars.remove('') return chars class VATrPP(PreTrainedModel): config_class = VATrPPConfig def __init__(self, config: VATrPPConfig) -> None: super().__init__(config) config.english_words_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename=config.english_words_path) config.mytext_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename='mytext.txt') self.model = VATr(config) self.model.eval() def set_style_folder(self, style_folder, num_examples=15): word_lengths = None if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): word_lengths = {} with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: for line in f: word, length = line.rstrip().split(",") word_lengths[word] = int(length) self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) @torch.no_grad() def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False): style_images = style_imgs.unsqueeze(0).to(self.model.args.device) fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once) return to_pil_image(fake) @torch.no_grad() def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): text = "".join([c for c in text if c in self.model.args.alphabet]) text = text.split() if not at_once else [text] gap = np.ones((32, 16)) text_encode, len_text, encode_pos = self.model.netconverter.encode(text) text_encode = text_encode.to(self.model.args.device).unsqueeze(0) fake = self.model._generate_fakes(style_images, text_encode, len_text) if not at_once: if align_words: fake = self.stitch_words(fake, show_lines=False) else: fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] else: fake = fake[0] fake = (fake * 255).astype(np.uint8) return fake @torch.no_grad() def generate_batch(self, style_imgs, text): """ Given a batch of style images and text, generate images using the model """ device = self.model.args.device text_encode, _, _ = self.model.netconverter.encode(text) fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) return fakes @staticmethod def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): gap_width = 16 bottom_lines = [] top_lines = [] for i in range(len(words)): b, t = detect_text_bounds(words[i]) bottom_lines.append(b) top_lines.append(t) if show_lines: words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) bottom_lines = np.array(bottom_lines, dtype=float) if scale_words: top_lines = np.array(top_lines, dtype=float) gaps = bottom_lines - top_lines target_gap = np.mean(gaps) scales = target_gap / gaps bottom_lines *= scales top_lines *= scales words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] highest = np.max(bottom_lines) offsets = highest - bottom_lines height = np.max(offsets + [word.shape[0] for word in words]) result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) x_pos = 0 for bottom_line, word in zip(bottom_lines, words): offset = int(highest - bottom_line) result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word x_pos += word.shape[1] + gap_width return result AutoConfig.register("vatrpp", VATrPPConfig) AutoModel.register(VATrPPConfig, VATrPP)