Spaces:
Sleeping
Sleeping

update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
from fastapi import APIRouter | |
from datetime import datetime | |
from datasets import load_dataset | |
from sklearn.metrics import accuracy_score | |
from .data.data_loaders import TextDataLoader | |
from .models.text_classifiers import BaselineModel | |
from .utils.evaluation import TextEvaluationRequest | |
from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData | |
# define models | |
from .models.text_classifiers import ModelFactory | |
embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"}) | |
distilbert_model = ModelFactory.create_model({"model_type": | |
"distilbert-pretrained", | |
"model_name": | |
"2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased" | |
}) | |
model_to_evaluate = distilbert_model | |
# define router | |
router = APIRouter() | |
DESCRIPTION = model_to_evaluate.description | |
ROUTE = "/text" | |
async def evaluate_text(request: TextEvaluationRequest, | |
track_emissions: bool = True, | |
model = distilbert_model, | |
light_dataset: bool = False) -> dict: | |
""" | |
Evaluate text classification for climate disinformation detection. | |
Parameters: | |
----------- | |
request: TextEvaluationRequest | |
The request object containing the dataset configuration. | |
track_emissions: bool | |
Whether to track emissions or not. | |
model: TextClassifier | |
The model to use for inference. | |
light_dataset: bool | |
Whether to use a light dataset or not. | |
Returns: | |
-------- | |
dict | |
A dictionary containing the evaluation results. | |
""" | |
# Get space info | |
username, space_url = get_space_info() | |
# Load the dataset | |
test_dataset = TextDataLoader(request, light=light_dataset).get_test_dataset() | |
# Start tracking emissions | |
if track_emissions: | |
tracker = get_tracker() | |
tracker.start() | |
tracker.start_task("inference") | |
# model inference | |
predictions = [model.predict(quote) for quote in test_dataset["quote"]] | |
# Stop tracking emissions | |
if track_emissions: | |
emissions_data = tracker.stop_task() | |
else: | |
emissions_data = EmissionsData(0, 0) | |
# Calculate accuracy | |
true_labels = test_dataset["label"] | |
accuracy = accuracy_score(true_labels, predictions) | |
# Prepare results dictionary | |
results = { | |
"username": username, | |
"space_url": space_url, | |
"submission_timestamp": datetime.now().isoformat(), | |
"model_description": DESCRIPTION, | |
"accuracy": float(accuracy), | |
"energy_consumed_wh": emissions_data.energy_consumed * 1000, | |
"emissions_gco2eq": emissions_data.emissions * 1000, | |
"emissions_data": clean_emissions_data(emissions_data), | |
"api_route": ROUTE, | |
"dataset_config": { | |
"dataset_name": request.dataset_name, | |
"test_size": request.test_size, | |
"test_seed": request.test_seed | |
} | |
} | |
return results | |