Spaces:
Sleeping
Sleeping
Merwan6
commited on
Commit
·
11204e4
1
Parent(s):
eaf4ff4
modif
Browse files- .DS_Store +0 -0
- scripts/inference.py +25 -16
- scripts/metric.py +2 -3
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
scripts/inference.py
CHANGED
@@ -31,36 +31,45 @@ def zero_shot_inference(text):
|
|
31 |
|
32 |
def few_shot_inference(text):
|
33 |
"""
|
34 |
-
|
35 |
|
36 |
Args:
|
37 |
text (str): Texte à classifier.
|
38 |
|
39 |
Returns:
|
40 |
tuple:
|
41 |
-
- str: Label prédit.
|
42 |
-
- dict:
|
43 |
"""
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
examples = [
|
48 |
("The president met the UN delegation to discuss global peace.", "World"),
|
49 |
("The football team won their match last night.", "Sports"),
|
50 |
("The company reported a big profit this quarter.", "Business"),
|
51 |
("New research in AI shows promising results.", "Sci/Tech")
|
52 |
]
|
53 |
-
|
54 |
-
#Construction du prompt avec des exemples
|
55 |
-
prompt = ""
|
56 |
-
for example_text, example_label in examples:
|
57 |
-
prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
|
58 |
-
prompt += f"Text: {text}\nLabel:"
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return prediction, scores
|
65 |
|
66 |
|
|
|
31 |
|
32 |
def few_shot_inference(text):
|
33 |
"""
|
34 |
+
Classification few-shot avec FLAN-T5 : génère uniquement le label (World, Sports, etc.).
|
35 |
|
36 |
Args:
|
37 |
text (str): Texte à classifier.
|
38 |
|
39 |
Returns:
|
40 |
tuple:
|
41 |
+
- str: Label prédit (nettoyé et validé).
|
42 |
+
- dict: Détails du texte généré brut.
|
43 |
"""
|
44 |
+
model_name = "google/flan-t5-small"
|
45 |
+
classifier = pipeline("text2text-generation", model=model_name, max_new_tokens=10)
|
46 |
+
|
47 |
examples = [
|
48 |
("The president met the UN delegation to discuss global peace.", "World"),
|
49 |
("The football team won their match last night.", "Sports"),
|
50 |
("The company reported a big profit this quarter.", "Business"),
|
51 |
("New research in AI shows promising results.", "Sci/Tech")
|
52 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# Prompt few-shot
|
55 |
+
prompt = "Classify the following text into one of the following categories: World, Sports, Business, Sci/Tech.\n\n"
|
56 |
+
for ex_text, ex_label in examples:
|
57 |
+
prompt += f"Text: {ex_text}\nCategory: {ex_label}\n\n"
|
58 |
+
prompt += f"Text: {text}\nCategory:"
|
59 |
+
|
60 |
+
# Génération
|
61 |
+
output = classifier(prompt)[0]["generated_text"].strip()
|
62 |
+
|
63 |
+
# Nettoyage du label
|
64 |
+
output_clean = output.split()[0].rstrip(".").capitalize() # ex : "sci/tech." → "Sci/tech"
|
65 |
+
|
66 |
+
# Mapping pour être sûr que ça correspond à une catégorie connue
|
67 |
+
candidate_labels = ["World", "Sports", "Business", "Sci/Tech"]
|
68 |
+
prediction = next((label for label in candidate_labels if label.lower() in output_clean.lower()), "Unknown")
|
69 |
+
|
70 |
+
# Fausse distribution (1.0 pour la classe prédite, 0.0 pour les autres)
|
71 |
+
scores = {label: 1.0 if label == prediction else 0.0 for label in candidate_labels}
|
72 |
+
|
73 |
return prediction, scores
|
74 |
|
75 |
|
scripts/metric.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import time
|
2 |
-
import numpy as np
|
3 |
import pandas as pd
|
4 |
from tqdm import tqdm # ✅ Ajout ici
|
5 |
from datasets import load_dataset
|
@@ -20,7 +19,7 @@ models_to_evaluate = {
|
|
20 |
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
21 |
|
22 |
# Charger un sous-ensemble du jeu de test AG News
|
23 |
-
dataset = load_dataset("ag_news", split="test[:
|
24 |
|
25 |
def evaluate_model(name, inference_func):
|
26 |
print(f"\n🔍 Évaluation du modèle : {name}")
|
@@ -56,7 +55,7 @@ def evaluate_model(name, inference_func):
|
|
56 |
f1 = f1_score(true_labels, pred_labels, average='weighted')
|
57 |
prec = precision_score(true_labels, pred_labels, average='weighted')
|
58 |
rec = recall_score(true_labels, pred_labels, average='weighted')
|
59 |
-
loss = log_loss(true_labels, all_probs)
|
60 |
|
61 |
print(f"✅ Résultats {name} :")
|
62 |
print(f"- Accuracy : {acc:.4f}")
|
|
|
1 |
import time
|
|
|
2 |
import pandas as pd
|
3 |
from tqdm import tqdm # ✅ Ajout ici
|
4 |
from datasets import load_dataset
|
|
|
19 |
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
20 |
|
21 |
# Charger un sous-ensemble du jeu de test AG News
|
22 |
+
dataset = load_dataset("ag_news", split="test[:3]")
|
23 |
|
24 |
def evaluate_model(name, inference_func):
|
25 |
print(f"\n🔍 Évaluation du modèle : {name}")
|
|
|
55 |
f1 = f1_score(true_labels, pred_labels, average='weighted')
|
56 |
prec = precision_score(true_labels, pred_labels, average='weighted')
|
57 |
rec = recall_score(true_labels, pred_labels, average='weighted')
|
58 |
+
loss = log_loss(true_labels, all_probs, labels=[0, 1, 2, 3])
|
59 |
|
60 |
print(f"✅ Résultats {name} :")
|
61 |
print(f"- Accuracy : {acc:.4f}")
|