Spaces:
Sleeping
Sleeping
File size: 4,936 Bytes
0cebe35 11204e4 0cebe35 11204e4 0cebe35 11204e4 0cebe35 11204e4 0cebe35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
import os
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
#Mapping entre les ID des classes et les labels textuels
id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
def zero_shot_inference(text):
"""
Effectue une classification zero-shot à l'aide du modèle BART MNLI.
Args:
text (str): Texte à classifier.
Returns:
tuple:
- str: Label prédit.
- dict: Dictionnaire {label: score} pour chaque classe.
"""
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
candidate_labels = list(id2label.values())
result = classifier(text, candidate_labels)
prediction = result["labels"][0]
# Formatage des scores avec 4 décimales
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
return prediction, scores
def few_shot_inference(text):
"""
Classification few-shot avec FLAN-T5 : génère uniquement le label (World, Sports, etc.).
Args:
text (str): Texte à classifier.
Returns:
tuple:
- str: Label prédit (nettoyé et validé).
- dict: Détails du texte généré brut.
"""
model_name = "google/flan-t5-small"
classifier = pipeline("text2text-generation", model=model_name, max_new_tokens=10)
examples = [
("The president met the UN delegation to discuss global peace.", "World"),
("The football team won their match last night.", "Sports"),
("The company reported a big profit this quarter.", "Business"),
("New research in AI shows promising results.", "Sci/Tech")
]
# Prompt few-shot
prompt = "Classify the following text into one of the following categories: World, Sports, Business, Sci/Tech.\n\n"
for ex_text, ex_label in examples:
prompt += f"Text: {ex_text}\nCategory: {ex_label}\n\n"
prompt += f"Text: {text}\nCategory:"
# Génération
output = classifier(prompt)[0]["generated_text"].strip()
# Nettoyage du label
output_clean = output.split()[0].rstrip(".").capitalize() # ex : "sci/tech." → "Sci/tech"
# Mapping pour être sûr que ça correspond à une catégorie connue
candidate_labels = ["World", "Sports", "Business", "Sci/Tech"]
prediction = next((label for label in candidate_labels if label.lower() in output_clean.lower()), "Unknown")
# Fausse distribution (1.0 pour la classe prédite, 0.0 pour les autres)
scores = {label: 1.0 if label == prediction else 0.0 for label in candidate_labels}
return prediction, scores
def base_model_inference(text):
"""
Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé).
Args:
text (str): Texte à classifier.
Returns:
tuple:
- str: Label prédit.
- dict: Scores softmax par classe.
"""
model_name = "textattack/bert-base-uncased-ag-news"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
#Encodage du texte
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
#Prédiction sans calcul de gradients
with torch.no_grad():
outputs = model(**inputs)
#Calcul des probabilités avec softmax
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_id = probs.argmax()
prediction = id2label[pred_id]
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
return prediction, scores
def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"):
"""
Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire.
Args:
text (str): Texte à classifier.
model_path (str): Nom du modèle Hugging Face ou chemin local.
Returns:
tuple:
- str: Label prédit.
- dict: Scores softmax par classe.
"""
#Récupération du token d'auth depuis les variables d'environnement
token = os.getenv("CLE")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_id = probs.argmax()
prediction = id2label[pred_id]
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
return prediction, scores
|