GiorgiSekhniashvili commited on
Commit
bb024f6
·
1 Parent(s): eacd0a8

using gradio

Browse files
Files changed (2) hide show
  1. app.py +63 -20
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,31 +1,74 @@
 
 
1
  import gradio as gr
2
- from transformers.pipelines.audio_utils import ffmpeg_read
 
 
 
 
 
 
3
 
4
- from transformers import WhisperForConditionalGeneration, AutoProcessor
 
 
5
 
6
- model_name = "GiorgiSekhniashvili/whisper-tiny-ka-01"
7
 
8
- processor = AutoProcessor.from_pretrained(model_name)
9
- model = WhisperForConditionalGeneration.from_pretrained(model_name)
10
- forced_decoder_ids = processor.get_decoder_prompt_ids(
11
- language="Georgian", task="transcribe"
 
 
 
 
12
  )
13
 
14
 
15
- def predict(audio_path):
16
- if audio_path:
17
- with open(audio_path, "rb") as f:
18
- waveform = ffmpeg_read(f.read(), sampling_rate=16_000)
19
- input_values = processor(waveform, sampling_rate=16_000, return_tensors="pt")
20
- res = model.generate(
21
- input_values["input_features"],
22
- forced_decoder_ids=forced_decoder_ids,
23
- max_new_tokens=448,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
- return processor.batch_decode(res, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
26
 
27
 
28
- mic = gr.Audio(source="microphone", type="filepath", label="Speak here...")
29
- demo = gr.Interface(predict, mic, "text")
 
 
 
 
 
30
 
31
- demo.launch()
 
1
+ from pathlib import Path
2
+
3
  import gradio as gr
4
+ import torch
5
+ import torchaudio
6
+ from transformers import (
7
+ WhisperFeatureExtractor,
8
+ WhisperForConditionalGeneration,
9
+ WhisperTokenizerFast,
10
+ )
11
 
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
14
+ SAMPLING_RATE = 16_000
15
 
 
16
 
17
+ model = WhisperForConditionalGeneration.from_pretrained(
18
+ "../data/jobs/whisper-tiny-ka-09", torch_dtype=DTYPE
19
+ )
20
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
21
+ tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny")
22
+
23
+ forced_decoder_ids = tokenizer.get_decoder_prompt_ids(
24
+ language="georgian", task="transcribe"
25
  )
26
 
27
 
28
+ def load_audio(audio_path: Path, target_sr: int):
29
+ waveform, sr = torchaudio.load(audio_path, backend="soundfile")
30
+
31
+ if waveform.shape[0] > 1:
32
+ waveform = waveform.mean(dim=0, keepdim=True)
33
+
34
+ if sr != target_sr:
35
+ waveform = torchaudio.functional.resample(
36
+ waveform, orig_freq=sr, new_freq=target_sr
37
+ )
38
+
39
+ return waveform
40
+
41
+
42
+ model.generation_config.forced_decoder_ids = forced_decoder_ids
43
+ model.to(DEVICE)
44
+
45
+
46
+ def transcribe(audio):
47
+ try:
48
+ waveform = load_audio(audio, target_sr=SAMPLING_RATE)
49
+ except Exception as e:
50
+ return str(e)
51
+
52
+ input_values = feature_extractor(
53
+ waveform[0], sampling_rate=SAMPLING_RATE, return_tensors="pt"
54
  )
55
+ input_features = input_values.input_features.to(DEVICE, dtype=DTYPE)
56
+ with torch.no_grad():
57
+ outputs = model.generate(
58
+ input_features,
59
+ forced_decoder_ids=forced_decoder_ids,
60
+ max_new_tokens=444,
61
+ )
62
+ transcriptions = tokenizer.batch_decode(outputs, skip_special_tokens=False)
63
+ return transcriptions[0]
64
 
65
 
66
+ iface = gr.Interface(
67
+ fn=transcribe,
68
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
69
+ outputs="text",
70
+ title="Whisper Geo",
71
+ description="Realtime demo for Georgian speech recognition using a fine-tuned Whisper model.",
72
+ )
73
 
74
+ iface.launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers
2
  torch
3
  torchvision
4
- torchaudio
 
 
1
  transformers
2
  torch
3
  torchvision
4
+ torchaudio
5
+ gradio