Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from scripts.inference import ( | |
zero_shot_inference, | |
few_shot_inference, | |
base_model_inference, | |
fine_tuned_inference | |
) | |
#Lire le README.md | |
readme_content = """# 📰 AG News Text Classification Demo | |
Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**. | |
L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte. | |
--- | |
## 🚀 Démo en ligne | |
L’application est disponible ici : | |
[**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/Merwan611/classification-text) | |
--- | |
## 📂 Organisation du projet | |
- `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`) | |
- `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles | |
- `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News | |
- `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.) | |
- `requirements.txt` : liste des dépendances Python | |
--- | |
## 🧠 Description des modèles utilisés | |
Base model | |
Modèle BERT préentraîné textattack/bert-base-uncased-ag-news. | |
Il est utilisé directement sans réentraînement. Le texte est tokenisé avec AutoTokenizer puis passé au modèle pour obtenir une distribution de probabilité via softmax. | |
Zero-shot | |
Modèle facebook/bart-large-mnli utilisé via la pipeline zero-shot-classification de Hugging Face. | |
Le texte est comparé à une liste de labels cibles (World, Sports, Business, Sci/Tech) sans aucun entraînement préalable sur AG News. Ce modèle s’appuie sur la reconnaissance d’implications textuelles pour inférer la classe la plus probable. | |
Few-shot | |
Basé sur le modèle google/flan-t5-small avec la pipeline text2text-generation. | |
Le prompt inclut quelques exemples de classification manuelle (prompt engineering). Le modèle génère ensuite une réponse textuelle correspondant à la catégorie. Les sorties sont nettoyées et validées par correspondance avec les labels autorisés. | |
Fine-tuned model | |
Modèle bert-base-uncased fine-tuné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe) puis hébergé sur le Hugging Face Hub sous Merwan611/agnews-finetuned-bert. | |
La prédiction utilise également AutoTokenizer et applique une couche softmax sur les logits du modèle. L’accès au modèle peut nécessiter un token d’authentification via une variable d’environnement CLE. | |
--- | |
## 📊 Données et entraînement | |
- **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech) | |
- **Préprocessing** : tokenisation avec `AutoTokenizer` BERT | |
- **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy | |
- **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test | |
--- | |
## 📈 Performances | |
| **Model** | **Accuracy** | **F1 Score** | **Precision** | **Recall** | **Loss** | | |
| ---------------- | ------------ | ------------ | ------------- | ---------- | -------- | | |
| **Fine-tune** | 0.92 | 0.92 | 0.92 | 0.92 | 0.28 | | |
| **Base model** | 0.92 | 0.92 | 0.92 | 0.92 | 0.32 | | |
| **Zero-shot** | 0.68 | 0.68 | 0.69 | 0.68 | 0.87 | | |
| **Few-shot** | 0.87 | 0.87 | 0.87 | 0.87 | 4.74 | | |
Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot. | |
--- | |
## ⚙️ Lancer l’application localement | |
1. Cloner le repo | |
2. Créer un environnement virtuel Python | |
3. Installer les dépendances : | |
```bash | |
pip install -r requirements.txt | |
4. Lancer python app.py""" | |
def predict_with_model(text, model_type): | |
""" | |
Applique une stratégie de classification sur un texte donné. | |
Args: | |
text (str): Le texte d’actualité à analyser. | |
model_type (str): Le modèle choisi ("Zero-shot", "Few-shot", etc.). | |
Returns: | |
tuple: | |
- str: Catégorie prédite. | |
- pandas.DataFrame: Score de confiance par classe. | |
""" | |
if model_type == "Zero-shot": | |
prediction, scores = zero_shot_inference(text) | |
elif model_type == "Few-shot": | |
prediction, scores = few_shot_inference(text) | |
elif model_type == "Fine-tuned": | |
prediction, scores = fine_tuned_inference(text) | |
elif model_type == "Base model": | |
prediction, scores = base_model_inference(text) | |
else: | |
return "Modèle inconnu", pd.DataFrame() | |
# Convertir le dictionnaire des scores en DataFrame pour affichage | |
scores_df = pd.DataFrame([ | |
{"Classe": label, "Score": score} for label, score in scores.items() | |
]) | |
return prediction, scores_df | |
# === Interface Gradio avec deux onglets === | |
with gr.Blocks(title="Classification AG News (4 stratégies)") as app: | |
gr.Markdown("# 📰 Classification de textes AG News") | |
gr.Markdown("Cette application compare plusieurs approches NLP pour classer des actualités.") | |
with gr.Tab("🧠 Inférence"): | |
gr.Markdown("### ✍️ Entrez une phrase d'actualité à analyser") | |
# Entrées utilisateur | |
input_text = gr.Textbox( | |
lines=4, | |
placeholder="Ex: Apple lance un nouveau produit basé sur l'intelligence artificielle...", | |
label="Texte à classifier" | |
) | |
model_choice = gr.Radio( | |
choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"], | |
label="Choisir le modèle", | |
value="Base model" | |
) | |
predict_button = gr.Button("🔍 Prédire") | |
# Sorties | |
label_output = gr.Label(label="🧾 Catégorie prédite") | |
scores_output = gr.BarPlot( | |
label="📊 Scores de confiance", | |
x="Classe", y="Score", color="Classe" | |
) | |
# Action sur clic bouton | |
predict_button.click( | |
fn=predict_with_model, | |
inputs=[input_text, model_choice], | |
outputs=[label_output, scores_output] | |
) | |
with gr.Tab("📄 Documentation"): | |
gr.Markdown(readme_content) | |
# Lancer l'app | |
if __name__ == "__main__": | |
app.launch() | |