Merwan6
modif doc
719631b
import gradio as gr
import pandas as pd
from scripts.inference import (
zero_shot_inference,
few_shot_inference,
base_model_inference,
fine_tuned_inference
)
#Lire le README.md
readme_content = """# 📰 AG News Text Classification Demo
Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**.
L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte.
---
## 🚀 Démo en ligne
L’application est disponible ici :
[**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/Merwan611/classification-text)
---
## 📂 Organisation du projet
- `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`)
- `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles
- `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News
- `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.)
- `requirements.txt` : liste des dépendances Python
---
## 🧠 Description des modèles utilisés
Base model
Modèle BERT préentraîné textattack/bert-base-uncased-ag-news.
Il est utilisé directement sans réentraînement. Le texte est tokenisé avec AutoTokenizer puis passé au modèle pour obtenir une distribution de probabilité via softmax.
Zero-shot
Modèle facebook/bart-large-mnli utilisé via la pipeline zero-shot-classification de Hugging Face.
Le texte est comparé à une liste de labels cibles (World, Sports, Business, Sci/Tech) sans aucun entraînement préalable sur AG News. Ce modèle s’appuie sur la reconnaissance d’implications textuelles pour inférer la classe la plus probable.
Few-shot
Basé sur le modèle google/flan-t5-small avec la pipeline text2text-generation.
Le prompt inclut quelques exemples de classification manuelle (prompt engineering). Le modèle génère ensuite une réponse textuelle correspondant à la catégorie. Les sorties sont nettoyées et validées par correspondance avec les labels autorisés.
Fine-tuned model
Modèle bert-base-uncased fine-tuné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe) puis hébergé sur le Hugging Face Hub sous Merwan611/agnews-finetuned-bert.
La prédiction utilise également AutoTokenizer et applique une couche softmax sur les logits du modèle. L’accès au modèle peut nécessiter un token d’authentification via une variable d’environnement CLE.
---
## 📊 Données et entraînement
- **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech)
- **Préprocessing** : tokenisation avec `AutoTokenizer` BERT
- **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy
- **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test
---
## 📈 Performances
| **Model** | **Accuracy** | **F1 Score** | **Precision** | **Recall** | **Loss** |
| ---------------- | ------------ | ------------ | ------------- | ---------- | -------- |
| **Fine-tune** | 0.92 | 0.92 | 0.92 | 0.92 | 0.28 |
| **Base model** | 0.92 | 0.92 | 0.92 | 0.92 | 0.32 |
| **Zero-shot** | 0.68 | 0.68 | 0.69 | 0.68 | 0.87 |
| **Few-shot** | 0.87 | 0.87 | 0.87 | 0.87 | 4.74 |
Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot.
---
## ⚙️ Lancer l’application localement
1. Cloner le repo
2. Créer un environnement virtuel Python
3. Installer les dépendances :
```bash
pip install -r requirements.txt
4. Lancer python app.py"""
def predict_with_model(text, model_type):
"""
Applique une stratégie de classification sur un texte donné.
Args:
text (str): Le texte d’actualité à analyser.
model_type (str): Le modèle choisi ("Zero-shot", "Few-shot", etc.).
Returns:
tuple:
- str: Catégorie prédite.
- pandas.DataFrame: Score de confiance par classe.
"""
if model_type == "Zero-shot":
prediction, scores = zero_shot_inference(text)
elif model_type == "Few-shot":
prediction, scores = few_shot_inference(text)
elif model_type == "Fine-tuned":
prediction, scores = fine_tuned_inference(text)
elif model_type == "Base model":
prediction, scores = base_model_inference(text)
else:
return "Modèle inconnu", pd.DataFrame()
# Convertir le dictionnaire des scores en DataFrame pour affichage
scores_df = pd.DataFrame([
{"Classe": label, "Score": score} for label, score in scores.items()
])
return prediction, scores_df
# === Interface Gradio avec deux onglets ===
with gr.Blocks(title="Classification AG News (4 stratégies)") as app:
gr.Markdown("# 📰 Classification de textes AG News")
gr.Markdown("Cette application compare plusieurs approches NLP pour classer des actualités.")
with gr.Tab("🧠 Inférence"):
gr.Markdown("### ✍️ Entrez une phrase d'actualité à analyser")
# Entrées utilisateur
input_text = gr.Textbox(
lines=4,
placeholder="Ex: Apple lance un nouveau produit basé sur l'intelligence artificielle...",
label="Texte à classifier"
)
model_choice = gr.Radio(
choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
label="Choisir le modèle",
value="Base model"
)
predict_button = gr.Button("🔍 Prédire")
# Sorties
label_output = gr.Label(label="🧾 Catégorie prédite")
scores_output = gr.BarPlot(
label="📊 Scores de confiance",
x="Classe", y="Score", color="Classe"
)
# Action sur clic bouton
predict_button.click(
fn=predict_with_model,
inputs=[input_text, model_choice],
outputs=[label_output, scores_output]
)
with gr.Tab("📄 Documentation"):
gr.Markdown(readme_content)
# Lancer l'app
if __name__ == "__main__":
app.launch()