233-Yorozuya commited on
Commit
01bac95
·
verified ·
1 Parent(s): 5b90a2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import streamlit as st
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torchaudio
 
4
 
5
  # Load the model
6
  @st.cache_resource
7
  def load_model():
8
- processor = WhisperProcessor.from_pretrained("233-Yorozuya/dl_whisper_model", use_auth_token=True)
9
- model = WhisperForConditionalGeneration.from_pretrained("233-Yorozuya/dl_whisper_model", use_auth_token=True)
10
-
11
  return processor, model
12
 
13
  processor, model = load_model()
@@ -19,17 +19,26 @@ st.write("Upload an audio file for transcription:")
19
  audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
20
 
21
  if audio_file:
22
- # Load and preprocess audio
23
- audio, rate = torchaudio.load(audio_file)
24
- audio = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)(audio)
25
- inputs = processor(audio[0].numpy(), sampling_rate=16000, return_tensors="pt")
26
-
27
- # Perform inference
28
- with st.spinner("Transcribing..."):
29
- predicted_ids = model.generate(inputs.input_features)
30
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
31
-
32
- # Display result
33
- st.subheader("Transcription")
34
- st.write(transcription)
35
-
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torchaudio
4
+ from io import BytesIO
5
 
6
  # Load the model
7
  @st.cache_resource
8
  def load_model():
9
+ processor = WhisperProcessor.from_pretrained("233-Yorozuya/dl_twi_asr")
10
+ model = WhisperForConditionalGeneration.from_pretrained("233-Yorozuya/dl_twi_asr")
 
11
  return processor, model
12
 
13
  processor, model = load_model()
 
19
  audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
20
 
21
  if audio_file:
22
+ try:
23
+ # Convert uploaded file to bytes
24
+ audio_bytes = BytesIO(audio_file.read())
25
+ audio, rate = torchaudio.load(audio_bytes)
26
+ audio = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)(audio)
27
+
28
+ # Preprocess the audio
29
+ inputs = processor(audio[0].numpy(), sampling_rate=16000, return_tensors="pt")
30
+
31
+ # Specify the language (Asanti Twi)
32
+ model.config.forced_decoder_ids = None # Disable forced language
33
+
34
+
35
+ # Perform inference
36
+ with st.spinner("Transcribing..."):
37
+ predicted_ids = model.generate(inputs.input_features)
38
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
39
+
40
+ # Display result
41
+ st.subheader("Transcription")
42
+ st.write(transcription)
43
+ except Exception as e:
44
+ st.error(f"An error occurred: {e}")