Mohssinibra commited on
Commit
d04bf8d
·
verified ·
1 Parent(s): d0830a2
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import gradio as gr
2
  import librosa
3
  import torch
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
  # Load pre-trained model and processor directly from Hugging Face Hub
7
  model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
8
  processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
9
 
 
 
 
 
10
  def transcribe_audio(audio):
11
  # Load the audio file from Gradio interface
12
  audio_array, sr = librosa.load(audio, sr=16000)
@@ -20,15 +24,31 @@ def transcribe_audio(audio):
20
  # Get the predicted tokens
21
  tokens = torch.argmax(logits, axis=-1)
22
 
23
- # Decode the tokens into text
24
  transcription = processor.decode(tokens[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- return transcription
27
 
28
  # Create a Gradio interface for uploading audio or recording from the browser
29
  demo = gr.Interface(fn=transcribe_audio,
30
  inputs=gr.Audio(type="filepath"), # Corrected input component
31
- outputs="text")
 
32
 
33
  demo.launch()
34
- demo.launch(api=True,share=True)
 
1
  import gradio as gr
2
  import librosa
3
  import torch
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MBartForConditionalGeneration, MBart50Tokenizer
5
 
6
  # Load pre-trained model and processor directly from Hugging Face Hub
7
  model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
8
  processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
9
 
10
+ # Load translation model
11
+ translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
12
+ translation_tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", src_lang="ar")
13
+
14
  def transcribe_audio(audio):
15
  # Load the audio file from Gradio interface
16
  audio_array, sr = librosa.load(audio, sr=16000)
 
24
  # Get the predicted tokens
25
  tokens = torch.argmax(logits, axis=-1)
26
 
27
+ # Decode the tokens into text (Darija transcription)
28
  transcription = processor.decode(tokens[0])
29
+
30
+ # Translate the transcription to English
31
+ translation = translate_text(transcription)
32
+
33
+ return transcription, translation
34
+
35
+ def translate_text(text):
36
+ # Tokenize the text to translate
37
+ inputs = translation_tokenizer(text, return_tensors="pt")
38
+
39
+ # Generate translated tokens (from Darija to English)
40
+ translated_tokens = translation_model.generate(**inputs, forced_bos_token_id=translation_tokenizer.lang_code_to_id["en"])
41
+
42
+ # Decode the translated tokens into text
43
+ translated_text = translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
44
 
45
+ return translated_text
46
 
47
  # Create a Gradio interface for uploading audio or recording from the browser
48
  demo = gr.Interface(fn=transcribe_audio,
49
  inputs=gr.Audio(type="filepath"), # Corrected input component
50
+ outputs=["text", "text"], # Both transcription and translation outputs
51
+ live=True)
52
 
53
  demo.launch()
54
+ demo.launch(api=True, share=True)