jiuuee commited on
Commit
6160888
·
verified ·
1 Parent(s): c5a564e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -145
app.py CHANGED
@@ -5,161 +5,131 @@ import os
5
  import soundfile as sf
6
  import tempfile
7
  import uuid
8
-
9
  import torch
10
-
11
  from nemo.collections.asr.models import ASRModel
12
  from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
13
  from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
 
 
14
 
15
- SAMPLE_RATE = 16000 # Hz
16
- MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
17
-
18
- model = ASRModel.from_pretrained("nvidia/canary-1b")
19
- model.eval()
20
 
21
- # make sure beam size always 1 for consistency
22
- model.change_decoding_strategy(None)
23
- decoding_cfg = model.cfg.decoding
 
 
24
  decoding_cfg.beam.beam_size = 1
25
- model.change_decoding_strategy(decoding_cfg)
26
-
27
- # setup for buffered inference
28
- model.cfg.preprocessor.dither = 0.0
29
- model.cfg.preprocessor.pad_to = 0
30
-
31
- feature_stride = model.cfg.preprocessor['window_stride']
32
- model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
33
-
34
  frame_asr = FrameBatchMultiTaskAED(
35
- asr_model=model,
36
- frame_len=40.0,
37
- total_buffer=40.0,
38
- batch_size=16,
39
  )
40
 
41
- amp_dtype = torch.float16
42
-
43
- def convert_audio(audio_filepath, tmpdir, utt_id):
44
-
45
- data, sr = librosa.load(audio_filepath, sr=None, mono=True)
46
-
47
- duration = librosa.get_duration(y=data, sr=sr)
48
-
49
- if duration / 60.0 > MAX_AUDIO_MINUTES:
50
- raise gr.Error(
51
- f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
52
- "If you wish, you may trim the audio using the Audio viewer in Step 1 "
53
- "(click on the scissors icon to start trimming audio)."
54
- )
55
-
56
- if sr != SAMPLE_RATE:
57
- data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
58
-
59
- out_filename = os.path.join(tmpdir, utt_id + '.wav')
60
-
61
- # save output audio
62
- sf.write(out_filename, data, SAMPLE_RATE)
63
-
64
- return out_filename, duration
65
 
 
 
 
66
 
 
67
  def transcribe(audio_filepath):
68
-
69
- if audio_filepath is None:
70
- raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
71
-
72
- utt_id = uuid.uuid4()
73
- with tempfile.TemporaryDirectory() as tmpdir:
74
- converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
75
-
76
- # make manifest file and save
77
- manifest_data = {
78
- "audio_filepath": converted_audio_filepath,
79
- "source_lang": "en",
80
- "target_lang": "en",
81
- "taskname": "asr",
82
- "pnc": "no",
83
- "answer": "predict",
84
- "duration": str(duration),
85
- }
86
-
87
- manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
88
-
89
- with open(manifest_filepath, 'w') as fout:
90
- line = json.dumps(manifest_data)
91
- fout.write(line + '\n')
92
-
93
- # call transcribe, passing in manifest filepath
94
- if duration < 40:
95
- output_text = model.transcribe(manifest_filepath)[0]
96
- else: # do buffered inference
97
- with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
98
- with torch.no_grad():
99
- hyps = get_buffered_pred_feat_multitaskAED(
100
- frame_asr,
101
- model.cfg.preprocessor,
102
- model_stride_in_secs,
103
- model.device,
104
- manifest=manifest_filepath,
105
- filepaths=None,
106
- )
107
-
108
- output_text = hyps[0].text
109
-
110
- return output_text
111
-
112
- with gr.Blocks(
113
- title="NeMo Canary Model",
114
- css="""
115
- textarea { font-size: 18px;}
116
- #model_output_text_box span {
117
- font-size: 18px;
118
- font-weight: bold;
119
- }
120
- """,
121
- theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
122
- ) as demo:
123
-
124
- gr.HTML("<h1 style='text-align: center'>NeMo Canary model: Transcribe & Translate audio</h1>")
125
-
126
- with gr.Row():
127
- with gr.Column():
128
- gr.HTML(
129
- "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
130
-
131
- "<p style='color: #A0A0A0;'>This demo supports audio files up to 10 mins long. "
132
- "You can transcribe longer files locally with this NeMo "
133
- "<a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/speech_multitask/speech_to_text_aed_chunked_infer.py'>script</a>.</p>"
134
- )
135
-
136
- audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
137
-
138
-
139
-
140
- with gr.Column():
141
-
142
- gr.HTML("<p><b>Step 2:</b> Run the model.</p>")
143
-
144
- go_button = gr.Button(
145
- value="Run model",
146
- variant="primary", # make "primary" so it stands out (default is "secondary")
147
- )
148
-
149
- model_output_text_box = gr.Textbox(
150
- label="Model Output",
151
- elem_id="model_output_text_box",
152
- )
153
-
154
-
155
- go_button.click(
156
- fn=transcribe,
157
- inputs = [audio_file],
158
- outputs = [model_output_text_box]
159
- )
160
-
161
-
162
- demo.queue()
163
- demo.launch()
164
-
165
-
 
5
  import soundfile as sf
6
  import tempfile
7
  import uuid
 
8
  import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
  from nemo.collections.asr.models import ASRModel
11
  from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
12
  from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
13
+ from transformers import VitsTokenizer, VitsModel, set_seed
14
+ import scipy.io.wavfile as wav
15
 
16
+ # Constants
17
+ SAMPLE_RATE = 16000 # Hz
 
 
 
18
 
19
+ # Load ASR model
20
+ asr_model = ASRModel.from_pretrained("nvidia/canary-1b")
21
+ asr_model.eval()
22
+ asr_model.change_decoding_strategy(None)
23
+ decoding_cfg = asr_model.cfg.decoding
24
  decoding_cfg.beam.beam_size = 1
25
+ asr_model.change_decoding_strategy(decoding_cfg)
26
+ feature_stride = asr_model.cfg.preprocessor['window_stride']
27
+ model_stride_in_secs = feature_stride * 8
 
 
 
 
 
 
28
  frame_asr = FrameBatchMultiTaskAED(
29
+ asr_model=asr_model,
30
+ frame_len=40.0,
31
+ total_buffer=40.0,
32
+ batch_size=16,
33
  )
34
 
35
+ # Load LLM model
36
+ torch.random.manual_seed(0)
37
+ llm_model = AutoModelForCausalLM.from_pretrained(
38
+ "microsoft/Phi-3-mini-128k-instruct",
39
+ device_map="auto",
40
+ torch_dtype="auto",
41
+ trust_remote_code=True,
42
+ )
43
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
44
+ pipe = pipeline("text-generation", model=llm_model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Load TTS model
47
+ tts_tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
48
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
49
 
50
+ # Function to convert audio to text using ASR
51
  def transcribe(audio_filepath):
52
+ if audio_filepath is None:
53
+ raise gr.InterfaceError("Please provide some input audio.")
54
+
55
+ utt_id = uuid.uuid4()
56
+ with tempfile.TemporaryDirectory() as tmpdir:
57
+ # Convert to 16 kHz
58
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
59
+ if sr != SAMPLE_RATE:
60
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
61
+ converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
62
+ sf.write(converted_audio_filepath, data, SAMPLE_RATE)
63
+
64
+ # Transcribe audio
65
+ duration = len(data) / SAMPLE_RATE
66
+ manifest_data = {
67
+ "audio_filepath": converted_audio_filepath,
68
+ "source_lang": "en",
69
+ "target_lang": "en",
70
+ "taskname": "asr",
71
+ "pnc": "no",
72
+ "answer": "predict",
73
+ "duration": str(duration),
74
+ }
75
+ manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
76
+ with open(manifest_filepath, 'w') as fout:
77
+ fout.write(json.dumps(manifest_data))
78
+
79
+ if duration < 40:
80
+ transcription = asr_model.transcribe(manifest_filepath)[0]
81
+ else:
82
+ transcription = get_buffered_pred_feat_multitaskAED(
83
+ frame_asr,
84
+ asr_model.cfg.preprocessor,
85
+ model_stride_in_secs,
86
+ asr_model.device,
87
+ manifest=manifest_filepath,
88
+ )[0].text
89
+
90
+ return transcription
91
+
92
+ # Function to generate text using LLM
93
+ def generate_text(input_text):
94
+ generation_args = {
95
+ "max_new_tokens": 500,
96
+ "return_full_text": True,
97
+ "temperature": 0.0,
98
+ "do_sample": False,
99
+ }
100
+ generated_text = pipe(
101
+ [{"role": "user", "content": input_text}],
102
+ **generation_args
103
+ )[0]["generated_text"]
104
+ return generated_text
105
+
106
+ # Function to convert text to speech using TTS
107
+ def gen_speech(text):
108
+ set_seed(555) # Make it deterministic
109
+ input_text = tts_tokenizer(text, return_tensors="pt")
110
+ with torch.no_grad():
111
+ outputs = tts_model(**input_text)
112
+ waveform_np = outputs.waveform[0].cpu().numpy()
113
+ output_file = f"{str(uuid.uuid4())}.wav"
114
+ wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
115
+ return output_file
116
+
117
+ # Combined function for Gradio interface
118
+ def process_audio(audio_filepath):
119
+ transcription = transcribe(audio_filepath)
120
+ generated_text = generate_text(transcription)
121
+ audio_output_filepath = gen_speech(generated_text)
122
+ return transcription, generated_text, audio_output_filepath
123
+
124
+ # Create Gradio interface
125
+ gr.Interface(
126
+ fn=process_audio,
127
+ inputs=[gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")],
128
+ outputs=[
129
+ gr.Textbox(label="Transcription"),
130
+ gr.Textbox(label="Generated Text"),
131
+ gr.Audio(type="filepath", label="Generated Speech")
132
+ ],
133
+ title="ASR to LLM to TTS",
134
+ description="Transcribe audio with ASR, generate text with LLM, and convert it back to speech with TTS."
135
+ ).launch(inbrowser=True)