GiftMark commited on
Commit
2f6398c
·
verified ·
1 Parent(s): 8ce026d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +12 -8
  2. requirements.txt +2 -0
app.py CHANGED
@@ -2,7 +2,9 @@ import gradio as gr
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
  import numpy as np
 
5
 
 
6
  # Load your model (use_auth_token for private models)
7
  import os
8
  hf_token = os.getenv("HF_TOKEN")
@@ -14,16 +16,19 @@ def transcribe(audio):
14
  try:
15
  if audio is None:
16
  return "No audio provided."
17
- # Gradio gives audio as (sampling_rate, data)
18
  sampling_rate, data = audio
19
- # Whisper expects float32 numpy arrays
20
- if not isinstance(data, np.ndarray):
21
- data = np.array(data)
22
- data = data.astype(np.float32)
23
-
24
- # Some Gradio versions return mono, others stereo; Whisper needs 1D
25
  if len(data.shape) > 1:
26
  data = data[:, 0]
 
 
 
 
 
 
 
 
27
 
28
  inputs = processor(
29
  data, sampling_rate=sampling_rate, return_tensors="pt"
@@ -34,7 +39,6 @@ def transcribe(audio):
34
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
35
  return transcription
36
  except Exception as e:
37
- # Print error to logs and return to user
38
  print("Error during transcription:", e)
39
  return f"Error: {e}"
40
 
 
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
  import numpy as np
5
+ from scipy.signal import resample
6
 
7
+ # app.py
8
  # Load your model (use_auth_token for private models)
9
  import os
10
  hf_token = os.getenv("HF_TOKEN")
 
16
  try:
17
  if audio is None:
18
  return "No audio provided."
 
19
  sampling_rate, data = audio
20
+ data = np.array(data).astype(np.float32)
21
+ # Ensure mono
 
 
 
 
22
  if len(data.shape) > 1:
23
  data = data[:, 0]
24
+ # Resample if needed
25
+ target_sr = 16000
26
+ if sampling_rate != target_sr:
27
+ # Calculate number of samples after resampling
28
+ duration = data.shape[0] / sampling_rate
29
+ new_length = int(duration * target_sr)
30
+ data = resample(data, new_length)
31
+ sampling_rate = target_sr
32
 
33
  inputs = processor(
34
  data, sampling_rate=sampling_rate, return_tensors="pt"
 
39
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
40
  return transcription
41
  except Exception as e:
 
42
  print("Error during transcription:", e)
43
  return f"Error: {e}"
44
 
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  transformers
2
  torch
 
 
 
1
  transformers
2
  torch
3
+ scipy
4
+