import time import pandas as pd from tqdm import tqdm from datasets import load_dataset from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, log_loss from inference import ( zero_shot_inference, few_shot_inference, base_model_inference, ) from datasets import load_dataset # Dictionnaire des fonctions à évaluer models_to_evaluate = { "Base model": base_model_inference, "Zero-shot": zero_shot_inference, "Few-shot": few_shot_inference, } label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} # Charger tout le test set dataset = load_dataset("ag_news", split="test") dataset = dataset.shuffle(seed=42).select(range(500)) def evaluate_model(name, inference_func): print(f"\n🔍 Évaluation du modèle : {name}") true_labels = [] pred_labels = [] all_probs = [] start = time.time() for example in tqdm(dataset, desc=f"Modèle : {name}"): text = example["text"] true_label = example["label"] true_class = label_map[true_label] try: pred_class, scores = inference_func(text) except Exception as e: print(f"⚠️ Erreur sur un exemple : {e}") continue if pred_class not in label_map.values(): print(f"⚠️ Classe prédite inconnue : '{pred_class}', exemple ignoré.") continue prob_dist = [scores.get(c, 0.0) for c in label_map.values()] pred_index = list(label_map.values()).index(pred_class) pred_labels.append(pred_index) true_labels.append(true_label) all_probs.append(prob_dist) end = time.time() runtime = round(end - start, 2) acc = accuracy_score(true_labels, pred_labels) f1 = f1_score(true_labels, pred_labels, average='weighted') prec = precision_score(true_labels, pred_labels, average='weighted') rec = recall_score(true_labels, pred_labels, average='weighted') loss = log_loss(true_labels, all_probs, labels=[0, 1, 2, 3]) print(f"✅ Résultats {name} :") print(f"- Accuracy : {acc:.2f}") print(f"- F1 Score : {f1:.2f}") print(f"- Precision : {prec:.2f}") print(f"- Recall : {rec:.2f}") print(f"- Log Loss : {loss:.2f}") print(f"- Runtime : {runtime:.2f} sec\n") return { "model": name, "accuracy": acc, "f1_score": f1, "precision": prec, "recall": rec, "loss": loss, "runtime": runtime } # Évaluer tous les modèles results = [] for name, func in models_to_evaluate.items(): results.append(evaluate_model(name, func)) # Affichage résumé df = pd.DataFrame(results) df["loss"] = df["loss"].round(4) print(df)