File size: 4,936 Bytes
0cebe35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11204e4
0cebe35
 
 
 
 
 
11204e4
 
0cebe35
11204e4
 
 
0cebe35
 
 
 
 
 
 
11204e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cebe35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
import os
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

#Mapping entre les ID des classes et les labels textuels
id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}


def zero_shot_inference(text):
    """

    Effectue une classification zero-shot à l'aide du modèle BART MNLI.

    

    Args:

        text (str): Texte à classifier.

    

    Returns:

        tuple:

            - str: Label prédit.

            - dict: Dictionnaire {label: score} pour chaque classe.

    """
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    candidate_labels = list(id2label.values())
    result = classifier(text, candidate_labels)
    prediction = result["labels"][0]
    # Formatage des scores avec 4 décimales
    scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
    return prediction, scores


def few_shot_inference(text):
    """

    Classification few-shot avec FLAN-T5 : génère uniquement le label (World, Sports, etc.).

    

    Args:

        text (str): Texte à classifier.

    

    Returns:

        tuple:

            - str: Label prédit (nettoyé et validé).

            - dict: Détails du texte généré brut.

    """
    model_name = "google/flan-t5-small"
    classifier = pipeline("text2text-generation", model=model_name, max_new_tokens=10)

    examples = [
        ("The president met the UN delegation to discuss global peace.", "World"),
        ("The football team won their match last night.", "Sports"),
        ("The company reported a big profit this quarter.", "Business"),
        ("New research in AI shows promising results.", "Sci/Tech")
    ]

    # Prompt few-shot
    prompt = "Classify the following text into one of the following categories: World, Sports, Business, Sci/Tech.\n\n"
    for ex_text, ex_label in examples:
        prompt += f"Text: {ex_text}\nCategory: {ex_label}\n\n"
    prompt += f"Text: {text}\nCategory:"

    # Génération
    output = classifier(prompt)[0]["generated_text"].strip()

    # Nettoyage du label
    output_clean = output.split()[0].rstrip(".").capitalize()  # ex : "sci/tech." → "Sci/tech"

    # Mapping pour être sûr que ça correspond à une catégorie connue
    candidate_labels = ["World", "Sports", "Business", "Sci/Tech"]
    prediction = next((label for label in candidate_labels if label.lower() in output_clean.lower()), "Unknown")

    # Fausse distribution (1.0 pour la classe prédite, 0.0 pour les autres)
    scores = {label: 1.0 if label == prediction else 0.0 for label in candidate_labels}

    return prediction, scores


def base_model_inference(text):
    """

    Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé).

    

    Args:

        text (str): Texte à classifier.

    

    Returns:

        tuple:

            - str: Label prédit.

            - dict: Scores softmax par classe.

    """
    model_name = "textattack/bert-base-uncased-ag-news"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    #Encodage du texte
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    #Prédiction sans calcul de gradients
    with torch.no_grad():
        outputs = model(**inputs)

    #Calcul des probabilités avec softmax
    probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()

    pred_id = probs.argmax()
    prediction = id2label[pred_id]
    scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
    return prediction, scores


def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"):
    """

    Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire.



    Args:

        text (str): Texte à classifier.

        model_path (str): Nom du modèle Hugging Face ou chemin local.



    Returns:

        tuple:

            - str: Label prédit.

            - dict: Scores softmax par classe.

    """

    #Récupération du token d'auth depuis les variables d'environnement
    token = os.getenv("CLE")

    tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
    pred_id = probs.argmax()
    prediction = id2label[pred_id]
    scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
    return prediction, scores