matsuap commited on
Commit
f28e619
·
verified ·
1 Parent(s): 3be8f00

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +33 -33
  2. app.py +427 -397
Dockerfile CHANGED
@@ -1,33 +1,33 @@
1
- # ベースイメージを指定
2
- FROM python:3.11-slim
3
-
4
- RUN apt-get update -y && apt-get upgrade -y && apt-get install -y ffmpeg
5
- RUN apt-get update -y && apt-get upgrade -y && apt-get install --reinstall -y wget
6
- RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb && dpkg -i cuda-keyring_1.0-1_all.deb && apt-get update -y && apt-get upgrade -y && apt-get install -y libcudnn8 libcudnn8-dev
7
-
8
- # Set up a new user named "user" with user ID 1000
9
- RUN useradd -m -u 1000 user
10
-
11
- # Switch to the "user" user
12
- USER user
13
-
14
- # Set home to the user"s home directory
15
- ENV HOME=/home/user \
16
- PATH=/home/user/.local/bin:$PATH
17
-
18
- # 作業ディレクトリを設定
19
- WORKDIR $HOME/app
20
-
21
- # 必要なパッケージをインストール
22
- COPY requirements.txt .
23
- RUN pip install --no-cache-dir -r requirements.txt
24
-
25
- # アプリケーションのソースコードをコピー
26
- COPY --chown=user . $HOME/app
27
-
28
- # .cacheディレクトリを作成
29
- RUN mkdir -p .cache
30
-
31
- # サーバーを起動するコマンドを指定
32
- CMD ["python", "app.py", "--host", "0.0.0.0", "--port", "7860", "--lan", "en", "--model", "AtPeak/whisper-medium-finance-faster-float16"]
33
- # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ # ベースイメージを指定
2
+ FROM python:3.11-slim
3
+
4
+ RUN apt-get update -y && apt-get upgrade -y && apt-get install -y ffmpeg
5
+ RUN apt-get update -y && apt-get upgrade -y && apt-get install --reinstall -y wget
6
+ RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb && dpkg -i cuda-keyring_1.0-1_all.deb && apt-get update -y && apt-get upgrade -y && apt-get install -y libcudnn8 libcudnn8-dev
7
+
8
+ # Set up a new user named "user" with user ID 1000
9
+ RUN useradd -m -u 1000 user
10
+
11
+ # Switch to the "user" user
12
+ USER user
13
+
14
+ # Set home to the user"s home directory
15
+ ENV HOME=/home/user \
16
+ PATH=/home/user/.local/bin:$PATH
17
+
18
+ # 作業ディレクトリを設定
19
+ WORKDIR $HOME/app
20
+
21
+ # 必要なパッケージをインストール
22
+ COPY requirements.txt .
23
+ RUN pip install --no-cache-dir -r requirements.txt
24
+
25
+ # アプリケーションのソースコードをコピー
26
+ COPY --chown=user . $HOME/app
27
+
28
+ # .cacheディレクトリを作成
29
+ RUN mkdir -p .cache
30
+
31
+ # サーバーを起動するコマンドを指定
32
+ CMD ["python", "app.py", "--host", "0.0.0.0", "--port", "7860", "--lan", "en", "--model", "tiny", "--min-chunk-size", "1.5", "--generate-audio", "True"]
33
+ # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,397 +1,427 @@
1
- import os
2
- import io
3
- import requests
4
- import argparse
5
- import asyncio
6
- import numpy as np
7
- import ffmpeg
8
- from time import time
9
-
10
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
11
- from fastapi.responses import HTMLResponse
12
- from fastapi.middleware.cors import CORSMiddleware
13
-
14
- from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
15
-
16
- import logging
17
- import logging.config
18
- from transformers import pipeline
19
- from huggingface_hub import login
20
-
21
- HUGGING_FACE_TOKEN = os.environ['HUGGING_FACE_TOKEN']
22
- login(HUGGING_FACE_TOKEN)
23
-
24
- os.environ['HF_HOME'] = './.cache'
25
-
26
- MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja'
27
- TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda')
28
- TRANSLATOR('Warming up!')
29
-
30
- DEEPL_API_KEY = os.environ['DEEPL_API_KEY']
31
-
32
- SOURCE_LANG = 'EN'
33
- TARGET_LANG = 'JA'
34
-
35
- def translator_wrapper(source_text, mode='deepl'):
36
- if mode == 'deepl':
37
- params = {
38
- 'auth_key' : DEEPL_API_KEY,
39
- 'text' : source_text,
40
- 'source_lang' : SOURCE_LANG, # 翻訳対象の言語
41
- "target_lang": TARGET_LANG # 翻訳後の言語
42
- }
43
-
44
- # リクエストを投げる
45
- try:
46
- request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) # URIは有償版, 無償版で異なるため要注意
47
- result = request.json()['translations'][0]['text']
48
- except requests.exceptions.Timeout:
49
- result = "(timed out)"
50
- return result
51
-
52
- elif mode == 'marianmt':
53
- return TRANSLATOR(source_text)[0]['translation_text']
54
-
55
- elif mode == 'google':
56
- import requests
57
-
58
- # https://www.eyoucms.com/news/ziliao/other/29445.html
59
- language_type = ""
60
- target = 'ja-jp'
61
- url = "https://translation.googleapis.com/language/translate/v2"
62
- data = {
63
- 'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", # 認証の設定:APIキー
64
- 'source': language_type,
65
- 'target': target,
66
- 'q': source_text,
67
- 'format': "text"
68
- }
69
- #headers = {'X-HTTP-Method-Override': 'GET'}
70
- #response = requests.post(url, data=data, headers=headers)
71
- response = requests.post(url, data)
72
- # print(response.json())
73
- print(response)
74
- res = response.json()
75
- print(res["data"]["translations"][0]["translatedText"])
76
- result = res["data"]["translations"][0]["translatedText"]
77
- print(result)
78
- return result
79
-
80
-
81
- def setup_logging():
82
- logging_config = {
83
- 'version': 1,
84
- 'disable_existing_loggers': False,
85
- 'formatters': {
86
- 'standard': {
87
- 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
88
- },
89
- },
90
- 'handlers': {
91
- 'console': {
92
- 'level': 'INFO',
93
- 'class': 'logging.StreamHandler',
94
- 'formatter': 'standard',
95
- },
96
- },
97
- 'root': {
98
- 'handlers': ['console'],
99
- 'level': 'DEBUG',
100
- },
101
- 'loggers': {
102
- 'uvicorn': {
103
- 'handlers': ['console'],
104
- 'level': 'INFO',
105
- 'propagate': False,
106
- },
107
- 'uvicorn.error': {
108
- 'level': 'INFO',
109
- },
110
- 'uvicorn.access': {
111
- 'level': 'INFO',
112
- },
113
- 'src.whisper_streaming.online_asr': { # Add your specific module here
114
- 'handlers': ['console'],
115
- 'level': 'DEBUG',
116
- 'propagate': False,
117
- },
118
- 'src.whisper_streaming.whisper_streaming': { # Add your specific module here
119
- 'handlers': ['console'],
120
- 'level': 'DEBUG',
121
- 'propagate': False,
122
- },
123
- },
124
- }
125
-
126
- logging.config.dictConfig(logging_config)
127
-
128
- setup_logging()
129
- logger = logging.getLogger(__name__)
130
-
131
-
132
-
133
-
134
-
135
-
136
- app = FastAPI()
137
- app.add_middleware(
138
- CORSMiddleware,
139
- allow_origins=["*"],
140
- allow_credentials=True,
141
- allow_methods=["*"],
142
- allow_headers=["*"],
143
- )
144
-
145
-
146
- parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
147
- parser.add_argument(
148
- "--host",
149
- type=str,
150
- default="localhost",
151
- help="The host address to bind the server to.",
152
- )
153
- parser.add_argument(
154
- "--port", type=int, default=8000, help="The port number to bind the server to."
155
- )
156
- parser.add_argument(
157
- "--warmup-file",
158
- type=str,
159
- dest="warmup_file",
160
- help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
161
- )
162
-
163
- parser.add_argument(
164
- "--diarization",
165
- type=bool,
166
- default=False,
167
- help="Whether to enable speaker diarization.",
168
- )
169
-
170
-
171
- add_shared_args(parser)
172
- args = parser.parse_args()
173
- # args.model = 'medium'
174
-
175
- asr, tokenizer = backend_factory(args)
176
-
177
- if args.diarization:
178
- from src.diarization.diarization_online import DiartDiarization
179
-
180
-
181
- # Load demo HTML for the root endpoint
182
- with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
183
- html = f.read()
184
-
185
-
186
- @app.get("/")
187
- async def get():
188
- return HTMLResponse(html)
189
-
190
-
191
- SAMPLE_RATE = 16000
192
- CHANNELS = 1
193
- SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
194
- BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
195
- BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
196
-
197
-
198
- async def start_ffmpeg_decoder():
199
- """
200
- Start an FFmpeg process in async streaming mode that reads WebM from stdin
201
- and outputs raw s16le PCM on stdout. Returns the process object.
202
- """
203
- process = (
204
- ffmpeg.input("pipe:0", format="webm")
205
- .output(
206
- "pipe:1",
207
- format="s16le",
208
- acodec="pcm_s16le",
209
- ac=CHANNELS,
210
- ar=str(SAMPLE_RATE),
211
- )
212
- .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
213
- )
214
- return process
215
-
216
-
217
-
218
- @app.websocket("/asr")
219
- async def websocket_endpoint(websocket: WebSocket):
220
- await websocket.accept()
221
- print("WebSocket connection opened.")
222
-
223
- ffmpeg_process = await start_ffmpeg_decoder()
224
- pcm_buffer = bytearray()
225
- print("Loading online.")
226
- online = online_factory(args, asr, tokenizer)
227
- print("Online loaded.")
228
-
229
- if args.diarization:
230
- diarization = DiartDiarization(SAMPLE_RATE)
231
-
232
- # Continuously read decoded PCM from ffmpeg stdout in a background task
233
- async def ffmpeg_stdout_reader():
234
- nonlocal pcm_buffer
235
- loop = asyncio.get_event_loop()
236
- full_transcription = ""
237
- beg = time()
238
-
239
- chunk_history = [] # Will store dicts: {beg, end, text, speaker}
240
-
241
- buffers = [{'speaker': '0', 'text': '', 'translation': None}]
242
- buffer_line = ''
243
-
244
- while True:
245
- print('in while')
246
- try:
247
- print('try in while')
248
- elapsed_time = int(time() - beg)
249
- beg = time()
250
- print('before await loop.run_in_executor()')
251
- chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 32000 * elapsed_time)
252
-
253
- print('before if not chunk')
254
- if not chunk: # The first chunk will be almost empty, FFmpeg is still starting up
255
- chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 4096)
256
- if not chunk: # FFmpeg might have closed
257
- print("FFmpeg stdout closed.")
258
- break
259
-
260
- pcm_buffer.extend(chunk)
261
-
262
- print('before if len(pcm_buffer)')
263
- if len(pcm_buffer) >= BYTES_PER_SEC:
264
- print('in if len(pcm_buffer)')
265
- # Convert int16 -> float32
266
- pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0)
267
- pcm_buffer = bytearray() # Initialize the PCM buffer
268
- online.insert_audio_chunk(pcm_array)
269
- beg_trans, end_trans, trans = online.process_iter()
270
-
271
- if trans:
272
- chunk_history.append({
273
- "beg": beg_trans,
274
- "end": end_trans,
275
- "text": trans,
276
- "speaker": "0"
277
- })
278
- full_transcription += trans
279
-
280
- # ----------------
281
- # Process buffer
282
- # ----------------
283
- if args.vac:
284
- # We need to access the underlying online object to get the buffer
285
- buffer = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2]
286
- else:
287
- buffer = online.concatenate_tsw(online.transcript_buffer.buffer)[2]
288
-
289
- if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
290
- buffer = ""
291
-
292
- buffer_line += buffer
293
-
294
- punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however')
295
- if any(punctuation in buffer_line for punctuation in punctuations):
296
- last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line)
297
- extracted_text = buffer_line[:last_punctuation_index]
298
- buffer_line = buffer_line[last_punctuation_index:]
299
- buffers.append({'speaker': '0', 'text': extracted_text, 'translation': None})
300
-
301
- # Translation loop
302
- print('buffers for loop')
303
- for i, buffer in enumerate(buffers):
304
- print(i, buffer)
305
- if buffer['translation'] is not None:
306
- continue
307
- if buffer['text'] == '':
308
- continue
309
-
310
- transcription = buffer['text']
311
- buffers[i]['translation'] = translator_wrapper(transcription, mode='google')
312
- buffers[i]['text'] += ('|' + buffers[i]['translation'])
313
-
314
- # ----------------
315
- # Process lines
316
- # ----------------
317
- print('Process lines')
318
- lines = [{"speaker": "0", "text": ""}]
319
-
320
- if args.diarization:
321
- await diarization.diarize(pcm_array)
322
- # diarization.assign_speakers_to_chunks(chunk_history)
323
- chunk_history = diarization.assign_speakers_to_chunks(chunk_history)
324
-
325
- for ch in chunk_history:
326
- if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
327
- lines.append({"speaker": ch["speaker"], "text": ch['text']})
328
-
329
- else:
330
- lines.append({"speaker": ch["speaker"], "text": ch['text']})
331
-
332
- for i, line in enumerate(lines):
333
- if line['text'].strip() == '':
334
- continue
335
- # translation = translator(line['text'])[0]['translation_text']
336
- # translation = translation.replace(' ', '')
337
- # lines[i]['text'] = line['text'] + translation
338
- lines[i]['text'] = line['text']
339
-
340
- # translation = translator(buffer)[0]['translation_text']
341
- # translation = translation.replace(' ', '')
342
- # buffer += translation
343
-
344
- print('Before making response')
345
- response = {"lines": buffers, "buffer": ''}
346
- await websocket.send_json(response)
347
-
348
- except Exception as e:
349
- print(f"Exception in ffmpeg_stdout_reader: {e}")
350
- break
351
-
352
- print("Exiting ffmpeg_stdout_reader...")
353
-
354
- stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
355
-
356
- try:
357
- while True:
358
- # Receive incoming WebM audio chunks from the client
359
- message = await websocket.receive_bytes()
360
- # Pass them to ffmpeg via stdin
361
- ffmpeg_process.stdin.write(message)
362
- ffmpeg_process.stdin.flush()
363
-
364
- except WebSocketDisconnect:
365
- print("WebSocket connection closed.")
366
- except Exception as e:
367
- print(f"Error in websocket loop: {e}")
368
- finally:
369
- # Clean up ffmpeg and the reader task
370
- try:
371
- ffmpeg_process.stdin.close()
372
- except:
373
- pass
374
- stdout_reader_task.cancel()
375
-
376
- try:
377
- ffmpeg_process.stdout.close()
378
- except:
379
- pass
380
-
381
- ffmpeg_process.wait()
382
- del online
383
-
384
- if args.diarization:
385
- # Stop Diart
386
- diarization.close()
387
-
388
-
389
-
390
-
391
- if __name__ == "__main__":
392
- import uvicorn
393
-
394
- uvicorn.run(
395
- "app:app", host=args.host, port=args.port, reload=True,
396
- log_level="info"
397
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import requests
4
+ import argparse
5
+ import asyncio
6
+ import numpy as np
7
+ import ffmpeg
8
+ from time import time
9
+
10
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
11
+ from fastapi.responses import HTMLResponse
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
15
+
16
+ import logging
17
+ import logging.config
18
+ from transformers import pipeline
19
+ from huggingface_hub import login
20
+
21
+ HUGGING_FACE_TOKEN = os.environ['HUGGING_FACE_TOKEN']
22
+ login(HUGGING_FACE_TOKEN)
23
+
24
+ # os.environ['HF_HOME'] = './.cache'
25
+
26
+ MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja'
27
+ TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda')
28
+ TRANSLATOR('Warming up!')
29
+
30
+ def translator_wrapper(source_text, translation_target_lang, mode):
31
+ if mode == 'deepl':
32
+ params = {
33
+ 'auth_key' : os.environ['DEEPL_API_KEY'],
34
+ 'text' : source_text,
35
+ 'source_lang' : 'EN', # 翻訳対象の言語
36
+ "target_lang": 'JA', # 翻訳後の言語
37
+ }
38
+
39
+ # リクエストを投げる
40
+ try:
41
+ request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) # URIは有償版, 無償版で異なるため要注意
42
+ result = request.json()['translations'][0]['text']
43
+ except requests.exceptions.Timeout:
44
+ result = "(timed out)"
45
+ return result
46
+
47
+ elif mode == 'marianmt':
48
+ return TRANSLATOR(source_text)[0]['translation_text']
49
+
50
+ elif mode == 'google':
51
+ import requests
52
+
53
+ # https://www.eyoucms.com/news/ziliao/other/29445.html
54
+ language_type = ""
55
+ url = "https://translation.googleapis.com/language/translate/v2"
56
+ data = {
57
+ 'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", # 認証の設定:APIキー
58
+ 'source': language_type,
59
+ 'target': translation_target_lang,
60
+ 'q': source_text,
61
+ 'format': "text"
62
+ }
63
+ #headers = {'X-HTTP-Method-Override': 'GET'}
64
+ #response = requests.post(url, data=data, headers=headers)
65
+ response = requests.post(url, data)
66
+ # print(response.json())
67
+ print(response)
68
+ res = response.json()
69
+ print(res["data"]["translations"][0]["translatedText"])
70
+ result = res["data"]["translations"][0]["translatedText"]
71
+ print(result)
72
+ return result
73
+
74
+
75
+ def setup_logging():
76
+ logging_config = {
77
+ 'version': 1,
78
+ 'disable_existing_loggers': False,
79
+ 'formatters': {
80
+ 'standard': {
81
+ 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
82
+ },
83
+ },
84
+ 'handlers': {
85
+ 'console': {
86
+ 'level': 'INFO',
87
+ 'class': 'logging.StreamHandler',
88
+ 'formatter': 'standard',
89
+ },
90
+ },
91
+ 'root': {
92
+ 'handlers': ['console'],
93
+ 'level': 'DEBUG',
94
+ },
95
+ 'loggers': {
96
+ 'uvicorn': {
97
+ 'handlers': ['console'],
98
+ 'level': 'INFO',
99
+ 'propagate': False,
100
+ },
101
+ 'uvicorn.error': {
102
+ 'level': 'INFO',
103
+ },
104
+ 'uvicorn.access': {
105
+ 'level': 'INFO',
106
+ },
107
+ 'src.whisper_streaming.online_asr': { # Add your specific module here
108
+ 'handlers': ['console'],
109
+ 'level': 'DEBUG',
110
+ 'propagate': False,
111
+ },
112
+ 'src.whisper_streaming.whisper_streaming': { # Add your specific module here
113
+ 'handlers': ['console'],
114
+ 'level': 'DEBUG',
115
+ 'propagate': False,
116
+ },
117
+ },
118
+ }
119
+
120
+ logging.config.dictConfig(logging_config)
121
+
122
+ setup_logging()
123
+ logger = logging.getLogger(__name__)
124
+
125
+ app = FastAPI()
126
+ app.add_middleware(
127
+ CORSMiddleware,
128
+ allow_origins=["*"],
129
+ allow_credentials=True,
130
+ allow_methods=["*"],
131
+ allow_headers=["*"],
132
+ )
133
+
134
+ parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
135
+ parser.add_argument(
136
+ "--host",
137
+ type=str,
138
+ default="localhost",
139
+ help="The host address to bind the server to.",
140
+ )
141
+ parser.add_argument(
142
+ "--port", type=int, default=8000, help="The port number to bind the server to."
143
+ )
144
+ parser.add_argument(
145
+ "--warmup-file",
146
+ type=str,
147
+ dest="warmup_file",
148
+ help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
149
+ )
150
+ parser.add_argument(
151
+ "--diarization",
152
+ type=bool,
153
+ default=False,
154
+ help="Whether to enable speaker diarization.",
155
+ )
156
+ parser.add_argument(
157
+ "--generate-audio",
158
+ type=bool,
159
+ default=False,
160
+ help="Whether to generate translation audio.",
161
+ )
162
+
163
+
164
+ add_shared_args(parser)
165
+ args = parser.parse_args()
166
+ # args.model = 'medium'
167
+
168
+ if args.lan == 'ja':
169
+ translation_target_lang = 'en'
170
+ elif args.lan == 'en':
171
+ translation_target_lang = 'ja'
172
+
173
+ asr, tokenizer = backend_factory(args)
174
+
175
+ if args.diarization:
176
+ from src.diarization.diarization_online import DiartDiarization
177
+
178
+
179
+ # Load demo HTML for the root endpoint
180
+ with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
181
+ html = f.read()
182
+
183
+
184
+ @app.get("/")
185
+ async def get():
186
+ return HTMLResponse(html)
187
+
188
+
189
+ SAMPLE_RATE = 16000
190
+ CHANNELS = 1
191
+ SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
192
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
193
+ BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
194
+ print('SAMPLE_RATE', SAMPLE_RATE)
195
+ print('CHANNELS', CHANNELS)
196
+ print('SAMPLES_PER_SEC', SAMPLES_PER_SEC)
197
+ print('BYTES_PER_SAMPLE', BYTES_PER_SAMPLE)
198
+ print('BYTES_PER_SEC', BYTES_PER_SEC)
199
+
200
+
201
+ def generate_audio(japanese_text, speed=1.0):
202
+ api_url = "https://j6im8slpwcevr7g0.us-east-1.aws.endpoints.huggingface.cloud"
203
+ headers = {
204
+ "Accept" : "application/json",
205
+ "Authorization": f"Bearer {HUGGING_FACE_TOKEN}",
206
+ "Content-Type": "application/json"
207
+ }
208
+
209
+ payload = {
210
+ "inputs": japanese_text,
211
+ "speed": speed,
212
+ }
213
+
214
+ response = requests.post(api_url, headers=headers, json=payload).json()
215
+ if 'error' in response:
216
+ print(response)
217
+ return ''
218
+ return response
219
+
220
+
221
+ async def start_ffmpeg_decoder():
222
+ """
223
+ Start an FFmpeg process in async streaming mode that reads WebM from stdin
224
+ and outputs raw s16le PCM on stdout. Returns the process object.
225
+ """
226
+ process = (
227
+ ffmpeg
228
+ .input("pipe:0", format="webm")
229
+ .output(
230
+ "pipe:1",
231
+ format="s16le",
232
+ acodec="pcm_s16le",
233
+ ac=CHANNELS,
234
+ ar=str(SAMPLE_RATE),
235
+ # fflags='nobuffer',
236
+ )
237
+ .global_args('-loglevel', 'quiet')
238
+ .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=False, quiet=True)
239
+ )
240
+ return process
241
+
242
+ import queue
243
+ import threading
244
+
245
+ @app.websocket("/asr")
246
+ async def websocket_endpoint(websocket: WebSocket):
247
+ await websocket.accept()
248
+ print("WebSocket connection opened.")
249
+
250
+ ffmpeg_process = await start_ffmpeg_decoder()
251
+ pcm_buffer = bytearray()
252
+ print("Loading online.")
253
+ online = online_factory(args, asr, tokenizer)
254
+ print("Online loaded.")
255
+
256
+ if args.diarization:
257
+ diarization = DiartDiarization(SAMPLE_RATE)
258
+
259
+ # Continuously read decoded PCM from ffmpeg stdout in a background task
260
+ async def ffmpeg_stdout_reader():
261
+ nonlocal pcm_buffer
262
+ loop = asyncio.get_event_loop()
263
+ full_transcription = ""
264
+ beg = time()
265
+
266
+ chunk_history = [] # Will store dicts: {beg, end, text, speaker}
267
+
268
+ buffers = [{'speaker': '0', 'text': '', 'translation': None, 'audio_url': None}]
269
+ buffer_line = ''
270
+
271
+ # Create a queue to hold the chunks
272
+ chunk_queue = queue.Queue()
273
+
274
+ # Function to read from ffmpeg stdout in a separate thread
275
+ def read_ffmpeg_stdout():
276
+ while True:
277
+ try:
278
+ chunk = ffmpeg_process.stdout.read(BYTES_PER_SEC)
279
+ if not chunk:
280
+ break
281
+ chunk_queue.put(chunk)
282
+ except Exception as e:
283
+ print(f"Exception in read_ffmpeg_stdout: {e}")
284
+ break
285
+
286
+ # Start the thread
287
+ threading.Thread(target=read_ffmpeg_stdout, daemon=True).start()
288
+
289
+ while True:
290
+ try:
291
+ # Get the chunk from the queue
292
+ chunk = await loop.run_in_executor(None, chunk_queue.get)
293
+ if not chunk:
294
+ print("FFmpeg stdout closed.")
295
+ break
296
+
297
+ pcm_buffer.extend(chunk)
298
+ print('len(pcm_buffer): ', len(pcm_buffer))
299
+ print('BYTES_PER_SEC: ', BYTES_PER_SEC)
300
+
301
+ if len(pcm_buffer) >= BYTES_PER_SEC:
302
+ # Convert int16 -> float32
303
+ pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0)
304
+ pcm_buffer = bytearray() # Initialize the PCM buffer
305
+ online.insert_audio_chunk(pcm_array)
306
+ beg_trans, end_trans, trans = online.process_iter()
307
+
308
+ if trans:
309
+ chunk_history.append({
310
+ "beg": beg_trans,
311
+ "end": end_trans,
312
+ "text": trans,
313
+ "speaker": "0"
314
+ })
315
+ full_transcription += trans
316
+
317
+ # ----------------
318
+ # Process buffer
319
+ # ----------------
320
+ if args.vac:
321
+ # We need to access the underlying online object to get the buffer
322
+ buffer_text = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2]
323
+ else:
324
+ buffer_text = online.concatenate_tsw(online.transcript_buffer.buffer)[2]
325
+
326
+ if buffer_text in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
327
+ buffer_text = ""
328
+
329
+ buffer_line += buffer_text
330
+
331
+ punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however')
332
+ if not any(punctuation in buffer_line for punctuation in punctuations):
333
+ continue
334
+
335
+ last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line)
336
+ extracted_text = buffer_line[:last_punctuation_index]
337
+ buffer_line = buffer_line[last_punctuation_index:]
338
+ buffer = {'speaker': '0', 'text': extracted_text, 'translation': None}
339
+
340
+ translation = translator_wrapper(buffer['text'], translation_target_lang, mode='google')
341
+
342
+ buffer['translation'] = translation
343
+ buffer['text'] += ('|' + translation)
344
+ buffer['audio_url'] = generate_audio(translation, speed=2.5) if args.generate_audio else ''
345
+ buffers.append(buffer)
346
+
347
+ # ----------------
348
+ # Process lines
349
+ # ----------------
350
+ '''
351
+ print('Process lines')
352
+ lines = [{"speaker": "0", "text": ""}]
353
+
354
+ if args.diarization:
355
+ await diarization.diarize(pcm_array)
356
+ # diarization.assign_speakers_to_chunks(chunk_history)
357
+ chunk_history = diarization.assign_speakers_to_chunks(chunk_history)
358
+
359
+ for ch in chunk_history:
360
+ if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
361
+ lines.append({"speaker": ch["speaker"], "text": ch['text']})
362
+
363
+ else:
364
+ lines.append({"speaker": ch["speaker"], "text": ch['text']})
365
+
366
+ for i, line in enumerate(lines):
367
+ if line['text'].strip() == '':
368
+ continue
369
+ # translation = translator(line['text'])[0]['translation_text']
370
+ # translation = translation.replace(' ', '')
371
+ # lines[i]['text'] = line['text'] + translation
372
+ lines[i]['text'] = line['text']
373
+ '''
374
+
375
+ print('Before making response')
376
+ response = {'line': buffer, 'buffer': ''}
377
+ print(response)
378
+ await websocket.send_json(response)
379
+
380
+ except Exception as e:
381
+ print(f"Exception in ffmpeg_stdout_reader: {e}")
382
+ break
383
+
384
+ print("Exiting ffmpeg_stdout_reader...")
385
+
386
+ stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
387
+
388
+ try:
389
+ while True:
390
+ # Receive incoming WebM audio chunks from the client
391
+ message = await websocket.receive_bytes()
392
+ # Pass them to ffmpeg via stdin
393
+ ffmpeg_process.stdin.write(message)
394
+ ffmpeg_process.stdin.flush()
395
+
396
+ except WebSocketDisconnect:
397
+ print("WebSocket connection closed.")
398
+ except Exception as e:
399
+ print(f"Error in websocket loop: {e}")
400
+ finally:
401
+ # Clean up ffmpeg and the reader task
402
+ try:
403
+ ffmpeg_process.stdin.close()
404
+ except:
405
+ pass
406
+ stdout_reader_task.cancel()
407
+
408
+ try:
409
+ ffmpeg_process.stdout.close()
410
+ except:
411
+ pass
412
+
413
+ ffmpeg_process.wait()
414
+ del online
415
+
416
+ if args.diarization:
417
+ # Stop Diart
418
+ diarization.close()
419
+
420
+
421
+ if __name__ == "__main__":
422
+ import uvicorn
423
+
424
+ uvicorn.run(
425
+ "app:app", host=args.host, port=args.port, reload=True,
426
+ log_level="info"
427
+ )