import torch
import torch.nn as nn
from .Modules.conformer import ConformerEncoder, ConformerDecoder
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
from .kan.fasterkan import FasterKAN
import numpy as np
import xgboost as xgb
import pandas as pd



class Sine(nn.Module):
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sin(self.w0 * x)


class MLPEncoder(nn.Module):
    def __init__(self, args):
        """
        Initialize an MLP with hidden layers, BatchNorm, and Dropout.

        Args:
            input_dim (int): Dimension of the input features.
            hidden_dims (list of int): List of dimensions for hidden layers.
            output_dim (int): Dimension of the output.
            dropout (float): Dropout probability (default: 0.0).
        """
        super(MLPEncoder, self).__init__()

        layers = []
        prev_dim = args.input_dim

        # Add hidden layers
        for hidden_dim in args.hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.SiLU())
            if args.dropout > 0.0:
                layers.append(nn.Dropout(args.dropout))
            prev_dim = hidden_dim
        self.model = nn.Sequential(*layers)
        self.output_dim = hidden_dim

    def forward(self, x):
        # if x.dim() == 2:
        #     x = x.unsqueeze(-1)
        x = self.model(x)
        # x = x.mean(-1)
        return x
class ConvBlock(nn.Module):
  def __init__(self, args, num_layer) -> None:
    super().__init__()
    if args.activation == 'silu':
        self.activation = nn.SiLU()
    elif args.activation == 'sine':
        self.activation = Sine(w0=args.sine_w0)
    else:
        self.activation = nn.ReLU()
    in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
    out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
    self.layers = nn.Sequential(
        nn.Conv1d(in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=args.kernel_size,
                stride=1, padding='same', bias=False),
        nn.BatchNorm1d(num_features=out_channels),
        self.activation,
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:  
    return self.layers(x)

class CNNEncoder(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
        if args.activation == 'silu':
            self.activation = nn.SiLU()
        elif args.activation == 'sine':
            self.activation = Sine(w0=args.sine_w0)
        else:
            self.activation = nn.ReLU()
        self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
                kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
                        nn.BatchNorm1d(args.encoder_dims[0]),
                        self.activation,
        )
        
        self.layers = nn.ModuleList([ConvBlock(args, i+1)
        for i in range(args.num_layers)])
        self.pool = nn.MaxPool1d(2)
        self.output_dim = args.encoder_dims[-1]
        self.min_seq_len = 2 
        self.avg_output = args.avg_output
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if len(x.shape)==2:
            x = x.unsqueeze(1)
        if len(x.shape)==3 and x.shape[-1]==1:
            x = x.permute(0,2,1)
        x = self.embedding(x)
        for m in self.layers:
            x = m(x)
            if x.shape[-1] > self.min_seq_len:
                x = self.pool(x)
        if self.avg_output:
            x = x.mean(dim=-1)
        return x


class MultiEncoder(nn.Module):
    def __init__(self, args, conformer_args):
        super().__init__()
        self.backbone = CNNEncoder(args)
        self.backbone.avg_output = False
        self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
        self.rotary_ndims = int(self.head_size * 0.5)
        self.pe = RotaryEmbedding(self.rotary_ndims)
        self.encoder = ConformerEncoder(conformer_args)
        self.output_dim = conformer_args.encoder_dim
        self.avg_output = args.avg_output
        
    def forward(self, x):
        # Store backbone output in a separate tensor
        backbone_out = self.backbone(x)
        
        # Create x_enc from backbone_out
        if len(backbone_out.shape) == 2:
            x_enc = backbone_out.unsqueeze(1).clone()
        else:
            x_enc = backbone_out.permute(0,2,1).clone()
            
        RoPE = self.pe(x_enc, x_enc.shape[1])
        x_enc = self.encoder(x_enc, RoPE)
        
        if len(x_enc.shape) == 3:
            if self.avg_output:
                x_enc = x_enc.sum(dim=1)
            else:
                x_enc = x_enc.permute(0,2,1)
                
        # Return x_enc and the original backbone output
        return x_enc, backbone_out

class DualEncoder(nn.Module):
    def __init__(self, args_x, args_f, conformer_args) -> None:
        super().__init__()
        self.encoder_x = CNNEncoder(args_x)
        self.encoder_f = MultiEncoder(args_f, conformer_args)
        total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
        self.regressor = nn.Sequential(
            nn.Linear(total_output_dim, total_output_dim//2),
            nn.BatchNorm1d(total_output_dim//2),
            nn.SiLU(),
            nn.Linear(total_output_dim//2, 1)
        )
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.encoder_x(x)
        x2, _ = self.encoder_f(x)
        logits = torch.cat([x1, x2], dim=-1)
        return self.regressor(logits).squeeze()

class CNNFeaturesEncoder(nn.Module):
    def __init__(self, xgb_model, args, mlp_hidden=64):
        super().__init__()
        self.xgb_model = xgb_model
        self.best_xgb_features = xgb_model.best_iteration + 1
        self.backbone = CNNEncoder(args)
        self.total_features = self.best_xgb_features + args.encoder_dims[-1]
        self.mlp = nn.Sequential(
            nn.Linear(self.total_features, mlp_hidden),
            nn.BatchNorm1d(mlp_hidden),
            nn.SiLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.BatchNorm1d(mlp_hidden),
            nn.SiLU(),
            nn.Linear(mlp_hidden, 1),
        )

    def _create_features_data(self, features):
        # Handle batch processing
        batch_size = len(features)
        data = []

        # Iterate through each item in the batch
        for batch_idx in range(batch_size):
            feature_dict = {}
            for k, v in features[batch_idx].items():
                feature_dict[f"frequency_domain_{k}"] = v[0].item()
            data.append(feature_dict)

        return pd.DataFrame(data)
    def forward(self, x: torch.Tensor, f) -> torch.Tensor:
        x = self.backbone(x)
        x = x.mean(dim=-1)
        f_np = self._create_features_data(f)
        dtest = xgb.DMatrix(f_np)  # Convert input to DMatrix
        xgb_features = self.xgb_model.predict(dtest, pred_leaf=True).astype(np.float32)
        xgb_features = torch.tensor(xgb_features, dtype=torch.float32, device=x.device)
        x_f = torch.cat([x, xgb_features[:, :self.best_xgb_features]], dim=1)
        return self.mlp(x_f)

class CNNKan(nn.Module):
    def __init__(self, args, conformer_args, kan_args):
        super().__init__()
        self.backbone = CNNEncoder(args)
        # self.kan = KAN(width=kan_args['layers_hidden'])
        self.kan = FasterKAN(**kan_args)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = x.mean(dim=1)
        return self.kan(x)

class CNNKanFeaturesEncoder(nn.Module):
    def __init__(self, xgb_model, args,  kan_args):
        super().__init__()
        self.xgb_model = xgb_model
        self.best_xgb_features = xgb_model.best_iteration + 1
        self.backbone = CNNEncoder(args)
        kan_args['layers_hidden'][0] += self.best_xgb_features
        self.kan = FasterKAN(**kan_args)

    def _create_features_data(self, features):
        # Handle batch processing
        batch_size = len(features)
        data = []

        # Iterate through each item in the batch
        for batch_idx in range(batch_size):
            feature_dict = {}
            for k, v in features[batch_idx].items():
                feature_dict[f"{k}"] = v[0].item()
            data.append(feature_dict)

        return pd.DataFrame(data)
    def forward(self, x: torch.Tensor, f) -> torch.Tensor:
        x = self.backbone(x)
        x = x.mean(dim=1)
        f_np = self._create_features_data(f)
        dtest = xgb.DMatrix(f_np)  # Convert input to DMatrix
        xgb_features = self.xgb_model.predict(dtest, pred_leaf=True).astype(np.float32)
        xgb_features = torch.tensor(xgb_features, dtype=torch.float32, device=x.device)
        x_f = torch.cat([x, xgb_features[:, :self.best_xgb_features]], dim=1)
        return self.kan(x_f)

class KanEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.kan_x = FasterKAN(**args)
        self.kan_f = FasterKAN(**args)
        self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])

    def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
        x = self.kan_x(x)
        f = self.kan_f(f)
        out = torch.cat([x, f], dim=-1)
        return self.kan_out(out)


class MultiGraph(nn.Module):
    def __init__(self, graph_net, args):
        super().__init__()
        self.graph_net = graph_net
        self.cnn = CNNEncoder(args)
        total_output_dim = args.encoder_dims[-1]
        self.projection = nn.Sequential(
            nn.Linear(total_output_dim, total_output_dim // 2),
            nn.BatchNorm1d(total_output_dim // 2),
            nn.SiLU(),
            nn.Linear(total_output_dim // 2, 1)
        )

    def forward(self, g: torch.Tensor, x:torch.Tensor) -> torch.Tensor:
        # g_out = self.graph_net(g)
        x_out = self.cnn(x)
        # g_out = g_out.expand(x.shape[0], -1)
        # features = torch.cat([g_out, x_out], dim=-1)
        return self.projection(x_out)

class ImplicitEncoder(nn.Module):
    def __init__(self, transform_net, encoder_net):
        super().__init__()
        self.transform_net = transform_net
        self.encoder_net = encoder_net

    def get_weights_and_bises(self):
        state_dict = self.transform_net.state_dict()
        weights = tuple(
            [v.permute(1, 0).unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "weight" in w]
        )
        biases = tuple([v.unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "bias" in w])
        return weights, biases

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        transformed_x = self.transform_net(x.permute(0, 2, 1)).permute(0, 2, 1)
        inputs = self.get_weights_and_bises()
        outputs = self.encoder_net(inputs, transformed_x)
        return outputs