from transformers import PreTrainedModel, AutoModel, AutoConfig
from .configuration_vatrpp import VATrPPConfig
import os

import cv2
import numpy as np
import torch

from .data.dataset import FolderDataset
from .models.model import VATr
from .util.vision import detect_text_bounds
from torchvision.transforms.functional import to_pil_image


def get_long_tail_chars():
    with open(f"files/longtail.txt", 'r') as f:
        chars = [c.rstrip() for c in f]

    chars.remove('')

    return chars


class VATrPP(PreTrainedModel):
    config_class = VATrPPConfig

    def __init__(self, config: VATrPPConfig) -> None:
        super().__init__(config)
        self.model = VATr(config)
        self.model.eval()

    def set_style_folder(self, style_folder, num_examples=15):
        word_lengths = None
        if os.path.exists(os.path.join(style_folder, "word_lengths.txt")):
            word_lengths = {}
            with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f:
                for line in f:
                    word, length = line.rstrip().split(",")
                    word_lengths[word] = int(length)

        self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths)

    @torch.no_grad()
    def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False):
        style_images = style_imgs.unsqueeze(0).to(self.model.args.device)

        fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once)
        return to_pil_image(fake)

    @torch.no_grad()
    def create_fake_sentence(self, style_images, text, align_words=False, at_once=False):
        text = "".join([c for c in text if c in self.model.args.alphabet])

        text = text.split() if not at_once else [text]
        gap = np.ones((32, 16))

        text_encode, len_text, encode_pos = self.model.netconverter.encode(text)
        text_encode = text_encode.to(self.model.args.device).unsqueeze(0)

        fake = self.model._generate_fakes(style_images, text_encode, len_text)
        if not at_once:
            if align_words:
                fake = self.stitch_words(fake, show_lines=False)
            else:
                fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16]
        else:
            fake = fake[0]
        fake = (fake * 255).astype(np.uint8)

        return fake

    @torch.no_grad()
    def generate_batch(self, style_imgs, text):
        """
        Given a batch of style images and text, generate images using the model
        """
        device = self.model.args.device
        text_encode, _, _ = self.model.netconverter.encode(text)
        fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device))
        return fakes

    @staticmethod
    def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False):
        gap_width = 16

        bottom_lines = []
        top_lines = []
        for i in range(len(words)):
            b, t = detect_text_bounds(words[i])
            bottom_lines.append(b)
            top_lines.append(t)
            if show_lines:
                words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0))
                words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0))

        bottom_lines = np.array(bottom_lines, dtype=float)

        if scale_words:
            top_lines = np.array(top_lines, dtype=float)
            gaps = bottom_lines - top_lines
            target_gap = np.mean(gaps)
            scales = target_gap / gaps

            bottom_lines *= scales
            top_lines *= scales
            words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)]

        highest = np.max(bottom_lines)
        offsets = highest - bottom_lines
        height = np.max(offsets + [word.shape[0] for word in words])

        result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words])))

        x_pos = 0
        for bottom_line, word in zip(bottom_lines, words):
            offset = int(highest - bottom_line)

            result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word

            x_pos += word.shape[1] + gap_width

        return result


AutoConfig.register("vatrpp", VATrPPConfig)
AutoModel.register(VATrPPConfig, VATrPP)