Merwan6
Commit initial
0cebe35
raw
history blame
4.49 kB
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
import os
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
#Mapping entre les ID des classes et les labels textuels
id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
def zero_shot_inference(text):
"""
Effectue une classification zero-shot à l'aide du modèle BART MNLI.
Args:
text (str): Texte à classifier.
Returns:
tuple:
- str: Label prédit.
- dict: Dictionnaire {label: score} pour chaque classe.
"""
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
candidate_labels = list(id2label.values())
result = classifier(text, candidate_labels)
prediction = result["labels"][0]
# Formatage des scores avec 4 décimales
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
return prediction, scores
def few_shot_inference(text):
"""
Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering).
Args:
text (str): Texte à classifier.
Returns:
tuple:
- str: Label prédit.
- dict: Scores pour chaque classe.
"""
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
#Exemples donnés au modèle pour le guider (prompt engineering)
examples = [
("The president met the UN delegation to discuss global peace.", "World"),
("The football team won their match last night.", "Sports"),
("The company reported a big profit this quarter.", "Business"),
("New research in AI shows promising results.", "Sci/Tech")
]
#Construction du prompt avec des exemples
prompt = ""
for example_text, example_label in examples:
prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
prompt += f"Text: {text}\nLabel:"
candidate_labels = list(id2label.values())
result = classifier(prompt, candidate_labels)
prediction = result["labels"][0]
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
return prediction, scores
def base_model_inference(text):
"""
Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé).
Args:
text (str): Texte à classifier.
Returns:
tuple:
- str: Label prédit.
- dict: Scores softmax par classe.
"""
model_name = "textattack/bert-base-uncased-ag-news"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
#Encodage du texte
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
#Prédiction sans calcul de gradients
with torch.no_grad():
outputs = model(**inputs)
#Calcul des probabilités avec softmax
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_id = probs.argmax()
prediction = id2label[pred_id]
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
return prediction, scores
def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"):
"""
Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire.
Args:
text (str): Texte à classifier.
model_path (str): Nom du modèle Hugging Face ou chemin local.
Returns:
tuple:
- str: Label prédit.
- dict: Scores softmax par classe.
"""
#Récupération du token d'auth depuis les variables d'environnement
token = os.getenv("CLE")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_id = probs.argmax()
prediction = id2label[pred_id]
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
return prediction, scores