from .isoformer_config import IsoformerConfig
from transformers import PreTrainedModel
from .modeling_esm import NTForMaskedLM, MultiHeadAttention
from .esm_config import NTConfig
from .modeling_esm_original import EsmForMaskedLM
from transformers.models.esm.configuration_esm import EsmConfig
from enformer_pytorch import Enformer, str_to_one_hot, EnformerConfig
import torch
from torch import nn

class Isoformer(PreTrainedModel):
    config_class = IsoformerConfig

    def __init__(self, config):
        super().__init__(config)


        self.esm_config = EsmConfig(
            vocab_size=config.esm_vocab_size,
            mask_token_id=config.esm_mask_token_id,
            pad_token_id=config.esm_pad_token_id,
            hidden_size=config.esm_hidden_size,
            num_hidden_layers=config.esm_num_hidden_layers,
            num_attention_heads=config.esm_num_attention_heads,
            intermediate_size=config.esm_intermediate_size,
            max_position_embeddings=config.esm_max_position_embeddings,
            token_dropout=config.esm_token_dropout,
            emb_layer_norm_before=config.esm_emb_layer_norm_before,
            attention_probs_dropout_prob=0.0,
            hidden_dropout_prob=0.0,
            use_cache=False,
            add_bias_fnn=config.esm_add_bias_fnn,
            position_embedding_type="rotary",
            tie_word_embeddings=False,
        )

        self.nt_config = NTConfig(
            vocab_size=config.nt_vocab_size,
            mask_token_id=config.nt_mask_token_id,
            pad_token_id=config.nt_pad_token_id,
            hidden_size=config.nt_hidden_size,
            num_hidden_layers=config.nt_num_hidden_layers,
            num_attention_heads=config.nt_num_attention_heads,
            intermediate_size=config.nt_intermediate_size,
            max_position_embeddings=config.nt_max_position_embeddings,
            token_dropout=config.nt_token_dropout,
            emb_layer_norm_before=config.nt_emb_layer_norm_before,
            attention_probs_dropout_prob=0.0,
            hidden_dropout_prob=0.0,
            use_cache=False,
            add_bias_fnn=config.nt_add_bias_fnn,
            position_embedding_type="rotary",
            tie_word_embeddings=False,
        )
        self.config = config

        self.esm_model = EsmForMaskedLM(self.esm_config) 
        self.nt_model = NTForMaskedLM(self.nt_config) 
        self.enformer_model = Enformer.from_pretrained("EleutherAI/enformer-official-rough")

        self.cross_attention_layer_rna = MultiHeadAttention(
            config=EsmConfig(
                num_attention_heads=config.num_heads_omics_cross_attention,
                attention_head_size=3072 // config.num_heads_omics_cross_attention,
                hidden_size=3072,
                attention_probs_dropout_prob=0,
                max_position_embeddings=0
            ),
            omics_of_interest_size=3072,
            other_omic_size=768
        )
        self.cross_attention_layer_protein = MultiHeadAttention(
            config=EsmConfig(
                num_attention_heads=config.num_heads_omics_cross_attention,
                attention_head_size=3072 // config.num_heads_omics_cross_attention,
                hidden_size=3072,
                attention_probs_dropout_prob=0,
                max_position_embeddings=0
            ),
            omics_of_interest_size=3072,
            other_omic_size=640
        )

        self.head_layer_1 = nn.Linear(3072, 2 * 3072)
        self.head_layer_2 = nn.Linear(2 * 3072, 30)

    def forward(
            self,
            tensor_dna,
            tensor_rna,
            tensor_protein,
            attention_mask_rna,
            attention_mask_protein
    ):
        tensor_dna = tensor_dna[:, 1:] # remove CLS
        dna_embedding = self.enformer_model(
            tensor_dna,
            return_only_embeddings=True
            # attention_mask=attention_mask_dna,
            # encoder_attention_mask=attention_mask_dna,
            # output_hidden_states=True
        )
        protein_embedding = self.esm_model(
            tensor_protein,
            attention_mask=attention_mask_protein,
            encoder_attention_mask=attention_mask_protein,
            output_hidden_states=True
        )
        rna_embedding = self.nt_model(
            tensor_rna,
            attention_mask=attention_mask_rna,
            encoder_attention_mask=attention_mask_rna,
            output_hidden_states=True
        )

        encoder_attention_mask = torch.unsqueeze(torch.unsqueeze(tensor_rna != 1, 0),0).repeat(1,1,dna_embedding.shape[1],1)
        rna_to_dna = self.cross_attention_layer_rna.forward(
            hidden_states=dna_embedding,
            encoder_hidden_states=rna_embedding["hidden_states"][-1],
            encoder_attention_mask=encoder_attention_mask
        )

        final_dna_embeddings = self.cross_attention_layer_protein.forward(
            hidden_states=rna_to_dna["embeddings"],
            encoder_hidden_states=protein_embedding["hidden_states"][-1],
        )["embeddings"]

        sequence_mask = torch.zeros(final_dna_embeddings.shape[1])
        sequence_mask[self.config.pool_window_start:self.config.pool_window_end] = 1
        x = torch.sum(torch.einsum('ijk,j->ijk', final_dna_embeddings, sequence_mask),axis=1)/torch.sum(sequence_mask)
        x = self.head_layer_1(x)
        x = torch.nn.functional.softplus(x)
        x = self.head_layer_2(x)


        return {
            "gene_expression_predictions": x,
            "final_dna_embeddings": final_dna_embeddings,
        }