Spaces:
Sleeping
Sleeping
File size: 5,652 Bytes
0cebe35 7ae005b 9682a49 8f7bccf 9682a49 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 0cebe35 716abff 7ae005b 716abff 0cebe35 716abff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
import pandas as pd
from scripts.inference import (
zero_shot_inference,
few_shot_inference,
base_model_inference,
fine_tuned_inference
)
#Lire le README.md
readme_content = """# 📰 AG News Text Classification Demo
Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**.
L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte.
---
## 🚀 Démo en ligne
L’application est disponible ici :
[**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/Merwan611/classification-text)
---
## 📂 Organisation du projet
- `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`)
- `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles
- `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News
- `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.)
- `requirements.txt` : liste des dépendances Python
---
## 🧠 Description des modèles utilisés
1. **Base model**
Modèle BERT préentraîné `textattack/bert-base-uncased-ag-news` utilisé directement sans fine-tuning.
2. **Zero-shot**
Modèle `facebook/bart-large-mnli` utilisé pour classification zero-shot via pipeline Hugging Face.
3. **Few-shot**
Approche zero-shot avec exemples dans le prompt (prompt engineering).
4. **Fine-tuned model**
Modèle BERT `bert-base-uncased` entraîné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe), sauvegardé sur Hugging Face Hub sous `Merwan611/agnews-finetuned-bert`.
---
## 📊 Données et entraînement
- **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech)
- **Préprocessing** : tokenisation avec `AutoTokenizer` BERT
- **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy
- **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test
---
## 📈 Performances
| **Model** | **Accuracy** | **F1 Score** | **Precision** | **Recall** | **Loss** |
| ---------------- | ------------ | ------------ | ------------- | ---------- | -------- |
| **Fine-tune** | 0.92 | 0.92 | 0.92 | 0.92 | 0.28 |
| **Base model** | 0.92 | 0.92 | 0.92 | 0.92 | 0.32 |
| **Zero-shot** | 0.68 | 0.68 | 0.69 | 0.68 | 0.87 |
| **Few-shot** | 0.87 | 0.87 | 0.87 | 0.87 | 4.74 |
Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot.
---
## ⚙️ Lancer l’application localement
1. Cloner le repo
2. Créer un environnement virtuel Python
3. Installer les dépendances :
```bash
pip install -r requirements.txt
4. Lancer python app.py"""
def predict_with_model(text, model_type):
"""
Applique une stratégie de classification sur un texte donné.
Args:
text (str): Le texte d’actualité à analyser.
model_type (str): Le modèle choisi ("Zero-shot", "Few-shot", etc.).
Returns:
tuple:
- str: Catégorie prédite.
- pandas.DataFrame: Score de confiance par classe.
"""
if model_type == "Zero-shot":
prediction, scores = zero_shot_inference(text)
elif model_type == "Few-shot":
prediction, scores = few_shot_inference(text)
elif model_type == "Fine-tuned":
prediction, scores = fine_tuned_inference(text)
elif model_type == "Base model":
prediction, scores = base_model_inference(text)
else:
return "Modèle inconnu", pd.DataFrame()
# Convertir le dictionnaire des scores en DataFrame pour affichage
scores_df = pd.DataFrame([
{"Classe": label, "Score": score} for label, score in scores.items()
])
return prediction, scores_df
# === Interface Gradio avec deux onglets ===
with gr.Blocks(title="Classification AG News (4 stratégies)") as app:
gr.Markdown("# 📰 Classification de textes AG News")
gr.Markdown("Cette application compare plusieurs approches NLP pour classer des actualités.")
with gr.Tab("🧠 Inférence"):
gr.Markdown("### ✍️ Entrez une phrase d'actualité à analyser")
# Entrées utilisateur
input_text = gr.Textbox(
lines=4,
placeholder="Ex: Apple lance un nouveau produit basé sur l'intelligence artificielle...",
label="Texte à classifier"
)
model_choice = gr.Radio(
choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
label="Choisir le modèle",
value="Base model"
)
predict_button = gr.Button("🔍 Prédire")
# Sorties
label_output = gr.Label(label="🧾 Catégorie prédite")
scores_output = gr.BarPlot(
label="📊 Scores de confiance",
x="Classe", y="Score", color="Classe"
)
# Action sur clic bouton
predict_button.click(
fn=predict_with_model,
inputs=[input_text, model_choice],
outputs=[label_output, scores_output]
)
with gr.Tab("📄 Documentation"):
gr.Markdown(readme_content)
# Lancer l'app
if __name__ == "__main__":
app.launch()
|