import torch import torch.nn.functional as F from dotenv import load_dotenv import os from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification #Mapping entre les ID des classes et les labels textuels id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} def zero_shot_inference(text): """ Effectue une classification zero-shot à l'aide du modèle BART MNLI. Args: text (str): Texte à classifier. Returns: tuple: - str: Label prédit. - dict: Dictionnaire {label: score} pour chaque classe. """ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") candidate_labels = list(id2label.values()) result = classifier(text, candidate_labels) prediction = result["labels"][0] # Formatage des scores avec 4 décimales scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])} return prediction, scores def few_shot_inference(text): """ Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering). Args: text (str): Texte à classifier. Returns: tuple: - str: Label prédit. - dict: Scores pour chaque classe. """ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") #Exemples donnés au modèle pour le guider (prompt engineering) examples = [ ("The president met the UN delegation to discuss global peace.", "World"), ("The football team won their match last night.", "Sports"), ("The company reported a big profit this quarter.", "Business"), ("New research in AI shows promising results.", "Sci/Tech") ] #Construction du prompt avec des exemples prompt = "" for example_text, example_label in examples: prompt += f"Text: {example_text}\nLabel: {example_label}\n\n" prompt += f"Text: {text}\nLabel:" candidate_labels = list(id2label.values()) result = classifier(prompt, candidate_labels) prediction = result["labels"][0] scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])} return prediction, scores def base_model_inference(text): """ Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé). Args: text (str): Texte à classifier. Returns: tuple: - str: Label prédit. - dict: Scores softmax par classe. """ model_name = "textattack/bert-base-uncased-ag-news" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) #Encodage du texte inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) #Prédiction sans calcul de gradients with torch.no_grad(): outputs = model(**inputs) #Calcul des probabilités avec softmax probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy() pred_id = probs.argmax() prediction = id2label[pred_id] scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)} return prediction, scores def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"): """ Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire. Args: text (str): Texte à classifier. model_path (str): Nom du modèle Hugging Face ou chemin local. Returns: tuple: - str: Label prédit. - dict: Scores softmax par classe. """ #Récupération du token d'auth depuis les variables d'environnement token = os.getenv("CLE") tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token) model = AutoModelForSequenceClassification.from_pretrained(model_path) inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy() pred_id = probs.argmax() prediction = id2label[pred_id] scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)} return prediction, scores