NightPrince commited on
Commit
98c9824
·
verified ·
1 Parent(s): ee196d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import torch
 
4
 
5
  # Load the pre-trained model and processor
6
  model_name = "facebook/s2t-wav2vec2-large-en-ar"
@@ -9,13 +10,20 @@ processor = Wav2Vec2Processor.from_pretrained(model_name)
9
 
10
  # Define a function for the ASR model
11
  def transcribe(audio):
 
 
 
 
12
  # Process the audio
13
  inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
 
14
  # Get the model's predictions
15
  logits = model(input_values=inputs.input_values).logits
 
16
  # Decode the predicted text
17
  predicted_ids = logits.argmax(dim=-1)
18
  transcription = processor.decode(predicted_ids[0])
 
19
  return transcription
20
 
21
  # Define the Gradio interface
 
1
  import gradio as gr
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import torch
4
+ import numpy as np
5
 
6
  # Load the pre-trained model and processor
7
  model_name = "facebook/s2t-wav2vec2-large-en-ar"
 
10
 
11
  # Define a function for the ASR model
12
  def transcribe(audio):
13
+ # Convert the audio into a format compatible with the processor
14
+ if isinstance(audio, np.ndarray):
15
+ audio = audio.flatten() # Ensure it's a 1D array
16
+
17
  # Process the audio
18
  inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
19
+
20
  # Get the model's predictions
21
  logits = model(input_values=inputs.input_values).logits
22
+
23
  # Decode the predicted text
24
  predicted_ids = logits.argmax(dim=-1)
25
  transcription = processor.decode(predicted_ids[0])
26
+
27
  return transcription
28
 
29
  # Define the Gradio interface