File size: 5,652 Bytes
0cebe35
 
 
 
 
 
 
 
7ae005b
 
9682a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7bccf
 
 
 
 
 
 
9682a49
 
 
 
 
 
 
 
 
 
 
 
 
 
0cebe35
 
 
716abff
0cebe35
 
716abff
 
0cebe35
 
 
 
716abff
0cebe35
 
 
 
 
 
 
 
 
 
 
 
716abff
0cebe35
 
 
 
 
716abff
 
 
 
 
 
 
 
 
 
 
0cebe35
716abff
0cebe35
716abff
 
 
0cebe35
 
716abff
0cebe35
716abff
 
 
 
 
 
 
 
0cebe35
 
716abff
 
 
 
 
 
 
 
7ae005b
716abff
 
0cebe35
716abff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import pandas as pd
from scripts.inference import (
    zero_shot_inference,
    few_shot_inference,
    base_model_inference,
    fine_tuned_inference
)

#Lire le README.md
readme_content = """# 📰 AG News Text Classification Demo



Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**.  

L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte.



---



## 🚀 Démo en ligne



L’application est disponible ici :  

[**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/Merwan611/classification-text) 



---



## 📂 Organisation du projet



- `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`)

- `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles

- `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News

- `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.)

- `requirements.txt` : liste des dépendances Python



---



## 🧠 Description des modèles utilisés



1. **Base model**  

   Modèle BERT préentraîné `textattack/bert-base-uncased-ag-news` utilisé directement sans fine-tuning.



2. **Zero-shot**  

   Modèle `facebook/bart-large-mnli` utilisé pour classification zero-shot via pipeline Hugging Face.



3. **Few-shot**  

   Approche zero-shot avec exemples dans le prompt (prompt engineering).



4. **Fine-tuned model**  

   Modèle BERT `bert-base-uncased` entraîné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe), sauvegardé sur Hugging Face Hub sous `Merwan611/agnews-finetuned-bert`.



---



## 📊 Données et entraînement



- **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech)

- **Préprocessing** : tokenisation avec `AutoTokenizer` BERT

- **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy

- **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test



---



## 📈 Performances



| **Model**        | **Accuracy** | **F1 Score** | **Precision** | **Recall** | **Loss** |

| ---------------- | ------------ | ------------ | ------------- | ---------- | -------- |

| **Fine-tune**    | 0.92         | 0.92         | 0.92          | 0.92       | 0.28     |

| **Base model**   | 0.92         | 0.92         | 0.92          | 0.92       | 0.32     |

| **Zero-shot**    | 0.68         | 0.68         | 0.69          | 0.68       | 0.87     |

| **Few-shot**     | 0.87         | 0.87         | 0.87          | 0.87       | 4.74     |





Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot.



---



## ⚙️ Lancer l’application localement



1. Cloner le repo  

2. Créer un environnement virtuel Python  

3. Installer les dépendances :  

   ```bash

   pip install -r requirements.txt

4. Lancer python app.py"""


def predict_with_model(text, model_type):
    """

    Applique une stratégie de classification sur un texte donné.



    Args:

        text (str): Le texte d’actualité à analyser.

        model_type (str): Le modèle choisi ("Zero-shot", "Few-shot", etc.).



    Returns:

        tuple:

            - str: Catégorie prédite.

            - pandas.DataFrame: Score de confiance par classe.

    """
    if model_type == "Zero-shot":
        prediction, scores = zero_shot_inference(text)
    elif model_type == "Few-shot":
        prediction, scores = few_shot_inference(text)
    elif model_type == "Fine-tuned":
        prediction, scores = fine_tuned_inference(text)
    elif model_type == "Base model":
        prediction, scores = base_model_inference(text)
    else:
        return "Modèle inconnu", pd.DataFrame()

    # Convertir le dictionnaire des scores en DataFrame pour affichage
    scores_df = pd.DataFrame([
        {"Classe": label, "Score": score} for label, score in scores.items()
    ])
    return prediction, scores_df

# === Interface Gradio avec deux onglets ===
with gr.Blocks(title="Classification AG News (4 stratégies)") as app:

    gr.Markdown("# 📰 Classification de textes AG News")
    gr.Markdown("Cette application compare plusieurs approches NLP pour classer des actualités.")

    with gr.Tab("🧠 Inférence"):
        gr.Markdown("### ✍️ Entrez une phrase d'actualité à analyser")

        # Entrées utilisateur
        input_text = gr.Textbox(
            lines=4,
            placeholder="Ex: Apple lance un nouveau produit basé sur l'intelligence artificielle...",
            label="Texte à classifier"
        )

        model_choice = gr.Radio(
            choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
            label="Choisir le modèle",
            value="Base model"
        )

        predict_button = gr.Button("🔍 Prédire")

        # Sorties
        label_output = gr.Label(label="🧾 Catégorie prédite")
        scores_output = gr.BarPlot(
            label="📊 Scores de confiance",
            x="Classe", y="Score", color="Classe"
        )

        # Action sur clic bouton
        predict_button.click(
            fn=predict_with_model,
            inputs=[input_text, model_choice],
            outputs=[label_output, scores_output]
        )

    with gr.Tab("📄 Documentation"):
        gr.Markdown(readme_content)

# Lancer l'app
if __name__ == "__main__":
    app.launch()