from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score

from .data.data_loaders import TextDataLoader
from .models.text_classifiers import BaselineModel
from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData

# define models
from .models.text_classifiers import ModelFactory
embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"})

distilbert_model = ModelFactory.create_model({"model_type":
                                                  "distilbert-pretrained",
                                              "model_name":
                                                  "2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased"
                                              })


model_to_evaluate = distilbert_model

# define router
router = APIRouter()
DESCRIPTION = model_to_evaluate.description
ROUTE = "/text"

@router.post(ROUTE, tags=["Text Task"], 
             description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest,
                        track_emissions: bool = True,
                        model = distilbert_model,
                        light_dataset: bool = False) -> dict:
    """
    Evaluate text classification for climate disinformation detection.
    
    Parameters:
    -----------
    request: TextEvaluationRequest
        The request object containing the dataset configuration.

    track_emissions: bool
        Whether to track emissions or not.

    model: TextClassifier
        The model to use for inference.

    light_dataset: bool
        Whether to use a light dataset or not.

    Returns:
    --------
    dict
        A dictionary containing the evaluation results.
    """
    # Get space info
    username, space_url = get_space_info()

    # Load the dataset
    test_dataset = TextDataLoader(request, light=light_dataset).get_test_dataset()
    
    # Start tracking emissions
    if track_emissions:
        tracker = get_tracker()
        tracker.start()
        tracker.start_task("inference")

    # model inference
    predictions = [model.predict(quote) for quote in test_dataset["quote"]]

    # Stop tracking emissions
    if track_emissions:
        emissions_data = tracker.stop_task()
    else:
        emissions_data = EmissionsData(0, 0)
    
    # Calculate accuracy
    true_labels = test_dataset["label"]
    accuracy = accuracy_score(true_labels, predictions)
    
    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results