vatrpp / modeling_vatrpp.py
vittoriopippi
Change imports
160658a
raw
history blame
4.67 kB
from transformers import PreTrainedModel, AutoModel, AutoConfig
from .configuration_vatrpp import VATrPPConfig
import os
import cv2
import numpy as np
import torch
from .data.dataset import FolderDataset
from .models.model import VATr
from .models.util.vision import detect_text_bounds
from torchvision.transforms.functional import to_pil_image
from huggingface_hub import hf_hub_download
def get_long_tail_chars():
with open(f"files/longtail.txt", 'r') as f:
chars = [c.rstrip() for c in f]
chars.remove('')
return chars
class VATrPP(PreTrainedModel):
config_class = VATrPPConfig
def __init__(self, config: VATrPPConfig) -> None:
super().__init__(config)
config.english_words_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename=config.english_words_path)
config.mytext_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename='mytext.txt')
self.model = VATr(config)
self.model.eval()
def set_style_folder(self, style_folder, num_examples=15):
word_lengths = None
if os.path.exists(os.path.join(style_folder, "word_lengths.txt")):
word_lengths = {}
with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f:
for line in f:
word, length = line.rstrip().split(",")
word_lengths[word] = int(length)
self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths)
@torch.no_grad()
def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False):
style_images = style_imgs.unsqueeze(0).to(self.model.args.device)
fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once)
return to_pil_image(fake)
@torch.no_grad()
def create_fake_sentence(self, style_images, text, align_words=False, at_once=False):
text = "".join([c for c in text if c in self.model.args.alphabet])
text = text.split() if not at_once else [text]
gap = np.ones((32, 16))
text_encode, len_text, encode_pos = self.model.netconverter.encode(text)
text_encode = text_encode.to(self.model.args.device).unsqueeze(0)
fake = self.model._generate_fakes(style_images, text_encode, len_text)
if not at_once:
if align_words:
fake = self.stitch_words(fake, show_lines=False)
else:
fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16]
else:
fake = fake[0]
fake = (fake * 255).astype(np.uint8)
return fake
@torch.no_grad()
def generate_batch(self, style_imgs, text):
"""
Given a batch of style images and text, generate images using the model
"""
device = self.model.args.device
text_encode, _, _ = self.model.netconverter.encode(text)
fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device))
return fakes
@staticmethod
def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False):
gap_width = 16
bottom_lines = []
top_lines = []
for i in range(len(words)):
b, t = detect_text_bounds(words[i])
bottom_lines.append(b)
top_lines.append(t)
if show_lines:
words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0))
words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0))
bottom_lines = np.array(bottom_lines, dtype=float)
if scale_words:
top_lines = np.array(top_lines, dtype=float)
gaps = bottom_lines - top_lines
target_gap = np.mean(gaps)
scales = target_gap / gaps
bottom_lines *= scales
top_lines *= scales
words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)]
highest = np.max(bottom_lines)
offsets = highest - bottom_lines
height = np.max(offsets + [word.shape[0] for word in words])
result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words])))
x_pos = 0
for bottom_line, word in zip(bottom_lines, words):
offset = int(highest - bottom_line)
result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word
x_pos += word.shape[1] + gap_width
return result
AutoConfig.register("vatrpp", VATrPPConfig)
AutoModel.register(VATrPPConfig, VATrPP)