from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import pickle
import xgboost

import random
import os

from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from .utils.preprocess import resample_audio, create_mel_spectrogram

from dotenv import load_dotenv
load_dotenv()

router = APIRouter()

DESCRIPTION = "Random Baseline"
ROUTE = "/audio"



@router.post(ROUTE, tags=["Audio Task"],
             description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
    """
    Evaluate audio classification for rainforest sound detection.
    
    Current Model: Random Baseline
    - Makes random predictions from the label space (0-1)
    - Used as a baseline for comparison
    """
    # Get space info
    print("start audio")
    username, space_url = get_space_info()
    print(username)
    print(space_url)
    
    # Define the label mapping
    LABEL_MAPPING = {
        "chainsaw": 0,
        "environment": 1
    }
    # Load and prepare the dataset
    # Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
    dataset = load_dataset(request.dataset_name,token=os.getenv("HF_TOKEN"))
    
    # Split dataset
    train = dataset["train"]
    test = dataset["test"]

    #preprocess data: resample data to be on the same sampling rate
    target_sr = 12000
    test_df = pd.DataFrame(test)
    test_df["array"] = test_df["audio"].apply(lambda x: x['array'])
    test_df["sampling_rate"] = test_df["audio"].apply(lambda x: x['sampling_rate'])   
    test_df["resampled_array"] = test_df.apply(
        lambda row: resample_audio(row["array"], row["sampling_rate"], target_sr=target_sr), axis=1
    )
    test_df["sampling_rate"] = target_sr

    features = []
    for idx, row in test_df.iterrows():
        features.append(create_mel_spectrogram(row['resampled_array'], row['sampling_rate']))

    # Convert features to a numpy array and add to the DataFrame
    test_df['basic_melspect'] = features

    # Filter on samples with the same mel spectogram shape
    test_df["shape"] = test_df['basic_melspect'].apply(lambda x: x.shape[1])
    test_df = test_df[test_df["shape"]==71]

    
    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")
    
    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE CODE HERE
    # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
    #--------------------------------------------------------------------------------------------   
    
    # Make random predictions (placeholder for actual model inference)
    with open("./train_models/xgboost_audio_model.pkl", "rb") as f:
        loaded_model = pickle.load(f)

    # Flatten Mel Spectrograms into 1D Features
    test_df["flattened_mel"] = test_df["basic_melspect"].apply(lambda x: x.flatten())

    # Convert to NumPy arrays
    X = np.stack(test_df["flattened_mel"].values)  # Features
    y = test_df["label"].values  # Labels (0: chainsaw, 1: rainforest)

    dtest = xgboost.DMatrix(X, label=y)
    # Make Predictions
    y_pred_probs = loaded_model.predict(dtest)
    y_pred = (y_pred_probs > 0.5).astype(int)  # Convert probabilities to binary labels

    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE STOPS HERE
    #--------------------------------------------------------------------------------------------   
    
    # Stop tracking emissions
    emissions_data = tracker.stop_task()
    
    # Calculate accuracy
    accuracy = accuracy_score(y, y_pred)
    
    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results