from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
import torch
import torch.nn as nn
from models.model.transformer import Transformer
from models.model.sparse_autoencoder import SparseAutoencoder

class CustomConfig(PretrainedConfig):
    model_type = "custom_model"

    def __init__(self, hidden_size=768, num_attention_heads=12, num_hidden_layers=12, intermediate_size=3072, hidden_dropout_prob=0.1, num_act_classes=5, num_emotion_classes=7, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.num_act_classes = num_act_classes
        self.num_emotion_classes = num_emotion_classes

class CustomModel(PreTrainedModel):
    config_class = CustomConfig

    def __init__(self, config):
        super().__init__(config)
        self.transformer = Transformer(
            src_pad_idx=0,
            trg_pad_idx=0,
            trg_sos_idx=101,
            enc_voc_size=30522,
            dec_voc_size=30522,
            d_model=config.hidden_size,
            max_len=128,
            ffn_hidden=config.intermediate_size,
            n_head=config.num_attention_heads,
            n_layers=config.num_hidden_layers,
            drop_prob=config.hidden_dropout_prob,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
        
        self.batch_norm = nn.BatchNorm1d(config.hidden_size)
        self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
        self.act_classifier = nn.Linear(config.hidden_size, config.num_act_classes)
        self.emotion_classifier = nn.Linear(config.hidden_size, config.num_emotion_classes)
        self.sparse_autoencoder = SparseAutoencoder(
            input_size=config.hidden_size,
            hidden_size=config.hidden_size // 2,
            sparsity_param=0.05,
            beta=3
        )
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        transformer_output = self.transformer.encoder(input_ids, attention_mask)
        transformer_output = self.batch_norm(transformer_output.view(-1, transformer_output.size(-1)))
        transformer_output = self.dropout(transformer_output)
        reconstructed, kl_div, encoded = self.sparse_autoencoder(transformer_output)
        cls_output = reconstructed[:, 0, :]
        act_output = self.act_classifier(cls_output)
        emotion_output = self.emotion_classifier(cls_output)
        return BaseModelOutput(last_hidden_state=cls_output, act_output=act_output, emotion_output=emotion_output, kl_div=kl_div)