import collections
import dataclasses
import types

import pytorch_lightning as pl
import torch.utils.data
import transformers

from data import (
    generate_annotated_images,
    get_annotation_ground_truth_str,
    DataItem,
    get_extra_tokens,
    Batch,
    Split,
    BatchCollateFunction,
)
from utils import load_pickle_or_build_object_and_save


@dataclasses.dataclass
class Model:
    processor: transformers.models.donut.processing_donut.DonutProcessor
    tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast
    encoder_decoder: transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder.VisionEncoderDecoderModel
    batch_collate_function: BatchCollateFunction
    config: types.SimpleNamespace


def add_unknown_tokens_to_tokenizer(
    tokenizer, encoder_decoder, unknown_tokens: list[str]
):
    tokenizer.add_tokens(unknown_tokens)
    encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))


def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter:
    unknown_tokens_counter = collections.Counter()

    for annotated_image in generate_annotated_images():
        ground_truth = get_annotation_ground_truth_str(annotated_image.annotation)

        input_ids = tokenizer(ground_truth).input_ids
        tokens = tokenizer.tokenize(ground_truth, add_special_tokens=True)

        for token_id, token in zip(input_ids, tokens, strict=True):
            if token_id == tokenizer.unk_token_id:
                unknown_tokens_counter.update([token])

    return unknown_tokens_counter


def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
    tokenizer, token_ids
):
    token_ids[token_ids == tokenizer.pad_token_id] = -100
    return token_ids


@dataclasses.dataclass
class BatchCollateFunction:
    processor: transformers.models.donut.processing_donut.DonutProcessor
    tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast
    decoder_sequence_max_length: int

    def __call__(self, batch: list[DataItem], split: Split) -> Batch:
        images = [di.image for di in batch]
        images = self.processor(
            images, random_padding=split == Split.train, return_tensors="pt"
        ).pixel_values

        target_token_ids = self.tokenizer(
            [di.target_string for di in batch],
            add_special_tokens=False,
            max_length=self.decoder_sequence_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).input_ids
        labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
            self.tokenizer, target_token_ids
        )

        data_indices = [di.data_index for di in batch]

        return Batch(images=images, labels=labels, data_indices=data_indices)


def build_model(config: types.SimpleNamespace or object) -> Model:
    donut_processor = transformers.DonutProcessor.from_pretrained(
        config.pretrained_model_name
    )
    donut_processor.image_processor.size = dict(
        width=config.image_width, height=config.image_height
    )
    donut_processor.image_processor.do_align_long_axis = False

    tokenizer = donut_processor.tokenizer

    encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained(
        config.pretrained_model_name
    )
    encoder_decoder_config.encoder.image_size = (
        config.image_width,
        config.image_height,
    )

    encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(
        config.pretrained_model_name, config=encoder_decoder_config
    )
    encoder_decoder_config.pad_token_id = tokenizer.pad_token_id
    encoder_decoder_config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(
        get_extra_tokens().benetech_prompt
    )
    encoder_decoder_config.bos_token_id = encoder_decoder_config.decoder_start_token_id
    encoder_decoder_config.eos_token_id = tokenizer.convert_tokens_to_ids(
        get_extra_tokens().benetech_prompt_end
    )

    extra_tokens = list(get_extra_tokens().__dict__.values())
    add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, extra_tokens)
    unknown_dataset_tokens = load_pickle_or_build_object_and_save(
        config.unknown_tokens_for_tokenizer_path,
        lambda: list(find_unknown_tokens_for_tokenizer(tokenizer).keys()),
    )
    add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, unknown_dataset_tokens)
    tokenizer.eos_token_id = encoder_decoder_config.eos_token_id

    batch_collate_function = BatchCollateFunction(
        processor=donut_processor,
        tokenizer=tokenizer,
        decoder_sequence_max_length=config.decoder_sequence_max_length,
    )

    return Model(
        processor=donut_processor,
        tokenizer=tokenizer,
        encoder_decoder=encoder_decoder,
        batch_collate_function=batch_collate_function,
        config=config,
    )


def generate_token_strings(
    model: Model, images: torch.Tensor, skip_special_tokens=True
) -> list[str]:
    decoder_output = model.encoder_decoder.generate(
        images,
        max_length=10
        if model.config.debug
        else model.config.decoder_sequence_max_length,
        eos_token_id=model.tokenizer.eos_token_id,
        return_dict_in_generate=True,
    )
    return model.tokenizer.batch_decode(
        decoder_output.sequences, skip_special_tokens=skip_special_tokens
    )


def predict_string(image, model: Model):
    image = model.processor(
        image, random_padding=False, return_tensors="pt"
    ).pixel_values
    string = generate_token_strings(model, image)[0]
    return string


class LightningModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()
        self.model = build_model(config)
        self.encoder_decoder = self.model.encoder_decoder

    def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor:
        loss = self.compute_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0):
        loss = self.compute_loss(batch)
        self.log("val_loss", loss)

    def compute_loss(self, batch: Batch) -> torch.Tensor:
        outputs = self.encoder_decoder(pixel_values=batch.images, labels=batch.labels)
        loss = outputs.loss
        return loss

    def configure_optimizers(self) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.hparams["config"].learning_rate
        )
        return optimizer