Merwan6 commited on
Commit
11204e4
·
1 Parent(s): eaf4ff4
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. scripts/inference.py +25 -16
  3. scripts/metric.py +2 -3
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
scripts/inference.py CHANGED
@@ -31,36 +31,45 @@ def zero_shot_inference(text):
31
 
32
  def few_shot_inference(text):
33
  """
34
- Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering).
35
 
36
  Args:
37
  text (str): Texte à classifier.
38
 
39
  Returns:
40
  tuple:
41
- - str: Label prédit.
42
- - dict: Scores pour chaque classe.
43
  """
44
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
45
-
46
- #Exemples donnés au modèle pour le guider (prompt engineering)
47
  examples = [
48
  ("The president met the UN delegation to discuss global peace.", "World"),
49
  ("The football team won their match last night.", "Sports"),
50
  ("The company reported a big profit this quarter.", "Business"),
51
  ("New research in AI shows promising results.", "Sci/Tech")
52
  ]
53
-
54
- #Construction du prompt avec des exemples
55
- prompt = ""
56
- for example_text, example_label in examples:
57
- prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
58
- prompt += f"Text: {text}\nLabel:"
59
 
60
- candidate_labels = list(id2label.values())
61
- result = classifier(prompt, candidate_labels)
62
- prediction = result["labels"][0]
63
- scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  return prediction, scores
65
 
66
 
 
31
 
32
  def few_shot_inference(text):
33
  """
34
+ Classification few-shot avec FLAN-T5 : génère uniquement le label (World, Sports, etc.).
35
 
36
  Args:
37
  text (str): Texte à classifier.
38
 
39
  Returns:
40
  tuple:
41
+ - str: Label prédit (nettoyé et validé).
42
+ - dict: Détails du texte généré brut.
43
  """
44
+ model_name = "google/flan-t5-small"
45
+ classifier = pipeline("text2text-generation", model=model_name, max_new_tokens=10)
46
+
47
  examples = [
48
  ("The president met the UN delegation to discuss global peace.", "World"),
49
  ("The football team won their match last night.", "Sports"),
50
  ("The company reported a big profit this quarter.", "Business"),
51
  ("New research in AI shows promising results.", "Sci/Tech")
52
  ]
 
 
 
 
 
 
53
 
54
+ # Prompt few-shot
55
+ prompt = "Classify the following text into one of the following categories: World, Sports, Business, Sci/Tech.\n\n"
56
+ for ex_text, ex_label in examples:
57
+ prompt += f"Text: {ex_text}\nCategory: {ex_label}\n\n"
58
+ prompt += f"Text: {text}\nCategory:"
59
+
60
+ # Génération
61
+ output = classifier(prompt)[0]["generated_text"].strip()
62
+
63
+ # Nettoyage du label
64
+ output_clean = output.split()[0].rstrip(".").capitalize() # ex : "sci/tech." → "Sci/tech"
65
+
66
+ # Mapping pour être sûr que ça correspond à une catégorie connue
67
+ candidate_labels = ["World", "Sports", "Business", "Sci/Tech"]
68
+ prediction = next((label for label in candidate_labels if label.lower() in output_clean.lower()), "Unknown")
69
+
70
+ # Fausse distribution (1.0 pour la classe prédite, 0.0 pour les autres)
71
+ scores = {label: 1.0 if label == prediction else 0.0 for label in candidate_labels}
72
+
73
  return prediction, scores
74
 
75
 
scripts/metric.py CHANGED
@@ -1,5 +1,4 @@
1
  import time
2
- import numpy as np
3
  import pandas as pd
4
  from tqdm import tqdm # ✅ Ajout ici
5
  from datasets import load_dataset
@@ -20,7 +19,7 @@ models_to_evaluate = {
20
  label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
21
 
22
  # Charger un sous-ensemble du jeu de test AG News
23
- dataset = load_dataset("ag_news", split="test[:10%]")
24
 
25
  def evaluate_model(name, inference_func):
26
  print(f"\n🔍 Évaluation du modèle : {name}")
@@ -56,7 +55,7 @@ def evaluate_model(name, inference_func):
56
  f1 = f1_score(true_labels, pred_labels, average='weighted')
57
  prec = precision_score(true_labels, pred_labels, average='weighted')
58
  rec = recall_score(true_labels, pred_labels, average='weighted')
59
- loss = log_loss(true_labels, all_probs)
60
 
61
  print(f"✅ Résultats {name} :")
62
  print(f"- Accuracy : {acc:.4f}")
 
1
  import time
 
2
  import pandas as pd
3
  from tqdm import tqdm # ✅ Ajout ici
4
  from datasets import load_dataset
 
19
  label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
20
 
21
  # Charger un sous-ensemble du jeu de test AG News
22
+ dataset = load_dataset("ag_news", split="test[:3]")
23
 
24
  def evaluate_model(name, inference_func):
25
  print(f"\n🔍 Évaluation du modèle : {name}")
 
55
  f1 = f1_score(true_labels, pred_labels, average='weighted')
56
  prec = precision_score(true_labels, pred_labels, average='weighted')
57
  rec = recall_score(true_labels, pred_labels, average='weighted')
58
+ loss = log_loss(true_labels, all_probs, labels=[0, 1, 2, 3])
59
 
60
  print(f"✅ Résultats {name} :")
61
  print(f"- Accuracy : {acc:.4f}")