STTDARIJAAPI / app.py
Mohssinibra's picture
Update app.py
3f0375d verified
import gradio as gr
import librosa
import torch
from transformers import (
Wav2Vec2ForCTC, Wav2Vec2Processor,
MarianMTModel, MarianTokenizer,
BertForSequenceClassification, AutoModel, AutoTokenizer,AutoModelForSequenceClassification
)
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)
# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
# Charger le modèle et le tokenizer Darija
sentiment_model_name = "BenhamdaneNawfal/sentiment-analysis-darija"
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name,num_labels=3,ignore_mismatched_sizes=True).to("cuda" if torch.cuda.is_available() else "cpu")
# Labels du modèle (à modifier selon le modèle utilisé)
sentiment_labels = ["Négatif", "Neutre", "Positif"]
# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
"مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau
"مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet
"مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement
"مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits
"مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...)
"مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM
"مساعدة تقنية (Mosa3ada technique)", # Assistance technique
"العروض والتخفيضات (Offres w promotions)", # Offres et promotions
"طلب معلومات (Talab l'ma3loumat)", # Demande d'information
"شكاية (Chikaya)", # Réclamation
"حاجة أخرى (Chi haja okhra)" # Autre
]
# Libellés en Anglais
english_topic_labels = [
"Network Issue",
"Internet Issue",
"Billing & Payment Issue",
"Recharge & Plans",
"App Issue",
"SIM Card Issue",
"Technical Support",
"Offers & Promotions",
"General Inquiry",
"Complaint",
"Other"
]
# New Function to Classify Topics by Keywords
def classify_topic_by_keywords(text, language='ar'):
# Arabic keywords for each topic
arabic_keywords = {
"Customer Service": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
"résiliation Service": ["نوقف", "تجديد", "خصم", "عرض", "نحي"],
"Billing Issue": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"],
"Other": ["شيء آخر", "غير ذلك", "أخرى"]
}
# English keywords for each topic
english_keywords = {
"Customer Service": ["service", "inquiry", "help", "support", "question", "assistance"],
"résiliation Service": ["retain", "cut", "discount", "stopped", "promotion","stop"],
"Billing Issue": ["bill", "payment", "problem", "error", "amount"],
"Other": ["other", "none of the above", "something else"]
}
# Select the appropriate keywords based on the language
if language == 'ar':
keywords = arabic_keywords
elif language == 'en':
keywords = english_keywords
else:
raise ValueError("Invalid language specified. Use 'ar' for Arabic or 'en' for English.")
# Convert text to lowercase to avoid inconsistencies
text = text.lower()
# Check for keywords in the text and calculate the topic scores
topic_scores = {topic: 0 for topic in keywords} # Initialize topic scores
for topic, words in keywords.items():
for word in words:
if word in text:
topic_scores[topic] += 1 # Increment score for each keyword found
# Check if no keywords are found, and in that case, return "Other"
if all(score == 0 for score in topic_scores.values()):
return "Other"
# Return the topic with the highest score
best_topic = max(topic_scores, key=topic_scores.get)
return best_topic
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
# Transcription (Darija)
with torch.no_grad():
logits = wav2vec_model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics using BERT models
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
# Classify topics using keywords-based classification
darija_keyword_topic = classify_topic_by_keywords(transcription,language='ar' )
english_keyword_topic = classify_topic_by_keywords(translation,language='en' )
#english_keyword_topic = classify_topic_by_keywords(translation )
# l'analyse de sentiment
sentiment = analyze_sentiment(transcription)
return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic,sentiment
except Exception as e:
return f"Error processing audio: {str(e)}", "", "", "", "", "", ""
def translate_text(text):
"""Translate Arabic text to English."""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated_tokens = translation_model.generate(**inputs)
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
def analyze_sentiment(text):
"""Classifie le sentiment du texte en Darija."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Tokenizer le texte
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
# Prédiction
with torch.no_grad():
outputs = sentiment_model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
# Retourner la classe correspondante
return sentiment_labels[predicted_class] if predicted_class < len(sentiment_labels) else "Inconnu"
# 🔹 Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")
sentiment_output = gr.Textbox(label="Sentiment (Darija)")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output,
darija_topic_output, english_topic_output,
darija_keyword_topic_output, english_keyword_topic_output, sentiment_output])
demo.launch()