Spaces:
Sleeping
Sleeping
customization
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
from transformers import (
|
5 |
Wav2Vec2ForCTC, Wav2Vec2Processor,
|
6 |
MarianMTModel, MarianTokenizer,
|
7 |
-
BertForSequenceClassification,
|
8 |
)
|
9 |
|
10 |
# Detect device
|
@@ -24,12 +24,13 @@ translation_model = MarianMTModel.from_pretrained(translation_model_name).to(dev
|
|
24 |
# AraBERT for Darija topic classification
|
25 |
arabert_model_name = "aubmindlab/bert-base-arabert"
|
26 |
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
|
27 |
-
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=
|
|
|
28 |
|
29 |
# BERT for English topic classification
|
30 |
bert_model_name = "bert-base-uncased"
|
31 |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
32 |
-
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=
|
33 |
|
34 |
# Libellés en Darija (Arabe et Latin)
|
35 |
darija_topic_labels = [
|
@@ -61,6 +62,31 @@ english_topic_labels = [
|
|
61 |
"Other"
|
62 |
]
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def transcribe_audio(audio):
|
65 |
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
|
66 |
try:
|
@@ -77,14 +103,18 @@ def transcribe_audio(audio):
|
|
77 |
# Translate to English
|
78 |
translation = translate_text(transcription)
|
79 |
|
80 |
-
# Classify topics
|
81 |
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
|
82 |
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
except Exception as e:
|
87 |
-
return f"Error processing audio: {str(e)}", "", "", ""
|
88 |
|
89 |
def translate_text(text):
|
90 |
"""Translate Arabic text to English."""
|
@@ -111,11 +141,15 @@ with gr.Blocks() as demo:
|
|
111 |
|
112 |
transcription_output = gr.Textbox(label="Transcription (Darija)")
|
113 |
translation_output = gr.Textbox(label="Translation (English)")
|
114 |
-
darija_topic_output = gr.Textbox(label="Darija Topic Classification")
|
115 |
-
english_topic_output = gr.Textbox(label="English Topic Classification")
|
|
|
|
|
116 |
|
117 |
submit_button.click(transcribe_audio,
|
118 |
inputs=[audio_input],
|
119 |
-
outputs=[transcription_output, translation_output,
|
|
|
|
|
120 |
|
121 |
demo.launch()
|
|
|
4 |
from transformers import (
|
5 |
Wav2Vec2ForCTC, Wav2Vec2Processor,
|
6 |
MarianMTModel, MarianTokenizer,
|
7 |
+
BertForSequenceClassification, AutoModel, AutoTokenizer
|
8 |
)
|
9 |
|
10 |
# Detect device
|
|
|
24 |
# AraBERT for Darija topic classification
|
25 |
arabert_model_name = "aubmindlab/bert-base-arabert"
|
26 |
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
|
27 |
+
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
|
28 |
+
|
29 |
|
30 |
# BERT for English topic classification
|
31 |
bert_model_name = "bert-base-uncased"
|
32 |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
33 |
+
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
|
34 |
|
35 |
# Libellés en Darija (Arabe et Latin)
|
36 |
darija_topic_labels = [
|
|
|
62 |
"Other"
|
63 |
]
|
64 |
|
65 |
+
# New Function to Classify Topics by Keywords
|
66 |
+
def classify_topic_by_keywords(text, topic_labels):
|
67 |
+
# Dictionnaire de mots-clés pour chaque topic
|
68 |
+
keywords = {
|
69 |
+
"خدمة العملاء": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
|
70 |
+
"خدمة الاحتفاظ": ["احتفاظ", "تجديد", "خصم", "عرض", "العرض"],
|
71 |
+
"مشكلة في الفاتورة": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"]
|
72 |
+
}
|
73 |
+
|
74 |
+
# Convertir le texte en minuscule pour éviter les incohérences
|
75 |
+
text = text.lower()
|
76 |
+
|
77 |
+
# Vérification de la présence des mots-clés dans le texte
|
78 |
+
topic_scores = {label: 0 for label in topic_labels} # Initialiser le score des topics
|
79 |
+
|
80 |
+
for topic, words in keywords.items():
|
81 |
+
for word in words:
|
82 |
+
if word in text:
|
83 |
+
topic_scores[topic] += 1 # Incrémenter le score pour chaque mot trouvé
|
84 |
+
|
85 |
+
# Retourner le topic avec le score le plus élevé
|
86 |
+
best_topic = max(topic_scores, key=topic_scores.get)
|
87 |
+
return best_topic
|
88 |
+
|
89 |
+
|
90 |
def transcribe_audio(audio):
|
91 |
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
|
92 |
try:
|
|
|
103 |
# Translate to English
|
104 |
translation = translate_text(transcription)
|
105 |
|
106 |
+
# Classify topics using BERT models
|
107 |
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
|
108 |
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
|
109 |
|
110 |
+
# Classify topics using keywords-based classification
|
111 |
+
darija_keyword_topic = classify_topic_by_keywords(transcription, darija_topic_labels)
|
112 |
+
english_keyword_topic = classify_topic_by_keywords(translation, english_topic_labels)
|
113 |
+
|
114 |
+
return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic
|
115 |
|
116 |
except Exception as e:
|
117 |
+
return f"Error processing audio: {str(e)}", "", "", "", "", ""
|
118 |
|
119 |
def translate_text(text):
|
120 |
"""Translate Arabic text to English."""
|
|
|
141 |
|
142 |
transcription_output = gr.Textbox(label="Transcription (Darija)")
|
143 |
translation_output = gr.Textbox(label="Translation (English)")
|
144 |
+
darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
|
145 |
+
english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
|
146 |
+
darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
|
147 |
+
english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")
|
148 |
|
149 |
submit_button.click(transcribe_audio,
|
150 |
inputs=[audio_input],
|
151 |
+
outputs=[transcription_output, translation_output,
|
152 |
+
darija_topic_output, english_topic_output,
|
153 |
+
darija_keyword_topic_output, english_keyword_topic_output])
|
154 |
|
155 |
demo.launch()
|