matsuap commited on
Commit
c339523
·
verified ·
1 Parent(s): 0ebb77f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +427 -427
app.py CHANGED
@@ -1,427 +1,427 @@
1
- import os
2
- import io
3
- import requests
4
- import argparse
5
- import asyncio
6
- import numpy as np
7
- import ffmpeg
8
- from time import time
9
-
10
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
11
- from fastapi.responses import HTMLResponse
12
- from fastapi.middleware.cors import CORSMiddleware
13
-
14
- from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
15
-
16
- import logging
17
- import logging.config
18
- from transformers import pipeline
19
- from huggingface_hub import login
20
-
21
- HUGGING_FACE_TOKEN = os.environ['HUGGING_FACE_TOKEN']
22
- login(HUGGING_FACE_TOKEN)
23
-
24
- # os.environ['HF_HOME'] = './.cache'
25
-
26
- MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja'
27
- TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda')
28
- TRANSLATOR('Warming up!')
29
-
30
- def translator_wrapper(source_text, translation_target_lang, mode):
31
- if mode == 'deepl':
32
- params = {
33
- 'auth_key' : os.environ['DEEPL_API_KEY'],
34
- 'text' : source_text,
35
- 'source_lang' : 'EN', # 翻訳対象の言語
36
- "target_lang": 'JA', # 翻訳後の言語
37
- }
38
-
39
- # リクエストを投げる
40
- try:
41
- request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) # URIは有償版, 無償版で異なるため要注意
42
- result = request.json()['translations'][0]['text']
43
- except requests.exceptions.Timeout:
44
- result = "(timed out)"
45
- return result
46
-
47
- elif mode == 'marianmt':
48
- return TRANSLATOR(source_text)[0]['translation_text']
49
-
50
- elif mode == 'google':
51
- import requests
52
-
53
- # https://www.eyoucms.com/news/ziliao/other/29445.html
54
- language_type = ""
55
- url = "https://translation.googleapis.com/language/translate/v2"
56
- data = {
57
- 'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", # 認証の設定:APIキー
58
- 'source': language_type,
59
- 'target': translation_target_lang,
60
- 'q': source_text,
61
- 'format': "text"
62
- }
63
- #headers = {'X-HTTP-Method-Override': 'GET'}
64
- #response = requests.post(url, data=data, headers=headers)
65
- response = requests.post(url, data)
66
- # print(response.json())
67
- print(response)
68
- res = response.json()
69
- print(res["data"]["translations"][0]["translatedText"])
70
- result = res["data"]["translations"][0]["translatedText"]
71
- print(result)
72
- return result
73
-
74
-
75
- def setup_logging():
76
- logging_config = {
77
- 'version': 1,
78
- 'disable_existing_loggers': False,
79
- 'formatters': {
80
- 'standard': {
81
- 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
82
- },
83
- },
84
- 'handlers': {
85
- 'console': {
86
- 'level': 'INFO',
87
- 'class': 'logging.StreamHandler',
88
- 'formatter': 'standard',
89
- },
90
- },
91
- 'root': {
92
- 'handlers': ['console'],
93
- 'level': 'DEBUG',
94
- },
95
- 'loggers': {
96
- 'uvicorn': {
97
- 'handlers': ['console'],
98
- 'level': 'INFO',
99
- 'propagate': False,
100
- },
101
- 'uvicorn.error': {
102
- 'level': 'INFO',
103
- },
104
- 'uvicorn.access': {
105
- 'level': 'INFO',
106
- },
107
- 'src.whisper_streaming.online_asr': { # Add your specific module here
108
- 'handlers': ['console'],
109
- 'level': 'DEBUG',
110
- 'propagate': False,
111
- },
112
- 'src.whisper_streaming.whisper_streaming': { # Add your specific module here
113
- 'handlers': ['console'],
114
- 'level': 'DEBUG',
115
- 'propagate': False,
116
- },
117
- },
118
- }
119
-
120
- logging.config.dictConfig(logging_config)
121
-
122
- setup_logging()
123
- logger = logging.getLogger(__name__)
124
-
125
- app = FastAPI()
126
- app.add_middleware(
127
- CORSMiddleware,
128
- allow_origins=["*"],
129
- allow_credentials=True,
130
- allow_methods=["*"],
131
- allow_headers=["*"],
132
- )
133
-
134
- parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
135
- parser.add_argument(
136
- "--host",
137
- type=str,
138
- default="localhost",
139
- help="The host address to bind the server to.",
140
- )
141
- parser.add_argument(
142
- "--port", type=int, default=8000, help="The port number to bind the server to."
143
- )
144
- parser.add_argument(
145
- "--warmup-file",
146
- type=str,
147
- dest="warmup_file",
148
- help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
149
- )
150
- parser.add_argument(
151
- "--diarization",
152
- type=bool,
153
- default=False,
154
- help="Whether to enable speaker diarization.",
155
- )
156
- parser.add_argument(
157
- "--generate-audio",
158
- type=bool,
159
- default=False,
160
- help="Whether to generate translation audio.",
161
- )
162
-
163
-
164
- add_shared_args(parser)
165
- args = parser.parse_args()
166
- # args.model = 'medium'
167
-
168
- if args.lan == 'ja':
169
- translation_target_lang = 'en'
170
- elif args.lan == 'en':
171
- translation_target_lang = 'ja'
172
-
173
- asr, tokenizer = backend_factory(args)
174
-
175
- if args.diarization:
176
- from src.diarization.diarization_online import DiartDiarization
177
-
178
-
179
- # Load demo HTML for the root endpoint
180
- with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
181
- html = f.read()
182
-
183
-
184
- @app.get("/")
185
- async def get():
186
- return HTMLResponse(html)
187
-
188
-
189
- SAMPLE_RATE = 16000
190
- CHANNELS = 1
191
- SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
192
- BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
193
- BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
194
- print('SAMPLE_RATE', SAMPLE_RATE)
195
- print('CHANNELS', CHANNELS)
196
- print('SAMPLES_PER_SEC', SAMPLES_PER_SEC)
197
- print('BYTES_PER_SAMPLE', BYTES_PER_SAMPLE)
198
- print('BYTES_PER_SEC', BYTES_PER_SEC)
199
-
200
-
201
- def generate_audio(japanese_text, speed=1.0):
202
- api_url = "https://j6im8slpwcevr7g0.us-east-1.aws.endpoints.huggingface.cloud"
203
- headers = {
204
- "Accept" : "application/json",
205
- "Authorization": f"Bearer {HUGGING_FACE_TOKEN}",
206
- "Content-Type": "application/json"
207
- }
208
-
209
- payload = {
210
- "inputs": japanese_text,
211
- "speed": speed,
212
- }
213
-
214
- response = requests.post(api_url, headers=headers, json=payload).json()
215
- if 'error' in response:
216
- print(response)
217
- return ''
218
- return response
219
-
220
-
221
- async def start_ffmpeg_decoder():
222
- """
223
- Start an FFmpeg process in async streaming mode that reads WebM from stdin
224
- and outputs raw s16le PCM on stdout. Returns the process object.
225
- """
226
- process = (
227
- ffmpeg
228
- .input("pipe:0", format="webm")
229
- .output(
230
- "pipe:1",
231
- format="s16le",
232
- acodec="pcm_s16le",
233
- ac=CHANNELS,
234
- ar=str(SAMPLE_RATE),
235
- # fflags='nobuffer',
236
- )
237
- .global_args('-loglevel', 'quiet')
238
- .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=False, quiet=True)
239
- )
240
- return process
241
-
242
- import queue
243
- import threading
244
-
245
- @app.websocket("/asr")
246
- async def websocket_endpoint(websocket: WebSocket):
247
- await websocket.accept()
248
- print("WebSocket connection opened.")
249
-
250
- ffmpeg_process = await start_ffmpeg_decoder()
251
- pcm_buffer = bytearray()
252
- print("Loading online.")
253
- online = online_factory(args, asr, tokenizer)
254
- print("Online loaded.")
255
-
256
- if args.diarization:
257
- diarization = DiartDiarization(SAMPLE_RATE)
258
-
259
- # Continuously read decoded PCM from ffmpeg stdout in a background task
260
- async def ffmpeg_stdout_reader():
261
- nonlocal pcm_buffer
262
- loop = asyncio.get_event_loop()
263
- full_transcription = ""
264
- beg = time()
265
-
266
- chunk_history = [] # Will store dicts: {beg, end, text, speaker}
267
-
268
- buffers = [{'speaker': '0', 'text': '', 'translation': None, 'audio_url': None}]
269
- buffer_line = ''
270
-
271
- # Create a queue to hold the chunks
272
- chunk_queue = queue.Queue()
273
-
274
- # Function to read from ffmpeg stdout in a separate thread
275
- def read_ffmpeg_stdout():
276
- while True:
277
- try:
278
- chunk = ffmpeg_process.stdout.read(BYTES_PER_SEC)
279
- if not chunk:
280
- break
281
- chunk_queue.put(chunk)
282
- except Exception as e:
283
- print(f"Exception in read_ffmpeg_stdout: {e}")
284
- break
285
-
286
- # Start the thread
287
- threading.Thread(target=read_ffmpeg_stdout, daemon=True).start()
288
-
289
- while True:
290
- try:
291
- # Get the chunk from the queue
292
- chunk = await loop.run_in_executor(None, chunk_queue.get)
293
- if not chunk:
294
- print("FFmpeg stdout closed.")
295
- break
296
-
297
- pcm_buffer.extend(chunk)
298
- print('len(pcm_buffer): ', len(pcm_buffer))
299
- print('BYTES_PER_SEC: ', BYTES_PER_SEC)
300
-
301
- if len(pcm_buffer) >= BYTES_PER_SEC:
302
- # Convert int16 -> float32
303
- pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0)
304
- pcm_buffer = bytearray() # Initialize the PCM buffer
305
- online.insert_audio_chunk(pcm_array)
306
- beg_trans, end_trans, trans = online.process_iter()
307
-
308
- if trans:
309
- chunk_history.append({
310
- "beg": beg_trans,
311
- "end": end_trans,
312
- "text": trans,
313
- "speaker": "0"
314
- })
315
- full_transcription += trans
316
-
317
- # ----------------
318
- # Process buffer
319
- # ----------------
320
- if args.vac:
321
- # We need to access the underlying online object to get the buffer
322
- buffer_text = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2]
323
- else:
324
- buffer_text = online.concatenate_tsw(online.transcript_buffer.buffer)[2]
325
-
326
- if buffer_text in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
327
- buffer_text = ""
328
-
329
- buffer_line += buffer_text
330
-
331
- punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however')
332
- if not any(punctuation in buffer_line for punctuation in punctuations):
333
- continue
334
-
335
- last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line)
336
- extracted_text = buffer_line[:last_punctuation_index]
337
- buffer_line = buffer_line[last_punctuation_index:]
338
- buffer = {'speaker': '0', 'text': extracted_text, 'translation': None}
339
-
340
- translation = translator_wrapper(buffer['text'], translation_target_lang, mode='google')
341
-
342
- buffer['translation'] = translation
343
- buffer['text'] += ('|' + translation)
344
- buffer['audio_url'] = generate_audio(translation, speed=2.5) if args.generate_audio else ''
345
- buffers.append(buffer)
346
-
347
- # ----------------
348
- # Process lines
349
- # ----------------
350
- '''
351
- print('Process lines')
352
- lines = [{"speaker": "0", "text": ""}]
353
-
354
- if args.diarization:
355
- await diarization.diarize(pcm_array)
356
- # diarization.assign_speakers_to_chunks(chunk_history)
357
- chunk_history = diarization.assign_speakers_to_chunks(chunk_history)
358
-
359
- for ch in chunk_history:
360
- if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
361
- lines.append({"speaker": ch["speaker"], "text": ch['text']})
362
-
363
- else:
364
- lines.append({"speaker": ch["speaker"], "text": ch['text']})
365
-
366
- for i, line in enumerate(lines):
367
- if line['text'].strip() == '':
368
- continue
369
- # translation = translator(line['text'])[0]['translation_text']
370
- # translation = translation.replace(' ', '')
371
- # lines[i]['text'] = line['text'] + translation
372
- lines[i]['text'] = line['text']
373
- '''
374
-
375
- print('Before making response')
376
- response = {'line': buffer, 'buffer': ''}
377
- print(response)
378
- await websocket.send_json(response)
379
-
380
- except Exception as e:
381
- print(f"Exception in ffmpeg_stdout_reader: {e}")
382
- break
383
-
384
- print("Exiting ffmpeg_stdout_reader...")
385
-
386
- stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
387
-
388
- try:
389
- while True:
390
- # Receive incoming WebM audio chunks from the client
391
- message = await websocket.receive_bytes()
392
- # Pass them to ffmpeg via stdin
393
- ffmpeg_process.stdin.write(message)
394
- ffmpeg_process.stdin.flush()
395
-
396
- except WebSocketDisconnect:
397
- print("WebSocket connection closed.")
398
- except Exception as e:
399
- print(f"Error in websocket loop: {e}")
400
- finally:
401
- # Clean up ffmpeg and the reader task
402
- try:
403
- ffmpeg_process.stdin.close()
404
- except:
405
- pass
406
- stdout_reader_task.cancel()
407
-
408
- try:
409
- ffmpeg_process.stdout.close()
410
- except:
411
- pass
412
-
413
- ffmpeg_process.wait()
414
- del online
415
-
416
- if args.diarization:
417
- # Stop Diart
418
- diarization.close()
419
-
420
-
421
- if __name__ == "__main__":
422
- import uvicorn
423
-
424
- uvicorn.run(
425
- "app:app", host=args.host, port=args.port, reload=True,
426
- log_level="info"
427
- )
 
1
+ import os
2
+ import io
3
+ import requests
4
+ import argparse
5
+ import asyncio
6
+ import numpy as np
7
+ import ffmpeg
8
+ from time import time
9
+
10
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
11
+ from fastapi.responses import HTMLResponse
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
15
+
16
+ import logging
17
+ import logging.config
18
+ from transformers import pipeline
19
+ from huggingface_hub import login
20
+
21
+ HUGGING_FACE_TOKEN = os.environ['HUGGING_FACE_TOKEN']
22
+ login(HUGGING_FACE_TOKEN)
23
+
24
+ # os.environ['HF_HOME'] = './.cache'
25
+
26
+ MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja'
27
+ TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda')
28
+ TRANSLATOR('Warming up!')
29
+
30
+ def translator_wrapper(source_text, translation_target_lang, mode):
31
+ if mode == 'deepl':
32
+ params = {
33
+ 'auth_key' : os.environ['DEEPL_API_KEY'],
34
+ 'text' : source_text,
35
+ 'source_lang' : 'EN', # 翻訳対象の言語
36
+ "target_lang": 'JA', # 翻訳後の言語
37
+ }
38
+
39
+ # リクエストを投げる
40
+ try:
41
+ request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) # URIは有償版, 無償版で異なるため要注意
42
+ result = request.json()['translations'][0]['text']
43
+ except requests.exceptions.Timeout:
44
+ result = "(timed out)"
45
+ return result
46
+
47
+ elif mode == 'marianmt':
48
+ return TRANSLATOR(source_text)[0]['translation_text']
49
+
50
+ elif mode == 'google':
51
+ import requests
52
+
53
+ # https://www.eyoucms.com/news/ziliao/other/29445.html
54
+ language_type = ""
55
+ url = "https://translation.googleapis.com/language/translate/v2"
56
+ data = {
57
+ 'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", # 認証の設定:APIキー
58
+ 'source': language_type,
59
+ 'target': translation_target_lang,
60
+ 'q': source_text,
61
+ 'format': "text"
62
+ }
63
+ #headers = {'X-HTTP-Method-Override': 'GET'}
64
+ #response = requests.post(url, data=data, headers=headers)
65
+ response = requests.post(url, data)
66
+ # print(response.json())
67
+ print(response)
68
+ res = response.json()
69
+ print(res["data"]["translations"][0]["translatedText"])
70
+ result = res["data"]["translations"][0]["translatedText"]
71
+ print(result)
72
+ return result
73
+
74
+
75
+ def setup_logging():
76
+ logging_config = {
77
+ 'version': 1,
78
+ 'disable_existing_loggers': False,
79
+ 'formatters': {
80
+ 'standard': {
81
+ 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
82
+ },
83
+ },
84
+ 'handlers': {
85
+ 'console': {
86
+ 'level': 'INFO',
87
+ 'class': 'logging.StreamHandler',
88
+ 'formatter': 'standard',
89
+ },
90
+ },
91
+ 'root': {
92
+ 'handlers': ['console'],
93
+ 'level': 'DEBUG',
94
+ },
95
+ 'loggers': {
96
+ 'uvicorn': {
97
+ 'handlers': ['console'],
98
+ 'level': 'INFO',
99
+ 'propagate': False,
100
+ },
101
+ 'uvicorn.error': {
102
+ 'level': 'INFO',
103
+ },
104
+ 'uvicorn.access': {
105
+ 'level': 'INFO',
106
+ },
107
+ 'src.whisper_streaming.online_asr': { # Add your specific module here
108
+ 'handlers': ['console'],
109
+ 'level': 'DEBUG',
110
+ 'propagate': False,
111
+ },
112
+ 'src.whisper_streaming.whisper_streaming': { # Add your specific module here
113
+ 'handlers': ['console'],
114
+ 'level': 'DEBUG',
115
+ 'propagate': False,
116
+ },
117
+ },
118
+ }
119
+
120
+ logging.config.dictConfig(logging_config)
121
+
122
+ setup_logging()
123
+ logger = logging.getLogger(__name__)
124
+
125
+ app = FastAPI()
126
+ app.add_middleware(
127
+ CORSMiddleware,
128
+ allow_origins=["*"],
129
+ allow_credentials=True,
130
+ allow_methods=["*"],
131
+ allow_headers=["*"],
132
+ )
133
+
134
+ parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
135
+ parser.add_argument(
136
+ "--host",
137
+ type=str,
138
+ default="localhost",
139
+ help="The host address to bind the server to.",
140
+ )
141
+ parser.add_argument(
142
+ "--port", type=int, default=8000, help="The port number to bind the server to."
143
+ )
144
+ parser.add_argument(
145
+ "--warmup-file",
146
+ type=str,
147
+ dest="warmup_file",
148
+ help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
149
+ )
150
+ parser.add_argument(
151
+ "--diarization",
152
+ type=bool,
153
+ default=False,
154
+ help="Whether to enable speaker diarization.",
155
+ )
156
+ parser.add_argument(
157
+ "--generate-audio",
158
+ type=bool,
159
+ default=False,
160
+ help="Whether to generate translation audio.",
161
+ )
162
+
163
+
164
+ add_shared_args(parser)
165
+ args = parser.parse_args()
166
+ # args.model = 'medium'
167
+
168
+ if args.lan == 'ja':
169
+ translation_target_lang = 'en'
170
+ elif args.lan == 'en':
171
+ translation_target_lang = 'ja'
172
+
173
+ asr, tokenizer = backend_factory(args)
174
+
175
+ if args.diarization:
176
+ from src.diarization.diarization_online import DiartDiarization
177
+
178
+
179
+ # Load demo HTML for the root endpoint
180
+ with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
181
+ html = f.read()
182
+
183
+
184
+ @app.get("/")
185
+ async def get():
186
+ return HTMLResponse(html)
187
+
188
+
189
+ SAMPLE_RATE = 16000
190
+ CHANNELS = 1
191
+ SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
192
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
193
+ BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
194
+ print('SAMPLE_RATE', SAMPLE_RATE)
195
+ print('CHANNELS', CHANNELS)
196
+ print('SAMPLES_PER_SEC', SAMPLES_PER_SEC)
197
+ print('BYTES_PER_SAMPLE', BYTES_PER_SAMPLE)
198
+ print('BYTES_PER_SEC', BYTES_PER_SEC)
199
+
200
+
201
+ def generate_audio(japanese_text, speed=1.0):
202
+ api_url = "https://j6im8slpwcevr7g0.us-east-1.aws.endpoints.huggingface.cloud"
203
+ headers = {
204
+ "Accept" : "application/json",
205
+ "Authorization": f"Bearer {HUGGING_FACE_TOKEN}",
206
+ "Content-Type": "application/json"
207
+ }
208
+
209
+ payload = {
210
+ "inputs": japanese_text,
211
+ "speed": speed,
212
+ }
213
+
214
+ response = requests.post(api_url, headers=headers, json=payload).json()
215
+ if 'error' in response:
216
+ print(response)
217
+ return ''
218
+ return response
219
+
220
+
221
+ async def start_ffmpeg_decoder():
222
+ """
223
+ Start an FFmpeg process in async streaming mode that reads WebM from stdin
224
+ and outputs raw s16le PCM on stdout. Returns the process object.
225
+ """
226
+ process = (
227
+ ffmpeg
228
+ .input("pipe:0", format="webm")
229
+ .output(
230
+ "pipe:1",
231
+ format="s16le",
232
+ acodec="pcm_s16le",
233
+ ac=CHANNELS,
234
+ ar=str(SAMPLE_RATE),
235
+ # fflags='nobuffer',
236
+ )
237
+ .global_args('-loglevel', 'quiet')
238
+ .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=False, quiet=True)
239
+ )
240
+ return process
241
+
242
+ import queue
243
+ import threading
244
+
245
+ @app.websocket("/asr")
246
+ async def websocket_endpoint(websocket: WebSocket):
247
+ await websocket.accept()
248
+ print("WebSocket connection opened.")
249
+
250
+ ffmpeg_process = await start_ffmpeg_decoder()
251
+ pcm_buffer = bytearray()
252
+ print("Loading online.")
253
+ online = online_factory(args, asr, tokenizer)
254
+ print("Online loaded.")
255
+
256
+ if args.diarization:
257
+ diarization = DiartDiarization(SAMPLE_RATE)
258
+
259
+ # Continuously read decoded PCM from ffmpeg stdout in a background task
260
+ async def ffmpeg_stdout_reader():
261
+ nonlocal pcm_buffer
262
+ loop = asyncio.get_event_loop()
263
+ full_transcription = ""
264
+ beg = time()
265
+
266
+ chunk_history = [] # Will store dicts: {beg, end, text, speaker}
267
+
268
+ buffers = [{'speaker': '0', 'text': '', 'translation': None, 'audio_url': None}]
269
+ buffer_line = ''
270
+
271
+ # Create a queue to hold the chunks
272
+ chunk_queue = queue.Queue()
273
+
274
+ # Function to read from ffmpeg stdout in a separate thread
275
+ def read_ffmpeg_stdout():
276
+ while True:
277
+ try:
278
+ chunk = ffmpeg_process.stdout.read(BYTES_PER_SEC)
279
+ if not chunk:
280
+ break
281
+ chunk_queue.put(chunk)
282
+ except Exception as e:
283
+ print(f"Exception in read_ffmpeg_stdout: {e}")
284
+ break
285
+
286
+ # Start the thread
287
+ threading.Thread(target=read_ffmpeg_stdout, daemon=True).start()
288
+
289
+ while True:
290
+ try:
291
+ # Get the chunk from the queue
292
+ chunk = await loop.run_in_executor(None, chunk_queue.get)
293
+ if not chunk:
294
+ print("FFmpeg stdout closed.")
295
+ break
296
+
297
+ pcm_buffer.extend(chunk)
298
+ print('len(pcm_buffer): ', len(pcm_buffer))
299
+ print('BYTES_PER_SEC: ', BYTES_PER_SEC)
300
+
301
+ if len(pcm_buffer) >= BYTES_PER_SEC:
302
+ # Convert int16 -> float32
303
+ pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0)
304
+ pcm_buffer = bytearray() # Initialize the PCM buffer
305
+ online.insert_audio_chunk(pcm_array)
306
+ beg_trans, end_trans, trans = online.process_iter()
307
+
308
+ if trans:
309
+ chunk_history.append({
310
+ "beg": beg_trans,
311
+ "end": end_trans,
312
+ "text": trans,
313
+ "speaker": "0"
314
+ })
315
+ full_transcription += trans
316
+
317
+ # ----------------
318
+ # Process buffer
319
+ # ----------------
320
+ if args.vac:
321
+ # We need to access the underlying online object to get the buffer
322
+ buffer_text = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2]
323
+ else:
324
+ buffer_text = online.concatenate_tsw(online.transcript_buffer.buffer)[2]
325
+
326
+ if buffer_text in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
327
+ buffer_text = ""
328
+
329
+ buffer_line += buffer_text
330
+
331
+ punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however')
332
+ if not any(punctuation in buffer_line for punctuation in punctuations):
333
+ continue
334
+
335
+ last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line)
336
+ extracted_text = buffer_line[:last_punctuation_index]
337
+ buffer_line = buffer_line[last_punctuation_index:]
338
+ buffer = {'speaker': '0', 'text': extracted_text, 'translation': None}
339
+
340
+ translation = translator_wrapper(buffer['text'], translation_target_lang, mode='google')
341
+
342
+ buffer['translation'] = translation
343
+ buffer['text'] += ('|' + translation)
344
+ buffer['audio_url'] = generate_audio(translation, speed=1.5) if args.generate_audio else ''
345
+ buffers.append(buffer)
346
+
347
+ # ----------------
348
+ # Process lines
349
+ # ----------------
350
+ '''
351
+ print('Process lines')
352
+ lines = [{"speaker": "0", "text": ""}]
353
+
354
+ if args.diarization:
355
+ await diarization.diarize(pcm_array)
356
+ # diarization.assign_speakers_to_chunks(chunk_history)
357
+ chunk_history = diarization.assign_speakers_to_chunks(chunk_history)
358
+
359
+ for ch in chunk_history:
360
+ if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
361
+ lines.append({"speaker": ch["speaker"], "text": ch['text']})
362
+
363
+ else:
364
+ lines.append({"speaker": ch["speaker"], "text": ch['text']})
365
+
366
+ for i, line in enumerate(lines):
367
+ if line['text'].strip() == '':
368
+ continue
369
+ # translation = translator(line['text'])[0]['translation_text']
370
+ # translation = translation.replace(' ', '')
371
+ # lines[i]['text'] = line['text'] + translation
372
+ lines[i]['text'] = line['text']
373
+ '''
374
+
375
+ print('Before making response')
376
+ response = {'line': buffer, 'buffer': ''}
377
+ print(response)
378
+ await websocket.send_json(response)
379
+
380
+ except Exception as e:
381
+ print(f"Exception in ffmpeg_stdout_reader: {e}")
382
+ break
383
+
384
+ print("Exiting ffmpeg_stdout_reader...")
385
+
386
+ stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
387
+
388
+ try:
389
+ while True:
390
+ # Receive incoming WebM audio chunks from the client
391
+ message = await websocket.receive_bytes()
392
+ # Pass them to ffmpeg via stdin
393
+ ffmpeg_process.stdin.write(message)
394
+ ffmpeg_process.stdin.flush()
395
+
396
+ except WebSocketDisconnect:
397
+ print("WebSocket connection closed.")
398
+ except Exception as e:
399
+ print(f"Error in websocket loop: {e}")
400
+ finally:
401
+ # Clean up ffmpeg and the reader task
402
+ try:
403
+ ffmpeg_process.stdin.close()
404
+ except:
405
+ pass
406
+ stdout_reader_task.cancel()
407
+
408
+ try:
409
+ ffmpeg_process.stdout.close()
410
+ except:
411
+ pass
412
+
413
+ ffmpeg_process.wait()
414
+ del online
415
+
416
+ if args.diarization:
417
+ # Stop Diart
418
+ diarization.close()
419
+
420
+
421
+ if __name__ == "__main__":
422
+ import uvicorn
423
+
424
+ uvicorn.run(
425
+ "app:app", host=args.host, port=args.port, reload=True,
426
+ log_level="info"
427
+ )