Merwan6 commited on
Commit
eaf4ff4
·
1 Parent(s): 18af9e8

push metrics

Browse files
Files changed (1) hide show
  1. scripts/metric.py +86 -0
scripts/metric.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import pandas as pd
4
+ from tqdm import tqdm # ✅ Ajout ici
5
+ from datasets import load_dataset
6
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, log_loss
7
+ from inference import (
8
+ zero_shot_inference,
9
+ few_shot_inference,
10
+ base_model_inference,
11
+ )
12
+
13
+ # Dictionnaire des fonctions à évaluer
14
+ models_to_evaluate = {
15
+ "Base model": base_model_inference,
16
+ "Zero-shot": zero_shot_inference,
17
+ "Few-shot": few_shot_inference,
18
+ }
19
+
20
+ label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
21
+
22
+ # Charger un sous-ensemble du jeu de test AG News
23
+ dataset = load_dataset("ag_news", split="test[:10%]")
24
+
25
+ def evaluate_model(name, inference_func):
26
+ print(f"\n🔍 Évaluation du modèle : {name}")
27
+ true_labels = []
28
+ pred_labels = []
29
+ all_probs = []
30
+
31
+ start = time.time()
32
+
33
+ for example in tqdm(dataset, desc=f"Modèle : {name}"):
34
+ text = example["text"]
35
+ true_label = example["label"]
36
+ true_class = label_map[true_label]
37
+
38
+ try:
39
+ pred_class, scores = inference_func(text)
40
+ except Exception as e:
41
+ print(f"⚠️ Erreur sur un exemple : {e}")
42
+ continue
43
+
44
+ # Scores pour les 4 classes dans le même ordre
45
+ prob_dist = [scores.get(c, 0.0) for c in label_map.values()]
46
+ pred_index = list(label_map.values()).index(pred_class)
47
+
48
+ pred_labels.append(pred_index)
49
+ true_labels.append(true_label)
50
+ all_probs.append(prob_dist)
51
+
52
+ end = time.time()
53
+ runtime = round(end - start, 2)
54
+
55
+ acc = accuracy_score(true_labels, pred_labels)
56
+ f1 = f1_score(true_labels, pred_labels, average='weighted')
57
+ prec = precision_score(true_labels, pred_labels, average='weighted')
58
+ rec = recall_score(true_labels, pred_labels, average='weighted')
59
+ loss = log_loss(true_labels, all_probs)
60
+
61
+ print(f"✅ Résultats {name} :")
62
+ print(f"- Accuracy : {acc:.4f}")
63
+ print(f"- F1 Score : {f1:.4f}")
64
+ print(f"- Precision : {prec:.4f}")
65
+ print(f"- Recall : {rec:.4f}")
66
+ print(f"- Log Loss : {loss:.4f}")
67
+ print(f"- Runtime : {runtime:.2f} sec\n")
68
+
69
+ return {
70
+ "model": name,
71
+ "accuracy": acc,
72
+ "f1_score": f1,
73
+ "precision": prec,
74
+ "recall": rec,
75
+ "loss": loss,
76
+ "runtime": runtime
77
+ }
78
+
79
+ # Évaluer tous les modèles
80
+ results = []
81
+ for name, func in models_to_evaluate.items():
82
+ results.append(evaluate_model(name, func))
83
+
84
+ # Affichage résumé
85
+ df = pd.DataFrame(results)
86
+ print(df)