Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from dotenv import load_dotenv | |
import os | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
#Mapping entre les ID des classes et les labels textuels | |
id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} | |
def zero_shot_inference(text): | |
""" | |
Effectue une classification zero-shot à l'aide du modèle BART MNLI. | |
Args: | |
text (str): Texte à classifier. | |
Returns: | |
tuple: | |
- str: Label prédit. | |
- dict: Dictionnaire {label: score} pour chaque classe. | |
""" | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
candidate_labels = list(id2label.values()) | |
result = classifier(text, candidate_labels) | |
prediction = result["labels"][0] | |
# Formatage des scores avec 4 décimales | |
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])} | |
return prediction, scores | |
def few_shot_inference(text): | |
""" | |
Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering). | |
Args: | |
text (str): Texte à classifier. | |
Returns: | |
tuple: | |
- str: Label prédit. | |
- dict: Scores pour chaque classe. | |
""" | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
#Exemples donnés au modèle pour le guider (prompt engineering) | |
examples = [ | |
("The president met the UN delegation to discuss global peace.", "World"), | |
("The football team won their match last night.", "Sports"), | |
("The company reported a big profit this quarter.", "Business"), | |
("New research in AI shows promising results.", "Sci/Tech") | |
] | |
#Construction du prompt avec des exemples | |
prompt = "" | |
for example_text, example_label in examples: | |
prompt += f"Text: {example_text}\nLabel: {example_label}\n\n" | |
prompt += f"Text: {text}\nLabel:" | |
candidate_labels = list(id2label.values()) | |
result = classifier(prompt, candidate_labels) | |
prediction = result["labels"][0] | |
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])} | |
return prediction, scores | |
def base_model_inference(text): | |
""" | |
Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé). | |
Args: | |
text (str): Texte à classifier. | |
Returns: | |
tuple: | |
- str: Label prédit. | |
- dict: Scores softmax par classe. | |
""" | |
model_name = "textattack/bert-base-uncased-ag-news" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
#Encodage du texte | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
#Prédiction sans calcul de gradients | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
#Calcul des probabilités avec softmax | |
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy() | |
pred_id = probs.argmax() | |
prediction = id2label[pred_id] | |
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)} | |
return prediction, scores | |
def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"): | |
""" | |
Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire. | |
Args: | |
text (str): Texte à classifier. | |
model_path (str): Nom du modèle Hugging Face ou chemin local. | |
Returns: | |
tuple: | |
- str: Label prédit. | |
- dict: Scores softmax par classe. | |
""" | |
#Récupération du token d'auth depuis les variables d'environnement | |
token = os.getenv("CLE") | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token) | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy() | |
pred_id = probs.argmax() | |
prediction = id2label[pred_id] | |
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)} | |
return prediction, scores | |