|
from transformers import PreTrainedModel, AutoModel |
|
from .configuration_vatrpp import VATrPPConfig |
|
import json |
|
import os |
|
import random |
|
import shutil |
|
from collections import defaultdict |
|
import time |
|
from datetime import timedelta |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
|
|
from data.dataset import FolderDataset |
|
from models.model import VATr |
|
from util.loading import load_checkpoint, load_generator |
|
from util.misc import FakeArgs |
|
from util.text import TextGenerator |
|
from util.vision import detect_text_bounds |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
|
|
def get_long_tail_chars(): |
|
with open(f"files/longtail.txt", 'r') as f: |
|
chars = [c.rstrip() for c in f] |
|
|
|
chars.remove('') |
|
|
|
return chars |
|
|
|
@AutoModel.register(VATrPPConfig) |
|
class VATrPP(PreTrainedModel): |
|
config_class = VATrPPConfig |
|
|
|
def __init__(self, config: VATrPPConfig) -> None: |
|
super().__init__(config) |
|
self.model = VATr(config) |
|
self.model.eval() |
|
|
|
def set_style_folder(self, style_folder, num_examples=15): |
|
word_lengths = None |
|
if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): |
|
word_lengths = {} |
|
with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: |
|
for line in f: |
|
word, length = line.rstrip().split(",") |
|
word_lengths[word] = int(length) |
|
|
|
self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) |
|
|
|
@torch.no_grad() |
|
def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False): |
|
style_images = style_imgs.unsqueeze(0).to(self.model.args.device) |
|
|
|
fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once) |
|
return to_pil_image(fake) |
|
|
|
@torch.no_grad() |
|
def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): |
|
text = "".join([c for c in text if c in self.model.args.alphabet]) |
|
|
|
text = text.split() if not at_once else [text] |
|
gap = np.ones((32, 16)) |
|
|
|
text_encode, len_text, encode_pos = self.model.netconverter.encode(text) |
|
text_encode = text_encode.to(self.model.args.device).unsqueeze(0) |
|
|
|
fake = self.model._generate_fakes(style_images, text_encode, len_text) |
|
if not at_once: |
|
if align_words: |
|
fake = self.stitch_words(fake, show_lines=False) |
|
else: |
|
fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] |
|
else: |
|
fake = fake[0] |
|
fake = (fake * 255).astype(np.uint8) |
|
|
|
return fake |
|
|
|
@torch.no_grad() |
|
def generate_batch(self, style_imgs, text): |
|
""" |
|
Given a batch of style images and text, generate images using the model |
|
""" |
|
device = self.model.args.device |
|
text_encode, _, _ = self.model.netconverter.encode(text) |
|
fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) |
|
return fakes |
|
|
|
@staticmethod |
|
def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): |
|
gap_width = 16 |
|
|
|
bottom_lines = [] |
|
top_lines = [] |
|
for i in range(len(words)): |
|
b, t = detect_text_bounds(words[i]) |
|
bottom_lines.append(b) |
|
top_lines.append(t) |
|
if show_lines: |
|
words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) |
|
words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) |
|
|
|
bottom_lines = np.array(bottom_lines, dtype=float) |
|
|
|
if scale_words: |
|
top_lines = np.array(top_lines, dtype=float) |
|
gaps = bottom_lines - top_lines |
|
target_gap = np.mean(gaps) |
|
scales = target_gap / gaps |
|
|
|
bottom_lines *= scales |
|
top_lines *= scales |
|
words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] |
|
|
|
highest = np.max(bottom_lines) |
|
offsets = highest - bottom_lines |
|
height = np.max(offsets + [word.shape[0] for word in words]) |
|
|
|
result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) |
|
|
|
x_pos = 0 |
|
for bottom_line, word in zip(bottom_lines, words): |
|
offset = int(highest - bottom_line) |
|
|
|
result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word |
|
|
|
x_pos += word.shape[1] + gap_width |
|
|
|
return result |
|
|