File size: 4,488 Bytes
0cebe35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
import os
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

#Mapping entre les ID des classes et les labels textuels
id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}


def zero_shot_inference(text):
    """

    Effectue une classification zero-shot à l'aide du modèle BART MNLI.

    

    Args:

        text (str): Texte à classifier.

    

    Returns:

        tuple:

            - str: Label prédit.

            - dict: Dictionnaire {label: score} pour chaque classe.

    """
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    candidate_labels = list(id2label.values())
    result = classifier(text, candidate_labels)
    prediction = result["labels"][0]
    # Formatage des scores avec 4 décimales
    scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
    return prediction, scores


def few_shot_inference(text):
    """

    Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering).

    

    Args:

        text (str): Texte à classifier.

    

    Returns:

        tuple:

            - str: Label prédit.

            - dict: Scores pour chaque classe.

    """
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    
    #Exemples donnés au modèle pour le guider (prompt engineering)
    examples = [
        ("The president met the UN delegation to discuss global peace.", "World"),
        ("The football team won their match last night.", "Sports"),
        ("The company reported a big profit this quarter.", "Business"),
        ("New research in AI shows promising results.", "Sci/Tech")
    ]
    
    #Construction du prompt avec des exemples
    prompt = ""
    for example_text, example_label in examples:
        prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
    prompt += f"Text: {text}\nLabel:"

    candidate_labels = list(id2label.values())
    result = classifier(prompt, candidate_labels)
    prediction = result["labels"][0]
    scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
    return prediction, scores


def base_model_inference(text):
    """

    Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé).

    

    Args:

        text (str): Texte à classifier.

    

    Returns:

        tuple:

            - str: Label prédit.

            - dict: Scores softmax par classe.

    """
    model_name = "textattack/bert-base-uncased-ag-news"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    #Encodage du texte
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    #Prédiction sans calcul de gradients
    with torch.no_grad():
        outputs = model(**inputs)

    #Calcul des probabilités avec softmax
    probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()

    pred_id = probs.argmax()
    prediction = id2label[pred_id]
    scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
    return prediction, scores


def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"):
    """

    Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire.



    Args:

        text (str): Texte à classifier.

        model_path (str): Nom du modèle Hugging Face ou chemin local.



    Returns:

        tuple:

            - str: Label prédit.

            - dict: Scores softmax par classe.

    """

    #Récupération du token d'auth depuis les variables d'environnement
    token = os.getenv("CLE")

    tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
    pred_id = probs.argmax()
    prediction = id2label[pred_id]
    scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
    return prediction, scores