from typing import Optional, Sequence
import torch
from dataclasses import dataclass
from torch import nn, Tensor
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
from transformers.utils import ModelOutput

#from huggingface_hub import notebook_login
#from transformers import AutoConfig, AutoModel
#from autoencoder_model.modeling_autoencoder import AutoEncoder, AutoEncoderConfig

#notebook_login()

# Register Huggingface Model
#AutoEncoderConfig.register_for_auto_class()
#AutoEncoder.register_for_auto_class("AutoModel")

#AutoConfig.register("autoencoder", AutoEncoderConfig)
#AutoModel.register(AutoEncoderConfig, AutoModel)

# Create Model
#autoencoder = AutoEncoder(AutoEncoderConfig())
#autoencoder.push_to_hub("autoencoder")

# Download Model
#config = AutoConfig.from_pretrained("amaye15/autoencoder", trust_remote_code = True)
#autoencoder = AutoModel.from_config(config, trust_remote_code = True)



# Stucture
    # Example
    # Model Outputs
    # Model Configuration
    # Model Layers
    # Model


##########################################################################################
#################################### Outputs #############################################
##########################################################################################

@dataclass
class AutoencoderModelOutput(ModelOutput):
    """
    Represents the output of an autoencoder model. This class holds various
    important tensors that are the result of passing data through an autoencoder.

    Attributes:
        logits (torch.FloatTensor, optional): The reconstructed output from the autoencoder.
            This is typically the direct output of the decoder part of the model.
        labels (torch.FloatTensor, optional): The true labels associated with the input data,
            if available. Useful for supervised training scenarios or evaluation.
        hidden_state (torch.FloatTensor, optional): The encoded representation of the input data.
            This is the output of the encoder part of the model and serves as a compressed
            representation of the input data.
        loss (torch.FloatTensor, optional): The computed loss value when comparing the reconstructed
            output to the original input data. This is essential for training and evaluating the model's performance.
    """
    logits: torch.FloatTensor = None
    labels: torch.FloatTensor = None
    hidden_state: torch.FloatTensor = None
    loss: torch.FloatTensor = None

##########################################################################################
################################# Configuration ##########################################
##########################################################################################

class AutoEncoderConfig(PretrainedConfig):
    """
    Configuration class for AutoEncoder. This class stores the parameters for the autoencoder model.
    
    Attributes:
        input_dim (int): The dimensionality of the input data. Default is 128.
        latent_dim (int): The dimensionality of the latent representation. Default is 64.
        layer_types (str): The type of layers used, e.g., 'linear', 'lstm', 'gru', 'rnn'. Default is 'linear'.
        dropout_rate (float): The dropout rate applied after each layer (except for the last layer). Default is 0.1.
        num_layers (int): The number of layers in the encoder/decoder. Default is 3.
        compression_rate (float): Factor by which to compress the dimensions through layers. Default is 0.5.
        bidirectional (bool): Whether the sequence layers should be bidirectional. Default is False.
        embed (bool): Whether to use embedding for input data. If True, `vocab_size` and `max_position` must be specified. Default is False.
        vocab_size (int): The size of the vocabulary. Required if `embed` is True.
        max_position (int): The maximum position for positional encoding. Required if `embed` is True.

    Raises:
        ValueError: If `embed` is True and either `vocab_size` or `max_position` is not defined as an integer.
    """
    model_type = "autoencoder"

    def __init__(
        self, 
        input_dim: int = 128, 
        latent_dim: int = 64, 
        layer_types: str = 'linear', 
        dropout_rate: float = 0.1, 
        num_layers: int = 3, 
        compression_rate: float = 0.5, 
        bidirectional: bool = False,
        embed: bool = False,
        vocab_size: int|bool = False,
        max_position: int|bool = False,
        pad_token_id: int = 0,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.layer_types = layer_types
        self.dropout_rate = dropout_rate
        self.num_layers = num_layers
        self.compression_rate = compression_rate
        self.bidirectional = bidirectional
        self.embed = embed
        self.vocab_size = vocab_size
        self.max_position = max_position
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        if self.embed:
            if not self.vocab_size and isinstance(self.vocab_size, int):
                raise ValueError("vocab_size needs to be defined when embed is True - AutoEncoderConfig(embed = True, vocab_size = 10_000, max_postion = 512")
            if not self.max_position and isinstance(self.max_position, int):
                raise ValueError("max_position needs to be defined when embed is True - AutoEncoderConfig(embed = True, vocab_size = 10_000, max_postion = 512)")

##########################################################################################
############################# Block/Encoder/Decoder ######################################
##########################################################################################

def create_layers(
    model_section: str, 
    layer_types: str, 
    input_dim: int, 
    latent_dim: int, 
    num_layers: int, 
    dropout_rate: float, 
    compression_rate: float, 
    bidirectional: bool,
    classes: bool|int = False
) -> nn.Sequential:
    """
    Creates a sequence of layers for the encoder or decoder part of the autoencoder.

    Args:
        model_section (str): A string indicating whether this is for 'encoder' or 'decoder'.
        layer_types (str): The type of layers to include in the sequence.
        input_dim (int): The input dimension for the first layer.
        latent_dim (int): The target dimension for the latent representation.
        num_layers (int): The number of layers to create.
        dropout_rate (float): The dropout rate to apply between layers.
        compression_rate (float): The compression rate for reducing dimensions through layers.
        bidirectional (bool): Whether the RNN layers should be bidirectional.
        classes (bool|int): If an integer is provided, it defines the output dimension of the last layer in the decoder. 
                            It's ignored for the encoder or if the value is False.

    Returns:
        A nn.Sequential module containing the created layers. The configuration of these layers is determined by the arguments provided.

    Raises:
        ValueError: If certain layer type conditions are not met or if required parameters for specific configurations are missing.
    """
  
    layers = []  # Initialize an empty list to store the layers.
    current_dim = input_dim  # Start with the initial input dimension.

    # Lists to store input and output dimensions for each layer.
    input_dimensions = []
    output_dimensions = []

    # Calculate input and output dimensions for each layer.
    for _ in range(num_layers):
        input_dimensions.append(current_dim)  # Store current dimension.
        next_dim = max(int(current_dim * compression_rate), latent_dim)  # Calculate next dimension with compression.
        current_dim = next_dim  # Update current dimension.
        output_dimensions.append(current_dim)  # Store output dimension.

    # Ensure the last layer's output dimension is the latent dimension.
    output_dimensions[num_layers - 1] = latent_dim

    # Adjust dimensions for decoder configuration.
    if model_section == "decoder":
        # Swap input and output dimensions for decoder.
        input_dimensions, output_dimensions = output_dimensions, input_dimensions
        input_dimensions.reverse()  # Reverse the order for decoder stack.
        output_dimensions.reverse()

        # Set the final layer's dimension to classes if specified and valid.
        if isinstance(classes, int) and not isinstance(classes, bool):
            if bidirectional:
                output_dimensions[-1] = classes//2
            else: 
                output_dimensions[-1] = classes

        # Adjust dimensions for bidirectional RNN layers.
        if bidirectional and (layer_types in ['lstm', 'rnn', 'gru']):
            output_dimensions = [2 * value for value in output_dimensions]

    # Construct layers based on the specified layer type.
    for idx, (input_dim, output_dim) in enumerate(zip(input_dimensions, output_dimensions)):
        # Add layers according to the specified type.
        if layer_types == 'linear':
            layers.append(nn.Linear(input_dim, output_dim))
        elif layer_types in ['lstm', 'rnn', 'gru']:
            rnn_layer = getattr(nn, layer_types.upper())  # Dynamically get the RNN layer class.
            half_output_dim = output_dim // (2 if bidirectional else 1)
            if model_section == "decoder":
                if idx == 0:
                    layers.append(rnn_layer(input_dim, half_output_dim, batch_first=True, bidirectional=bidirectional))
                else: 
                    layers.append(rnn_layer(input_dim*2, half_output_dim, batch_first=True, bidirectional=bidirectional))
            else:
                layers.append(rnn_layer(input_dim, half_output_dim, batch_first=True, bidirectional=bidirectional))
        # Add dropout layer between layers, except for the last layer.
        if (idx != num_layers - 1) and (dropout_rate is not None):
            layers.append(nn.Dropout(dropout_rate))

    # Return the sequence of layers as an nn.Sequential module.
    return nn.Sequential(*layers)

##########################################################################################
##################################### Model ##############################################
##########################################################################################

class AutoEncoder(PreTrainedModel):
    """
    AutoEncoder model for creating an encoder-decoder architecture.
    
    Inherits from PreTrainedModel to utilize its pretrained model features from the Hugging Face library.
    
    Args:
        config (AutoEncoderConfig): The configuration instance with all model parameters.
    """
    config_class = AutoEncoderConfig
    
    def __init__(self, config: AutoEncoderConfig):
        super(AutoEncoder, self).__init__(config)

         # Embeddings
        if config.embed:
            # Word Embeddings
            self.word_embeddings = nn.Embedding(config.vocab_size, 
                                                config.input_dim,
                                                config.pad_token_id,)
            # Postional Embeddings
            self.position_embeddings = nn.Embedding(config.max_position, 
                                                    config.input_dim,)
        # Encoder
        self.encoder = create_layers("encoder", 
                                     config.layer_types, 
                                     config.input_dim, 
                                     config.latent_dim,
                                     config.num_layers, 
                                     config.dropout_rate,
                                     config.compression_rate, 
                                     config.bidirectional,)
        # Decoder
        if config.embed:
            # Assuming symmetry between encoder and decoder
            self.decoder = create_layers("decoder",
                                         config.layer_types, 
                                         config.input_dim, 
                                         config.latent_dim, 
                                         config.num_layers, 
                                         config.dropout_rate, 
                                         config.compression_rate,
                                         config.bidirectional, 
                                         config.vocab_size,)
        else:
            # Assuming symmetry between encoder and decoder
            self.decoder = create_layers("decoder",
                                         config.layer_types, 
                                         config.input_dim, 
                                         config.latent_dim, 
                                         config.num_layers, 
                                         config.dropout_rate, 
                                         config.compression_rate,
                                         config.bidirectional,)


    def forward(self, input_ids: Tensor, position_ids: Optional[Tensor] = None, labels: Optional[Tensor] = None) -> Tensor:

        # Define Data Class
        outputs = AutoencoderModelOutput()

        outputs.labels = labels if labels != None else input_ids
        
        # Embeddings
        if self.config.embed:
            # Word Embeddings
            input_embeddings = self.word_embeddings(input_ids)
            
            # Positional Embeddings
            seq_length = input_ids.size(1)
            position_ids = position_ids or torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
            position_embeddings = self.position_embeddings(position_ids)
            
            # Combine Embeddings
            input_ids = input_embeddings + position_embeddings

        # Non-Linear Encoding & Decoding
        if self.config.layer_types in ['lstm', 'rnn', 'gru']:
            # Encoding
            for layer in self.encoder:
                if isinstance(layer, nn.LSTM):
                    input_ids, (h_n, c_n) = layer(input_ids)

                elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
                    input_ids, h_o = layer(input_ids)

                else:
                    input_ids = layer(input_ids)
            # Hidden Vector
            outputs.hidden_state = input_ids
            # Decoding
            for layer in self.decoder:
                if isinstance(layer, nn.LSTM):
                    input_ids, (h_n, c_n) = layer(input_ids)

                elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
                    input_ids, h_o = layer(input_ids)

                else:
                    input_ids = layer(input_ids)

        # Linear Encoding & Decoding
        else:
            # Encoding
            input_ids = self.encoder(input_ids)
            # Hidden Vector
            outputs.hidden_state = input_ids
            # Decoding
            input_ids = self.decoder(input_ids)
        
        outputs.logits = input_ids
        
        # Choose loss function based on dtype
        if torch.is_floating_point(outputs.labels):
            loss_fn = nn.MSELoss()
            outputs.loss = loss_fn(outputs.logits.view(-1), outputs.labels.view(-1))
        elif not torch.is_floating_point(outputs.labels) and not torch.is_complex(outputs.labels):
            loss_fn = nn.CrossEntropyLoss()
            outputs.loss = loss_fn(outputs.logits.reshape(-1, self.config.vocab_size), outputs.labels.view(-1))
        else:
            raise ValueError("Unsupported tensor dtype for these loss functions")

        return outputs