Merwan6
Commit initial
0cebe35
raw
history blame
1.13 kB
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
def compute_metrics(preds, labels):
"""
Calcule les métriques de classification à partir des prédictions du modèle
et des labels de vérité terrain (vrais).
Args:
preds (array-like): Les classes prédites par le modèle (entiers).
labels (array-like): Les vraies classes associées aux exemples (entiers).
Returns:
dict: Dictionnaire contenant les métriques suivantes :
- "accuracy" : exactitude globale des prédictions
- "f1" : score F1 pondéré (par classe)
- "precision" : précision pondérée
- "recall" : rappel pondéré
"""
#Calcule précision, rappel et F1 pondérés selon la taille de chaque classe
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
#Calcule l'accuracy brute
acc = accuracy_score(labels, preds)
return {
"accuracy": acc,
"f1": f1,
"precision": precision,
"recall": recall
}