submission / tasks /text.py
pierre-loic's picture
update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from .data.data_loaders import TextDataLoader
from .models.text_classifiers import BaselineModel
from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData
# define models
from .models.text_classifiers import ModelFactory
embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"})
distilbert_model = ModelFactory.create_model({"model_type":
"distilbert-pretrained",
"model_name":
"2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased"
})
model_to_evaluate = distilbert_model
# define router
router = APIRouter()
DESCRIPTION = model_to_evaluate.description
ROUTE = "/text"
@router.post(ROUTE, tags=["Text Task"],
description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest,
track_emissions: bool = True,
model = distilbert_model,
light_dataset: bool = False) -> dict:
"""
Evaluate text classification for climate disinformation detection.
Parameters:
-----------
request: TextEvaluationRequest
The request object containing the dataset configuration.
track_emissions: bool
Whether to track emissions or not.
model: TextClassifier
The model to use for inference.
light_dataset: bool
Whether to use a light dataset or not.
Returns:
--------
dict
A dictionary containing the evaluation results.
"""
# Get space info
username, space_url = get_space_info()
# Load the dataset
test_dataset = TextDataLoader(request, light=light_dataset).get_test_dataset()
# Start tracking emissions
if track_emissions:
tracker = get_tracker()
tracker.start()
tracker.start_task("inference")
# model inference
predictions = [model.predict(quote) for quote in test_dataset["quote"]]
# Stop tracking emissions
if track_emissions:
emissions_data = tracker.stop_task()
else:
emissions_data = EmissionsData(0, 0)
# Calculate accuracy
true_labels = test_dataset["label"]
accuracy = accuracy_score(true_labels, predictions)
# Prepare results dictionary
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results