mikr commited on
Commit
7df6e8c
·
1 Parent(s): 86b5f8f
Files changed (2) hide show
  1. app.py +13 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import torch
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, pipeline
4
 
@@ -23,9 +24,20 @@ def transcribe(file_upload):
23
  text = pipe(file)["text"]
24
  return warn_output + text
25
 
 
 
 
 
 
 
 
 
 
 
26
  def transcribe2(file_upload):
 
27
  with torch.inference_mode():
28
- input_values = processor(file_upload, sampling_rate=16000).input_values[0]
29
  input_values = torch.tensor(input_values, device=device).unsqueeze(0)
30
  logits = model(input_values).logits
31
  pred_ids = torch.argmax(logits, dim=-1)
 
1
  import gradio as gr
2
+ import soundfile as sf
3
  import torch
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, pipeline
5
 
 
24
  text = pipe(file)["text"]
25
  return warn_output + text
26
 
27
+
28
+ def readwav(a_f):
29
+ wav, sr = sf.read(a_f, dtype=np.float32)
30
+ if len(wav.shape) == 2:
31
+ wav = wav.mean(1)
32
+ if sr != 16000:
33
+ wlen = int(wav.shape[0] / sr * 16000)
34
+ wav = signal.resample(wav, wlen)
35
+ return wav
36
+
37
  def transcribe2(file_upload):
38
+ wav = readwav(file_upload)
39
  with torch.inference_mode():
40
+ input_values = processor(wav, sampling_rate=16000).input_values[0]
41
  input_values = torch.tensor(input_values, device=device).unsqueeze(0)
42
  logits = model(input_values).logits
43
  pred_ids = torch.argmax(logits, dim=-1)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  git+https://github.com/huggingface/transformers
2
  torch
 
 
1
  git+https://github.com/huggingface/transformers
2
  torch
3
+ soundfile