from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import random

from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

from transformers import AutoTokenizer,BertForSequenceClassification,AutoModelForSequenceClassification,Trainer, TrainingArguments,DataCollatorWithPadding
from datasets import Dataset
import torch
import numpy as np


router = APIRouter()

DESCRIPTION = "modernBERT_final_original"
ROUTE = "/text"

@router.post(ROUTE, tags=["Text Task"], 
             description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
    """
    Evaluate text classification for climate disinformation detection.
    
    Current Model: Random Baseline
    - Makes random predictions from the label space (0-7)
    - Used as a baseline for comparison
    """
    # Get space info
    username, space_url = get_space_info()

    # Define the label mapping
    LABEL_MAPPING = {
        "0_not_relevant": 0,
        "1_not_happening": 1,
        "2_not_human": 2,
        "3_not_bad": 3,
        "4_solutions_harmful_unnecessary": 4,
        "5_science_unreliable": 5,
        "6_proponents_biased": 6,
        "7_fossil_fuels_needed": 7
    }

    # Load and prepare the dataset
    dataset = load_dataset(request.dataset_name)

    # Convert string labels to integers
    dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})

    # Split dataset
    train_test = dataset["train"]
    test_dataset = dataset["test"]
    
    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")

    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE CODE HERE
    # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
    #--------------------------------------------------------------------------------------------   
    
    # Make random predictions (placeholder for actual model inference)
    true_labels = test_dataset["label"]
    # predictions = [random.randint(0, 7) for _ in range(len(true_labels))]

    # Chemins du modèle et du tokenizer
    path_model = 'MatthiasPicard/modernBERT_final_original'
    path_tokenizer = "answerdotai/ModernBERT-base"
    
    # Détection du GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Charger le modèle et le tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(path_model).half().to(device)  # Model en half precision sur GPU
    tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
    
    # Fonction de préprocessing
    def preprocess_function(df):
        tokenized = tokenizer(df["quote"], truncation=True) # Removed padding here
        return tokenized
    
    # Appliquer le préprocessing
    tokenized_test = test_dataset.map(preprocess_function, batched=True)
    
    # Convertir le dataset au format PyTorch
    tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask"])
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    
    # Créer le DataLoader avec un batch_size > 1 pour optimiser le passage GPU
    batch_size = 16  # Ajuster selon la mémoire dispo sur GPU
    test_loader = DataLoader(tokenized_test, batch_size=batch_size, collate_fn=data_collator)
    
    model = model.half()
    model.eval()
    
    # Inférence sur GPU
    predictions = []
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
    
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
    
            predictions.extend(preds.cpu().numpy())  # Remettre sur CPU pour stockage

    
    # path_model = 'MatthiasPicard/checkpoint4200_batch16_modern_bert_valloss_0.79_0.74acc'
    # path_tokenizer = "answerdotai/ModernBERT-base"
    
    # model = AutoModelForSequenceClassification.from_pretrained(path_model)
    # tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
    
    # def preprocess_function(df):
    #     return tokenizer(df["quote"], truncation=True)
    # tokenized_test = test_dataset.map(preprocess_function, batched=True)
     
    # # training_args = torch.load("training_args.bin")
    # # training_args.eval_strategy='no'

    # model = model.half()
    # model.eval() 
    
    # data_collator = DataCollatorWithPadding(tokenizer)
    
    # trainer = Trainer(
    #     model=model,
    #     # args=training_args,
    #     tokenizer=tokenizer,
    #     data_collator=data_collator
    # )
    
    # trainer.args.per_device_eval_batch_size = 16
    
    # preds = trainer.predict(tokenized_test)
    
    # predictions = np.array([np.argmax(x) for x in preds[0]])

    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE STOPS HERE
    #--------------------------------------------------------------------------------------------   

    
    # Stop tracking emissions
    emissions_data = tracker.stop_task()
    
    # Calculate accuracy
    accuracy = accuracy_score(true_labels, predictions)
    
    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results