from transformers import AutoConfig, AutoModel, PretrainedConfig, CLIPTextConfig, CLIPVisionConfig, PreTrainedModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
from transformers.utils import ModelOutput
import torch
import open_clip
from dataclasses import dataclass
import safetensors.torch
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
import os

HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors"
HF_SAFE_WEIGHTS_NAME_PRIOR = "prior_model.safetensors"

@dataclass
class PriorTransformerOutput(ModelOutput):
    """
    The output of [`PriorTransformer`].

    Args:
        predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
            The predicted CLIP image embedding conditioned on the CLIP text embedding input.
    """

    predicted_image_embedding: torch.FloatTensor

@dataclass
class TextEncoderOutput(ModelOutput):
    """
    Output class for CLIPTextEncoderOnly model to store the outputs in a Hugging Face transformer style.

    Attributes:
        prompt_embeds (torch.Tensor): The embeddings of the input prompts.
        last_hidden_states (torch.Tensor): The last hidden states from the model.
    """
    text_embeds: torch.FloatTensor = None
    last_hidden_state: torch.FloatTensor = None

class CLIPTextEncoderOnlyConfig(CLIPTextConfig):
    model_type = "clip_custom_text_model"

    def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs):
        self.model_name = model_name
        self.pretrained = pretrained
        self.frozen = frozen
        self.lora = lora
        super().__init__(**kwargs)

class CLIPTextEncoderOnly(PreTrainedModel):
    config_class = CLIPTextEncoderOnlyConfig

    def __init__(self, config):
        """
        Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.

        :param model_name: The name or path of the pretrained model.
        :param pretrained: Whether to load the pretrained weights.
        """
        super().__init__(config)
        
        if config.pretrained:
            self.model = CLIPTextModelWithProjection.from_pretrained(config.model_name)
        else:
            base_cfg = CLIPTextConfig.from_pretrained(config.model_name)
            self.model = CLIPTextModelWithProjection(base_cfg)

        if config.lora:
            l_config = LoraConfig(
                r=config.lora.lora_r,
                lora_alpha=config.lora.lora_alpha,
                target_modules=[
                    "k_proj", 
                    "v_proj", 
                    "q_proj", 
                    "out_proj", 
                    "fc1",
                    "fc2",
                    "visual_projection", 
                    "text_projection"
                ],
                lora_dropout=config.lora.lora_dropout,
                bias="lora_only",
            )
            self.model = get_peft_model(self.model, l_config)
        

    def forward(self, input_ids, attention_mask=None, position_ids=None):
        """
        Forward pass of the model.

        :param input_ids: Indices of input sequence tokens in the vocabulary.
        :param attention_mask: Mask to avoid performing attention on padding token indices.
        :param token_type_ids: Segment token indices to indicate first and second portions of the inputs.
        :return: Outputs of the model.
        """
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True)
        return TextEncoderOutput(text_embeds=outputs.text_embeds, last_hidden_state=outputs.last_hidden_state)
    
class CustomTextEncoderOnly(PreTrainedModel):
    def __init__(self, model_name: str, output_hidden_size: int, pretrained: bool = True, frozen: bool = True, last_hidden_state: bool = False, lora: dict = None):
        """
        Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.

        :param model_name: The name or path of the pretrained model.
        :param pretrained: Whether to load the pretrained weights.
        """
        config = AutoModel.from_pretrained(model_name).config
        super().__init__(config)
        self.last_hidden_state = last_hidden_state

        if pretrained:
            self.model = AutoModel.from_pretrained(model_name)
            if frozen:
                for param in self.model.parameters():
                    param.requires_grad = False
        else:
            self.model = AutoModel(config)

        self.fc1 = torch.nn.Linear(self.model.config.hidden_size, output_hidden_size)
        if last_hidden_state:
            self.fc2 = torch.nn.Linear(self.model.config.hidden_size, output_hidden_size)

        if lora:
            l_config = LoraConfig(
                task_type=TaskType.FEATURE_EXTRACTION,
                r=lora.lora_r,
                lora_alpha=lora.lora_alpha,
                lora_dropout=lora.lora_dropout,
                bias="lora_only",
            )
            self.model = get_peft_model(self.model, l_config)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        """
        Forward pass of the model.

        :param input_ids: Indices of input sequence tokens in the vocabulary.
        :param attention_mask: Mask to avoid performing attention on padding token indices.
        :param token_type_ids: Segment token indices to indicate first and second portions of the inputs.
        :return: Outputs of the model.
        """
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        text_embeds = self.fc1(outputs[1])
        last_hidden_state = None
        if self.last_hidden_state:
            last_hidden_state = self.fc2(outputs[0])
        else:
            last_hidden_state = outputs[0]
        return TextEncoderOutput(text_embeds=text_embeds, last_hidden_state=last_hidden_state)

class CLIPVisionEncoderOnlyConfig(PretrainedConfig):
    model_type = "clip_custom_vision_model"

    def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs):
        self.model_name = model_name
        self.pretrained = pretrained
        self.frozen = frozen
        self.lora = lora
        super().__init__(**kwargs)

class CLIPVisionEncoderOnly(PreTrainedModel):
    config_class = CLIPVisionEncoderOnlyConfig

    def __init__(self, config):
        """
        Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.

        :param model_name: The name or path of the pretrained model.
        :param pretrained: Whether to load the pretrained weights.
        """
        super().__init__(config)
        
        if config.pretrained:
            self.model = CLIPVisionModelWithProjection.from_pretrained(config.model_name)
        else:
            base_cfg = CLIPVisionConfig.from_pretrained(config.model_name)
            self.model = CLIPVisionModelWithProjection(base_cfg)

        if config.lora:
            l_config = LoraConfig(
                r=config.lora.lora_r,
                lora_alpha=config.lora.lora_alpha,
                target_modules=[
                    "k_proj", 
                    "v_proj", 
                    "q_proj", 
                    "out_proj", 
                    "fc1",
                    "fc2",
                    "visual_projection", 
                    "text_projection"
                ],
                lora_dropout=config.lora.lora_dropout,
                bias="lora_only",
            )
            self.model = get_peft_model(self.model, l_config)

    def forward(self, data):
        """
        Forward pass of the model.
        """
        return self.model(**data).image_embeds
    
    def parameters(self):
        return self.model.parameters()


class OpenCLIPVisionEncoderOnly(torch.nn.Module):
    def __init__(self, model_name: str, pretrained: bool = True, frozen: bool = False, lora: dict = None):
        """
        Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.

        :param model_name: The name or path of the pretrained model.
        :param pretrained: Whether to load the pretrained weights.
        """
        super().__init__()
        if pretrained:
            model, _ = open_clip.create_model_from_pretrained(f"hf-hub:{model_name}")
            model = model.visual
        else:
            raise NotImplemented
        self.model = model

        if lora:
            l_config = LoraConfig(
                r=lora.lora_r,
                lora_alpha=lora.lora_alpha,
                target_modules=[
                    "k_proj", 
                    "v_proj", 
                    "q_proj", 
                    "out_proj", 
                    "fc1",
                    "fc2",
                    "visual_projection", 
                    "text_projection"
                ],
                lora_dropout=lora.lora_dropout,
                bias="lora_only",
            )
            self.model = get_peft_model(self.model, l_config)

    def forward(self, image):
        """
        Forward pass of the model.
        """
        return self.model(image)
    
    def save_pretrained(self, save_dir):
        tensors = self.model.state_dict()
        safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME)

class CustomPriorModel(torch.nn.Module):
    def __init__(self, in_hidden_state, out_hidden_state):
        """
        Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel.

        :param model_name: The name or path of the pretrained model.
        :param pretrained: Whether to load the pretrained weights.
        """
        super().__init__()
        mid_hidden_state = max(in_hidden_state, out_hidden_state)

        self.fc1 = torch.nn.Linear(in_hidden_state*2, mid_hidden_state)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(mid_hidden_state, out_hidden_state)
    
    def reinitialize_model(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                if len(param.shape) > 1:
                    torch.nn.init.xavier_uniform_(param)
                else:
                    if 'weight' in name:
                        torch.nn.init.normal_(param)
                    else:
                        torch.nn.init.zeros_(param)

    def forward(self, feats):
        """
        Forward pass of the model.
        """
        return PriorTransformerOutput(predicted_image_embedding=self.fc2(self.relu(self.fc1(feats))))
    
    def save_pretrained(self, save_dir):
        pass
        # tensors = self.state_dict()
        # safetensors.torch.save_file(tensors, os.path.join(save_dir, HF_SAFE_WEIGHTS_NAME_PRIOR))


def test_text_model(register=False, upload=False):
    # register the classes
    if register:
        AutoConfig.register("clip_custom_text_model", CLIPTextEncoderOnlyConfig)
        AutoModel.register(CLIPTextEncoderOnlyConfig, CLIPTextEncoderOnly)
        CLIPTextEncoderOnlyConfig.register_for_auto_class()
        CLIPTextEncoderOnly.register_for_auto_class("AutoModel")

    if upload:
        # Initialize the model
        model_name = "openai/clip-vit-base-patch32"
        pretrained=True 
        lora=None

        cfg = CLIPTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora)
        model = CLIPTextEncoderOnly(cfg)
        model.push_to_hub("test-text-hf-upload")

        model = CLIPTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True)

def test_vision_model(register=False, upload=False):
    # register the classes
    if register:
        AutoConfig.register("clip_custom_vision_model", CLIPVisionEncoderOnlyConfig)
        AutoModel.register(CLIPVisionEncoderOnlyConfig, CLIPVisionEncoderOnly)
        CLIPVisionEncoderOnlyConfig.register_for_auto_class()
        CLIPVisionEncoderOnly.register_for_auto_class("AutoModel")

    if upload:
        # Initialize the model
        model_name = "openai/clip-vit-base-patch32"
        pretrained=True 
        lora=None

        cfg = CLIPVisionEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora)
        model = CLIPVisionEncoderOnly(cfg)
        model.push_to_hub("test-vision-hf-upload")

        model = CLIPVisionEncoderOnly.from_pretrained("mpatel57/test-vision-hf-upload", force_download=True)


if __name__ == "__main__":
    test_text_model(register=False, upload=True)
    test_vision_model(register=False, upload=True)