Mohssinibra commited on
Commit
2785b50
·
verified ·
1 Parent(s): 5f09d24
Files changed (1) hide show
  1. app.py +38 -47
app.py CHANGED
@@ -1,90 +1,82 @@
1
  import gradio as gr
2
  import librosa
3
  import torch
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer
5
-
6
- # Charger le modèle de transcription pour le Darija
7
- model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
8
- processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
9
-
10
- # Charger le modèle de traduction Arabe -> Anglais
 
 
 
 
 
 
 
 
 
11
  translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
12
- translation_model = MarianMTModel.from_pretrained(translation_model_name)
13
  translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
 
14
 
15
-
16
-
17
- # Load AraBERT for Darija topic classification
18
  arabert_model_name = "aubmindlab/bert-base-arabert"
19
  arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
20
- arabert_model = AutoModel.from_pretrained(arabert_model_name)
21
 
22
- # Load BERT for English topic classification
23
  bert_model_name = "bert-base-uncased"
24
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
25
- bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=2) # Adjust labels as needed
26
 
27
- darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"] # Adjust for Darija topics
28
- english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
29
-
30
-
31
- import torch
32
 
33
  def transcribe_audio(audio):
34
- """Convert audio to text, translate it, and classify topics in both Darija and English"""
35
  try:
36
  # Load and preprocess audio
37
  audio_array, sr = librosa.load(audio, sr=16000)
38
-
39
- # Ensure correct sampling rate
40
- input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values
41
 
42
- # Move to GPU if available
43
- device = "cuda" if torch.cuda.is_available() else "cpu"
44
- model.to(device)
45
- input_values = input_values.to(device)
46
-
47
- # Get predictions from Wav2Vec2 model
48
  with torch.no_grad():
49
- logits = model(input_values).logits
50
  tokens = torch.argmax(logits, axis=-1)
51
-
52
- # Decode transcription (Darija)
53
  transcription = processor.decode(tokens[0])
54
 
55
  # Translate to English
56
  translation = translate_text(transcription)
57
 
58
- # Classify topics for Darija and English
59
  darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
60
  english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
61
 
62
  return transcription, translation, darija_topic, english_topic
63
 
64
  except Exception as e:
65
- print(f"Error in transcription: {e}")
66
- return "Error processing audio", "", "", ""
67
-
68
-
69
 
70
  def translate_text(text):
71
- """Traduire le texte de l'arabe vers l'anglais"""
72
- inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
73
- translated_tokens = translation_model.generate(**inputs)
74
- translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
75
- return translated_text
76
 
77
  def classify_topic(text, tokenizer, model, topic_labels):
78
- """Classify topic using BERT-based models"""
79
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
80
  with torch.no_grad():
81
  outputs = model(**inputs)
82
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
83
 
84
  return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
85
 
86
-
87
- # Interface utilisateur avec Gradio
88
  with gr.Blocks() as demo:
89
  gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
90
 
@@ -101,4 +93,3 @@ with gr.Blocks() as demo:
101
  outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
102
 
103
  demo.launch()
104
-
 
1
  import gradio as gr
2
  import librosa
3
  import torch
4
+ from transformers import (
5
+ Wav2Vec2ForCTC, Wav2Vec2Processor,
6
+ MarianMTModel, MarianTokenizer,
7
+ BertForSequenceClassification, AutoModel, AutoTokenizer
8
+ )
9
+
10
+ # Detect device
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ ### 🔹 Load Models & Tokenizers Once ###
14
+ # Wav2Vec2 for Darija transcription
15
+ wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
16
+ processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
17
+ wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)
18
+
19
+ # MarianMT for translation (Arabic → English)
20
  translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
 
21
  translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
22
+ translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
23
 
24
+ # AraBERT for Darija topic classification
 
 
25
  arabert_model_name = "aubmindlab/bert-base-arabert"
26
  arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
27
+ arabert_model = AutoModel.from_pretrained(arabert_model_name).to(device)
28
 
29
+ # BERT for English topic classification
30
  bert_model_name = "bert-base-uncased"
31
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
32
+ bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
33
 
34
+ # Define Topic Labels
35
+ darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"]
36
+ english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"]
 
 
37
 
38
  def transcribe_audio(audio):
39
+ """Convert audio to text, translate it, and classify topics in both Darija and English."""
40
  try:
41
  # Load and preprocess audio
42
  audio_array, sr = librosa.load(audio, sr=16000)
43
+ input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
 
 
44
 
45
+ # Transcription (Darija)
 
 
 
 
 
46
  with torch.no_grad():
47
+ logits = wav2vec_model(input_values).logits
48
  tokens = torch.argmax(logits, axis=-1)
 
 
49
  transcription = processor.decode(tokens[0])
50
 
51
  # Translate to English
52
  translation = translate_text(transcription)
53
 
54
+ # Classify topics
55
  darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
56
  english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
57
 
58
  return transcription, translation, darija_topic, english_topic
59
 
60
  except Exception as e:
61
+ return f"Error processing audio: {str(e)}", "", "", ""
 
 
 
62
 
63
  def translate_text(text):
64
+ """Translate Arabic text to English."""
65
+ inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
66
+ with torch.no_grad():
67
+ translated_tokens = translation_model.generate(**inputs)
68
+ return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
69
 
70
  def classify_topic(text, tokenizer, model, topic_labels):
71
+ """Classify topic using BERT-based models."""
72
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
73
  with torch.no_grad():
74
  outputs = model(**inputs)
75
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
76
 
77
  return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
78
 
79
+ # 🔹 Gradio Interface
 
80
  with gr.Blocks() as demo:
81
  gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
82
 
 
93
  outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
94
 
95
  demo.launch()