import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel

BirdAST_FEATURE_EXTRACTOR = ASTFeatureExtractor()
DEFAULT_SR = 16_000
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_N_CLASSES = 728
DEFAULT_ACTIVATION = "silu"
DEFAULT_N_MLP_LAYERS = 1


def birdast_seq_preprocess(audio_array, sr=DEFAULT_SR):
    """
    Preprocess audio array for BirdAST model
    audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
    sr: int, sampling rate of the audio array (default: 16_000)
    
    Note: 
    1. The audio array should be normalized to [-1, 1].
    2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
    """
    # Extract features
    features = BirdAST_FEATURE_EXTRACTOR(audio_array, sampling_rate=sr, padding="max_length", return_tensors="pt")
    
    # Convert to PyTorch tensor
    spectrogram = torch.tensor(features['input_values']).squeeze(0)
    
    return spectrogram


def birdast_seq_inference(
    model_weights, 
    spectrogram, 
    device = 'cpu', 
    backbone_name=DEFAULT_BACKBONE, 
    n_classes=DEFAULT_N_CLASSES,
    activation=DEFAULT_ACTIVATION,
    n_mlp_layers=DEFAULT_N_MLP_LAYERS
    ):
    
    """
    Perform inference on BirdAST model
    model_weights: list, list of model weights
    spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
    device: str, device to run inference (default: 'cpu')
    backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
    n_classes: int, number of classes (default: 728)
    activation: str, activation function (default: 'silu')
    n_mlp_layers: int, number of MLP layers (default: 1)
    
    Returns:
    predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
    """
    
    model = BirdAST_Seq(
        backbone_name=backbone_name, 
        n_classes=n_classes,
        n_mlp_layers=n_mlp_layers,
        activation=activation
    )
    
    predict_collects = []
    
    for _weight in model_weights:
        model.load_state_dict(torch.load(_weight, map_location=device))
        if device != 'cpu': model.to(device)
        model.eval()
        
        with torch.no_grad():
            if device != 'cpu': spectrogram = spectrogram.to(device)
            
            #check if the input tensor is in the correct shape
            if spectrogram.dim() == 2:
                spectrogram = spectrogram.unsqueeze(0) #-> (batch_size, n_frames, n_mels)
                
            output = model(spectrogram)
            logits = output['logits']
            predictions = F.softmax(logits, dim=1)
            predict_collects.append(predictions)
            
    if device != 'cpu':
        predict_collects = [pred.cpu() for pred in predict_collects]
        
    predict_collects = torch.cat(predict_collects, dim=0).numpy()
            
    return predict_collects

      
class SelfAttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling 
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """
    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, batch_rep):
        """
        input:
            batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
        attention_weight:
            att_w : size (N, T, 1)
        return:
            utter_rep: size (N, H)
        """
        att_w = self.softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep


class BirdAST_Seq(nn.Module):
    
    def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'):
        super(BirdAST_Seq, self).__init__()
        
        # pre-trained backbone
        backbone_config = ASTConfig.from_pretrained(backbone_name)
        self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config)
        self.hidden_size = backbone_config.hidden_size
        
        # set activation functions
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            raise ValueError("Unsupported activation function. Choose 'relu', 'silu' or 'gelu'")
        
        #define self-attention pooling layer
        self.sa_pool = SelfAttentionPooling(self.hidden_size)
        
        # define MLP layers with activation
        layers = []
        for _ in range(n_mlp_layers):
            layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            layers.append(self.activation)
        layers.append(nn.Linear(self.hidden_size, n_classes))
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, spectrogram):
        # spectrogram: (batch_size, n_mels, n_frames)
        # output: (batch_size, n_classes)
        
        ast_output = self.ast(spectrogram, output_hidden_states=False)
        hidden_state = ast_output.last_hidden_state
        pool_output = self.sa_pool(hidden_state)
        logits = self.mlp(pool_output)
        
        return {'logits': logits}
    
    
    
if __name__ == '__main__':
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    # example usage of BirdAST_Seq
    # create random audio array
    audio_array = np.random.randn(160_000 * 10)
    
    # Preprocess audio array
    spectrogram = birdast_seq_preprocess(audio_array)
    
    model_weights_dir = '/workspace/voice_of_jungle/training_logs'
    
    # Load model weights
    model_weights = [f'{model_weights_dir}/BirdAST_SeqPool_GroupKFold_fold_{i}.pth' for i in range(5)]
    
    # Perform inference
    predictions = birdast_seq_inference(model_weights, spectrogram.unsqueeze(0))
    
    # Plot predictions
    fig, ax = plt.subplots()
    for i, pred in enumerate(predictions):
        ax.plot(pred[0], label=f'model_{i}')
    ax.legend()
    fig.savefig('test_BirdAST_Seq.png')
    
    print("Inference completed successfully!")