from pathlib import Path
from typing import Tuple

import pyrallis
import torch
from accelerate import Accelerator
from torch import nn
from transformers import CLIPTokenizer

from models.neti_clip_text_encoder import NeTICLIPTextModel
from models.neti_mapper import NeTIMapper
from models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
from config import RunConfig


class CheckpointHandler:

    def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path):
        self.cfg = cfg
        self.placeholder_token_string = placeholder_token_string
        self.placeholder_token_id = placeholder_token_id
        self.save_root = save_root

    def save_model(self, text_encoder: NeTICLIPTextModel,
                   accelerator: Accelerator,
                   embeds_save_name: str,
                   mapper_save_name: str):
        self.save_learned_embeds(text_encoder, accelerator, embeds_save_name)
        self.save_mapper(text_encoder, mapper_save_name)

    def save_learned_embeds(self, text_encoder: NeTICLIPTextModel, accelerator: Accelerator, save_name: str):
        """
        Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference
        to take the place of our placeholder token.
        """
        learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id]
        learned_embeds = learned_embeds.detach().cpu()
        learned_embeds_dict = {self.placeholder_token_string: learned_embeds}
        torch.save(learned_embeds_dict, self.save_root / save_name)

    def save_mapper(self, text_encoder: NeTICLIPTextModel, save_name: str):
        """ Save the mapper and config to be used at inference. """
        cfg_ = RunConfig(**self.cfg.__dict__.copy())
        state_dict = {
            "state_dict": text_encoder.text_model.embeddings.mapper.state_dict(),
            "cfg": pyrallis.encode(cfg_),
            "encoder": text_encoder.text_model.embeddings.mapper.encoder
        }
        torch.save(state_dict, self.save_root / save_name)

    @staticmethod
    def load_mapper(mapper_path: Path) -> Tuple[RunConfig, NeTIMapper]:
        mapper_ckpt = torch.load(mapper_path, map_location="cpu")
        cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg'])
        neti_mapper = NeTIMapper(output_dim=768,
                                 use_nested_dropout=cfg.model.use_nested_dropout,
                                 nested_dropout_prob=cfg.model.nested_dropout_prob,
                                 norm_scale=cfg.model.target_norm,
                                 use_positional_encoding=cfg.model.use_positional_encoding,
                                 num_pe_time_anchors=cfg.model.num_pe_time_anchors,
                                 pe_sigmas=cfg.model.pe_sigmas,
                                 output_bypass=cfg.model.output_bypass)
        neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True)
        encoder = mapper_ckpt['encoder']
        if isinstance(encoder, NeTIPositionalEncoding):
            encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda())
        elif isinstance(encoder, BasicEncoder):
            encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda()
            encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda()
        neti_mapper.encoder = encoder.cuda()
        neti_mapper.cuda()
        neti_mapper.eval()
        return cfg, neti_mapper

    @staticmethod
    def load_learned_embed_in_clip(learned_embeds_path: Path,
                                   text_encoder: NeTICLIPTextModel,
                                   tokenizer: CLIPTokenizer) -> Tuple[str, int]:
        loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")

        # separate token and the embeds
        trained_tokens = list(loaded_learned_embeds.keys())
        embeds = list(loaded_learned_embeds.values())

        # cast to dtype of text_encoder
        dtype = text_encoder.get_input_embeddings().weight.dtype
        embeds = [e.to(dtype) for e in embeds]

        # add the tokens in tokenizer
        num_added_tokens = tokenizer.add_tokens(trained_tokens)
        if num_added_tokens == 0:
            raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. "
                             f"Please pass a different `token` that is not already in the tokenizer.")

        # resize the token embeddings
        text_encoder.resize_token_embeddings(len(tokenizer))

        # get the id for the token and assign the embeds
        placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens]

        for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)):
            text_encoder.get_input_embeddings().weight.data[token_id] = embed

        assert len(trained_tokens) == 1, "Only one placeholder token is supported"
        placeholder_token = trained_tokens[0]
        placeholder_token_id = placeholder_token_ids[0]
        return placeholder_token, placeholder_token_id