Spaces:
Running
Running
import gradio as gr | |
import librosa | |
import torch | |
from transformers import ( | |
Wav2Vec2ForCTC, Wav2Vec2Processor, | |
MarianMTModel, MarianTokenizer, | |
BertForSequenceClassification, AutoModel, AutoTokenizer,AutoModelForSequenceClassification | |
) | |
# Detect device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
### 🔹 Load Models & Tokenizers Once ### | |
# Wav2Vec2 for Darija transcription | |
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija" | |
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name) | |
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device) | |
# MarianMT for translation (Arabic → English) | |
translation_model_name = "Helsinki-NLP/opus-mt-ar-en" | |
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) | |
# AraBERT for Darija topic classification | |
arabert_model_name = "aubmindlab/bert-base-arabert" | |
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name) | |
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device) | |
# BERT for English topic classification | |
bert_model_name = "bert-base-uncased" | |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device) | |
# Charger le modèle et le tokenizer Darija | |
sentiment_model_name = "BenhamdaneNawfal/sentiment-analysis-darija" | |
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name) | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name,num_labels=3,ignore_mismatched_sizes=True).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Labels du modèle (à modifier selon le modèle utilisé) | |
sentiment_labels = ["Négatif", "Neutre", "Positif"] | |
# Libellés en Darija (Arabe et Latin) | |
darija_topic_labels = [ | |
"مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau | |
"مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet | |
"مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement | |
"مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits | |
"مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...) | |
"مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM | |
"مساعدة تقنية (Mosa3ada technique)", # Assistance technique | |
"العروض والتخفيضات (Offres w promotions)", # Offres et promotions | |
"طلب معلومات (Talab l'ma3loumat)", # Demande d'information | |
"شكاية (Chikaya)", # Réclamation | |
"حاجة أخرى (Chi haja okhra)" # Autre | |
] | |
# Libellés en Anglais | |
english_topic_labels = [ | |
"Network Issue", | |
"Internet Issue", | |
"Billing & Payment Issue", | |
"Recharge & Plans", | |
"App Issue", | |
"SIM Card Issue", | |
"Technical Support", | |
"Offers & Promotions", | |
"General Inquiry", | |
"Complaint", | |
"Other" | |
] | |
# New Function to Classify Topics by Keywords | |
def classify_topic_by_keywords(text, language='ar'): | |
# Arabic keywords for each topic | |
arabic_keywords = { | |
"Customer Service": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"], | |
"résiliation Service": ["نوقف", "تجديد", "خصم", "عرض", "نحي"], | |
"Billing Issue": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"], | |
"Other": ["شيء آخر", "غير ذلك", "أخرى"] | |
} | |
# English keywords for each topic | |
english_keywords = { | |
"Customer Service": ["service", "inquiry", "help", "support", "question", "assistance"], | |
"résiliation Service": ["retain", "cut", "discount", "stopped", "promotion","stop"], | |
"Billing Issue": ["bill", "payment", "problem", "error", "amount"], | |
"Other": ["other", "none of the above", "something else"] | |
} | |
# Select the appropriate keywords based on the language | |
if language == 'ar': | |
keywords = arabic_keywords | |
elif language == 'en': | |
keywords = english_keywords | |
else: | |
raise ValueError("Invalid language specified. Use 'ar' for Arabic or 'en' for English.") | |
# Convert text to lowercase to avoid inconsistencies | |
text = text.lower() | |
# Check for keywords in the text and calculate the topic scores | |
topic_scores = {topic: 0 for topic in keywords} # Initialize topic scores | |
for topic, words in keywords.items(): | |
for word in words: | |
if word in text: | |
topic_scores[topic] += 1 # Increment score for each keyword found | |
# Check if no keywords are found, and in that case, return "Other" | |
if all(score == 0 for score in topic_scores.values()): | |
return "Other" | |
# Return the topic with the highest score | |
best_topic = max(topic_scores, key=topic_scores.get) | |
return best_topic | |
def transcribe_audio(audio): | |
"""Convert audio to text, translate it, and classify topics in both Darija and English.""" | |
try: | |
# Load and preprocess audio | |
audio_array, sr = librosa.load(audio, sr=16000) | |
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device) | |
# Transcription (Darija) | |
with torch.no_grad(): | |
logits = wav2vec_model(input_values).logits | |
tokens = torch.argmax(logits, axis=-1) | |
transcription = processor.decode(tokens[0]) | |
# Translate to English | |
translation = translate_text(transcription) | |
# Classify topics using BERT models | |
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels) | |
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels) | |
# Classify topics using keywords-based classification | |
darija_keyword_topic = classify_topic_by_keywords(transcription,language='ar' ) | |
english_keyword_topic = classify_topic_by_keywords(translation,language='en' ) | |
#english_keyword_topic = classify_topic_by_keywords(translation ) | |
# l'analyse de sentiment | |
sentiment = analyze_sentiment(transcription) | |
return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic,sentiment | |
except Exception as e: | |
return f"Error processing audio: {str(e)}", "", "", "", "", "", "" | |
def translate_text(text): | |
"""Translate Arabic text to English.""" | |
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
with torch.no_grad(): | |
translated_tokens = translation_model.generate(**inputs) | |
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
def classify_topic(text, tokenizer, model, topic_labels): | |
"""Classify topic using BERT-based models.""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other" | |
def analyze_sentiment(text): | |
"""Classifie le sentiment du texte en Darija.""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Tokenizer le texte | |
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
# Prédiction | |
with torch.no_grad(): | |
outputs = sentiment_model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
# Retourner la classe correspondante | |
return sentiment_labels[predicted_class] if predicted_class < len(sentiment_labels) else "Inconnu" | |
# 🔹 Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification") | |
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record") | |
submit_button = gr.Button("Process") | |
transcription_output = gr.Textbox(label="Transcription (Darija)") | |
translation_output = gr.Textbox(label="Translation (English)") | |
darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)") | |
english_topic_output = gr.Textbox(label="English Topic Classification (BERT)") | |
darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)") | |
english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)") | |
sentiment_output = gr.Textbox(label="Sentiment (Darija)") | |
submit_button.click(transcribe_audio, | |
inputs=[audio_input], | |
outputs=[transcription_output, translation_output, | |
darija_topic_output, english_topic_output, | |
darija_keyword_topic_output, english_keyword_topic_output, sentiment_output]) | |
demo.launch() | |