Mohssinibra commited on
Commit
f2ecb6e
·
verified ·
1 Parent(s): 9078685

classificationV1

Browse files
Files changed (1) hide show
  1. app.py +43 -11
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import librosa
3
  import torch
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer
5
 
6
  # Charger le modèle de transcription pour le Darija
7
  model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
@@ -12,23 +12,39 @@ translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
12
  translation_model = MarianMTModel.from_pretrained(translation_model_name)
13
  translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def transcribe_audio(audio):
16
- """Convertir l'audio en texte et le traduire en anglais"""
17
- # Charger et prétraiter l'audio
18
  audio_array, sr = librosa.load(audio, sr=16000)
19
  input_values = processor(audio_array, return_tensors="pt", padding=True).input_values
20
 
21
- # Obtenir les prédictions du modèle
22
  logits = model(input_values).logits
23
  tokens = torch.argmax(logits, axis=-1)
24
 
25
- # Décoder la transcription en Darija
26
  transcription = processor.decode(tokens[0])
27
-
28
- # Traduire en anglais
29
  translation = translate_text(transcription)
30
 
31
- return transcription, translation
 
 
 
 
 
32
 
33
  def translate_text(text):
34
  """Traduire le texte de l'arabe vers l'anglais"""
@@ -37,15 +53,31 @@ def translate_text(text):
37
  translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
38
  return translated_text
39
 
 
 
 
 
 
 
 
 
 
 
40
  # Interface utilisateur avec Gradio
41
  with gr.Blocks() as demo:
42
- gr.Markdown("# 🎙️ Speech-to-Text & Translation")
43
 
44
  audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
45
- submit_button = gr.Button("Transcribe & Translate")
 
46
  transcription_output = gr.Textbox(label="Transcription (Darija)")
47
  translation_output = gr.Textbox(label="Translation (English)")
 
 
48
 
49
- submit_button.click(transcribe_audio, inputs=[audio_input], outputs=[transcription_output, translation_output])
 
 
50
 
51
  demo.launch()
 
 
1
  import gradio as gr
2
  import librosa
3
  import torch
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer
5
 
6
  # Charger le modèle de transcription pour le Darija
7
  model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
 
12
  translation_model = MarianMTModel.from_pretrained(translation_model_name)
13
  translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
14
 
15
+
16
+
17
+ # Load AraBERT for Darija topic classification
18
+ arabert_model_name = "aubmindlab/bert-base-arabert"
19
+ arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
20
+ arabert_model = AutoModel.from_pretrained(arabert_model_name)
21
+
22
+ # Load BERT for English topic classification
23
+ bert_model_name = "bert-base-uncased"
24
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
25
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3) # Adjust labels as needed
26
+
27
+ darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"] # Adjust for Darija topics
28
+ english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
29
+
30
+
31
  def transcribe_audio(audio):
32
+ """Convert audio to text, translate it, and classify topics in both Darija and English"""
 
33
  audio_array, sr = librosa.load(audio, sr=16000)
34
  input_values = processor(audio_array, return_tensors="pt", padding=True).input_values
35
 
 
36
  logits = model(input_values).logits
37
  tokens = torch.argmax(logits, axis=-1)
38
 
 
39
  transcription = processor.decode(tokens[0])
 
 
40
  translation = translate_text(transcription)
41
 
42
+ # Classify topics for both Darija and English
43
+ darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
44
+ english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
45
+
46
+ return transcription, translation, darija_topic, english_topic
47
+
48
 
49
  def translate_text(text):
50
  """Traduire le texte de l'arabe vers l'anglais"""
 
53
  translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
54
  return translated_text
55
 
56
+ def classify_topic(text, tokenizer, model, topic_labels):
57
+ """Classify topic using BERT-based models"""
58
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
62
+
63
+ return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
64
+
65
+
66
  # Interface utilisateur avec Gradio
67
  with gr.Blocks() as demo:
68
+ gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
69
 
70
  audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
71
+ submit_button = gr.Button("Process")
72
+
73
  transcription_output = gr.Textbox(label="Transcription (Darija)")
74
  translation_output = gr.Textbox(label="Translation (English)")
75
+ darija_topic_output = gr.Textbox(label="Darija Topic Classification")
76
+ english_topic_output = gr.Textbox(label="English Topic Classification")
77
 
78
+ submit_button.click(transcribe_audio,
79
+ inputs=[audio_input],
80
+ outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
81
 
82
  demo.launch()
83
+