Mohssinibra commited on
Commit
4a98e32
·
verified ·
1 Parent(s): 6349c25

customization

Browse files
Files changed (1) hide show
  1. app.py +43 -9
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from transformers import (
5
  Wav2Vec2ForCTC, Wav2Vec2Processor,
6
  MarianMTModel, MarianTokenizer,
7
- BertForSequenceClassification, AutoTokenizer, AutoModel
8
  )
9
 
10
  # Detect device
@@ -24,12 +24,13 @@ translation_model = MarianMTModel.from_pretrained(translation_model_name).to(dev
24
  # AraBERT for Darija topic classification
25
  arabert_model_name = "aubmindlab/bert-base-arabert"
26
  arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
27
- arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=11).to(device) # Adjusted to 11 labels for Darija
 
28
 
29
  # BERT for English topic classification
30
  bert_model_name = "bert-base-uncased"
31
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
32
- bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=11).to(device) # Adjusted to 11 labels for English
33
 
34
  # Libellés en Darija (Arabe et Latin)
35
  darija_topic_labels = [
@@ -61,6 +62,31 @@ english_topic_labels = [
61
  "Other"
62
  ]
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def transcribe_audio(audio):
65
  """Convert audio to text, translate it, and classify topics in both Darija and English."""
66
  try:
@@ -77,14 +103,18 @@ def transcribe_audio(audio):
77
  # Translate to English
78
  translation = translate_text(transcription)
79
 
80
- # Classify topics
81
  darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
82
  english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
83
 
84
- return transcription, translation, darija_topic, english_topic
 
 
 
 
85
 
86
  except Exception as e:
87
- return f"Error processing audio: {str(e)}", "", "", ""
88
 
89
  def translate_text(text):
90
  """Translate Arabic text to English."""
@@ -111,11 +141,15 @@ with gr.Blocks() as demo:
111
 
112
  transcription_output = gr.Textbox(label="Transcription (Darija)")
113
  translation_output = gr.Textbox(label="Translation (English)")
114
- darija_topic_output = gr.Textbox(label="Darija Topic Classification")
115
- english_topic_output = gr.Textbox(label="English Topic Classification")
 
 
116
 
117
  submit_button.click(transcribe_audio,
118
  inputs=[audio_input],
119
- outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
 
 
120
 
121
  demo.launch()
 
4
  from transformers import (
5
  Wav2Vec2ForCTC, Wav2Vec2Processor,
6
  MarianMTModel, MarianTokenizer,
7
+ BertForSequenceClassification, AutoModel, AutoTokenizer
8
  )
9
 
10
  # Detect device
 
24
  # AraBERT for Darija topic classification
25
  arabert_model_name = "aubmindlab/bert-base-arabert"
26
  arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
27
+ arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
28
+
29
 
30
  # BERT for English topic classification
31
  bert_model_name = "bert-base-uncased"
32
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
33
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
34
 
35
  # Libellés en Darija (Arabe et Latin)
36
  darija_topic_labels = [
 
62
  "Other"
63
  ]
64
 
65
+ # New Function to Classify Topics by Keywords
66
+ def classify_topic_by_keywords(text, topic_labels):
67
+ # Dictionnaire de mots-clés pour chaque topic
68
+ keywords = {
69
+ "خدمة العملاء": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
70
+ "خدمة الاحتفاظ": ["احتفاظ", "تجديد", "خصم", "عرض", "العرض"],
71
+ "مشكلة في الفاتورة": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"]
72
+ }
73
+
74
+ # Convertir le texte en minuscule pour éviter les incohérences
75
+ text = text.lower()
76
+
77
+ # Vérification de la présence des mots-clés dans le texte
78
+ topic_scores = {label: 0 for label in topic_labels} # Initialiser le score des topics
79
+
80
+ for topic, words in keywords.items():
81
+ for word in words:
82
+ if word in text:
83
+ topic_scores[topic] += 1 # Incrémenter le score pour chaque mot trouvé
84
+
85
+ # Retourner le topic avec le score le plus élevé
86
+ best_topic = max(topic_scores, key=topic_scores.get)
87
+ return best_topic
88
+
89
+
90
  def transcribe_audio(audio):
91
  """Convert audio to text, translate it, and classify topics in both Darija and English."""
92
  try:
 
103
  # Translate to English
104
  translation = translate_text(transcription)
105
 
106
+ # Classify topics using BERT models
107
  darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
108
  english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
109
 
110
+ # Classify topics using keywords-based classification
111
+ darija_keyword_topic = classify_topic_by_keywords(transcription, darija_topic_labels)
112
+ english_keyword_topic = classify_topic_by_keywords(translation, english_topic_labels)
113
+
114
+ return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic
115
 
116
  except Exception as e:
117
+ return f"Error processing audio: {str(e)}", "", "", "", "", ""
118
 
119
  def translate_text(text):
120
  """Translate Arabic text to English."""
 
141
 
142
  transcription_output = gr.Textbox(label="Transcription (Darija)")
143
  translation_output = gr.Textbox(label="Translation (English)")
144
+ darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
145
+ english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
146
+ darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
147
+ english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")
148
 
149
  submit_button.click(transcribe_audio,
150
  inputs=[audio_input],
151
+ outputs=[transcription_output, translation_output,
152
+ darija_topic_output, english_topic_output,
153
+ darija_keyword_topic_output, english_keyword_topic_output])
154
 
155
  demo.launch()