NightPrince commited on
Commit
2726627
·
verified ·
1 Parent(s): c137890

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -2,23 +2,28 @@ import gradio as gr
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import torch
4
  import numpy as np
 
5
 
6
  # Load the pre-trained model and processor
7
  model_name = "facebook/s2t-wav2vec2-large-en-ar"
8
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
9
  processor = Wav2Vec2Processor.from_pretrained(model_name)
10
 
11
- # Define a function for the ASR model
12
  def transcribe(audio):
13
- # Convert the audio into a format compatible with the processor
14
- if isinstance(audio, np.ndarray):
15
- audio = audio.flatten() # Ensure it's a 1D array
16
 
17
- # Process the audio
 
 
 
18
  inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
19
 
20
  # Get the model's predictions
21
- logits = model(input_values=inputs.input_values).logits
 
22
 
23
  # Decode the predicted text
24
  predicted_ids = logits.argmax(dim=-1)
@@ -26,8 +31,13 @@ def transcribe(audio):
26
 
27
  return transcription
28
 
29
- # Define the Gradio interface
30
- interface = gr.Interface(fn=transcribe, inputs=gr.Audio(type="numpy"), outputs="text")
 
 
 
 
 
31
 
32
- # Launch the Gradio interface
33
  interface.launch()
 
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import torch
4
  import numpy as np
5
+ import librosa
6
 
7
  # Load the pre-trained model and processor
8
  model_name = "facebook/s2t-wav2vec2-large-en-ar"
9
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
10
  processor = Wav2Vec2Processor.from_pretrained(model_name)
11
 
12
+ # Function to transcribe audio using the model
13
  def transcribe(audio):
14
+ # Resample the audio to 16kHz if necessary
15
+ if audio.ndim > 1: # If audio is stereo
16
+ audio = audio.mean(axis=1) # Convert to mono
17
 
18
+ # Resample audio to 16kHz if it's not already
19
+ audio = librosa.resample(audio, orig_sr=audio.shape[0] / len(audio), target_sr=16000)
20
+
21
+ # Process the audio to match the model's input format
22
  inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
23
 
24
  # Get the model's predictions
25
+ with torch.no_grad():
26
+ logits = model(input_values=inputs.input_values).logits
27
 
28
  # Decode the predicted text
29
  predicted_ids = logits.argmax(dim=-1)
 
31
 
32
  return transcription
33
 
34
+ # Create the Gradio interface
35
+ interface = gr.Interface(
36
+ fn=transcribe,
37
+ inputs=gr.Audio(type="numpy"), # Take the audio input as numpy array
38
+ outputs="text", # Output transcribed text
39
+ live=True # Optional: live transcribing as you speak
40
+ )
41
 
42
+ # Launch the interface
43
  interface.launch()