|
import os |
|
import io |
|
import requests |
|
import argparse |
|
import asyncio |
|
import numpy as np |
|
import ffmpeg |
|
from time import time |
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args |
|
|
|
import logging |
|
import logging.config |
|
from transformers import pipeline |
|
from huggingface_hub import login |
|
|
|
HUGGING_FACE_TOKEN = os.environ['HUGGING_FACE_TOKEN'] |
|
login(HUGGING_FACE_TOKEN) |
|
|
|
|
|
|
|
MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja' |
|
TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda') |
|
TRANSLATOR('Warming up!') |
|
|
|
def translator_wrapper(source_text, translation_target_lang, mode): |
|
if mode == 'deepl': |
|
params = { |
|
'auth_key' : os.environ['DEEPL_API_KEY'], |
|
'text' : source_text, |
|
'source_lang' : 'EN', |
|
"target_lang": 'JA', |
|
} |
|
|
|
|
|
try: |
|
request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) |
|
result = request.json()['translations'][0]['text'] |
|
except requests.exceptions.Timeout: |
|
result = "(timed out)" |
|
return result |
|
|
|
elif mode == 'marianmt': |
|
return TRANSLATOR(source_text)[0]['translation_text'] |
|
|
|
elif mode == 'google': |
|
import requests |
|
|
|
|
|
language_type = "" |
|
url = "https://translation.googleapis.com/language/translate/v2" |
|
data = { |
|
'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", |
|
'source': language_type, |
|
'target': translation_target_lang, |
|
'q': source_text, |
|
'format': "text" |
|
} |
|
|
|
|
|
response = requests.post(url, data) |
|
|
|
print(response) |
|
res = response.json() |
|
print(res["data"]["translations"][0]["translatedText"]) |
|
result = res["data"]["translations"][0]["translatedText"] |
|
print(result) |
|
return result |
|
|
|
|
|
def setup_logging(): |
|
logging_config = { |
|
'version': 1, |
|
'disable_existing_loggers': False, |
|
'formatters': { |
|
'standard': { |
|
'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s', |
|
}, |
|
}, |
|
'handlers': { |
|
'console': { |
|
'level': 'INFO', |
|
'class': 'logging.StreamHandler', |
|
'formatter': 'standard', |
|
}, |
|
}, |
|
'root': { |
|
'handlers': ['console'], |
|
'level': 'DEBUG', |
|
}, |
|
'loggers': { |
|
'uvicorn': { |
|
'handlers': ['console'], |
|
'level': 'INFO', |
|
'propagate': False, |
|
}, |
|
'uvicorn.error': { |
|
'level': 'INFO', |
|
}, |
|
'uvicorn.access': { |
|
'level': 'INFO', |
|
}, |
|
'src.whisper_streaming.online_asr': { |
|
'handlers': ['console'], |
|
'level': 'DEBUG', |
|
'propagate': False, |
|
}, |
|
'src.whisper_streaming.whisper_streaming': { |
|
'handlers': ['console'], |
|
'level': 'DEBUG', |
|
'propagate': False, |
|
}, |
|
}, |
|
} |
|
|
|
logging.config.dictConfig(logging_config) |
|
|
|
setup_logging() |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") |
|
parser.add_argument( |
|
"--host", |
|
type=str, |
|
default="localhost", |
|
help="The host address to bind the server to.", |
|
) |
|
parser.add_argument( |
|
"--port", type=int, default=8000, help="The port number to bind the server to." |
|
) |
|
parser.add_argument( |
|
"--warmup-file", |
|
type=str, |
|
dest="warmup_file", |
|
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .", |
|
) |
|
parser.add_argument( |
|
"--diarization", |
|
type=bool, |
|
default=False, |
|
help="Whether to enable speaker diarization.", |
|
) |
|
parser.add_argument( |
|
"--generate-audio", |
|
type=bool, |
|
default=False, |
|
help="Whether to generate translation audio.", |
|
) |
|
|
|
|
|
add_shared_args(parser) |
|
args = parser.parse_args() |
|
|
|
|
|
if args.lan == 'ja': |
|
translation_target_lang = 'en' |
|
elif args.lan == 'en': |
|
translation_target_lang = 'ja' |
|
|
|
asr, tokenizer = backend_factory(args) |
|
|
|
if args.diarization: |
|
from src.diarization.diarization_online import DiartDiarization |
|
|
|
|
|
|
|
with open("src/web/live_transcription.html", "r", encoding="utf-8") as f: |
|
html = f.read() |
|
|
|
|
|
@app.get("/") |
|
async def get(): |
|
return HTMLResponse(html) |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
CHANNELS = 1 |
|
SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size) |
|
BYTES_PER_SAMPLE = 2 |
|
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE |
|
print('SAMPLE_RATE', SAMPLE_RATE) |
|
print('CHANNELS', CHANNELS) |
|
print('SAMPLES_PER_SEC', SAMPLES_PER_SEC) |
|
print('BYTES_PER_SAMPLE', BYTES_PER_SAMPLE) |
|
print('BYTES_PER_SEC', BYTES_PER_SEC) |
|
|
|
|
|
def generate_audio(japanese_text, speed=1.0): |
|
api_url = "https://j6im8slpwcevr7g0.us-east-1.aws.endpoints.huggingface.cloud" |
|
headers = { |
|
"Accept" : "application/json", |
|
"Authorization": f"Bearer {HUGGING_FACE_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
payload = { |
|
"inputs": japanese_text, |
|
"speed": speed, |
|
} |
|
|
|
response = requests.post(api_url, headers=headers, json=payload).json() |
|
if 'error' in response: |
|
print(response) |
|
return '' |
|
return response |
|
|
|
|
|
async def start_ffmpeg_decoder(): |
|
""" |
|
Start an FFmpeg process in async streaming mode that reads WebM from stdin |
|
and outputs raw s16le PCM on stdout. Returns the process object. |
|
""" |
|
process = ( |
|
ffmpeg |
|
.input("pipe:0", format="webm") |
|
.output( |
|
"pipe:1", |
|
format="s16le", |
|
acodec="pcm_s16le", |
|
ac=CHANNELS, |
|
ar=str(SAMPLE_RATE), |
|
|
|
) |
|
.global_args('-loglevel', 'quiet') |
|
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=False, quiet=True) |
|
) |
|
return process |
|
|
|
import queue |
|
import threading |
|
|
|
@app.websocket("/asr") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
print("WebSocket connection opened.") |
|
|
|
ffmpeg_process = await start_ffmpeg_decoder() |
|
pcm_buffer = bytearray() |
|
print("Loading online.") |
|
online = online_factory(args, asr, tokenizer) |
|
print("Online loaded.") |
|
|
|
if args.diarization: |
|
diarization = DiartDiarization(SAMPLE_RATE) |
|
|
|
|
|
async def ffmpeg_stdout_reader(): |
|
nonlocal pcm_buffer |
|
loop = asyncio.get_event_loop() |
|
full_transcription = "" |
|
beg = time() |
|
|
|
chunk_history = [] |
|
|
|
buffers = [{'speaker': '0', 'text': '', 'translation': None, 'audio_url': None}] |
|
buffer_line = '' |
|
|
|
|
|
chunk_queue = queue.Queue() |
|
|
|
|
|
def read_ffmpeg_stdout(): |
|
while True: |
|
try: |
|
chunk = ffmpeg_process.stdout.read(BYTES_PER_SEC) |
|
if not chunk: |
|
break |
|
chunk_queue.put(chunk) |
|
except Exception as e: |
|
print(f"Exception in read_ffmpeg_stdout: {e}") |
|
break |
|
|
|
|
|
threading.Thread(target=read_ffmpeg_stdout, daemon=True).start() |
|
|
|
while True: |
|
try: |
|
|
|
chunk = await loop.run_in_executor(None, chunk_queue.get) |
|
if not chunk: |
|
print("FFmpeg stdout closed.") |
|
break |
|
|
|
pcm_buffer.extend(chunk) |
|
print('len(pcm_buffer): ', len(pcm_buffer)) |
|
print('BYTES_PER_SEC: ', BYTES_PER_SEC) |
|
|
|
if len(pcm_buffer) >= BYTES_PER_SEC: |
|
|
|
pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0) |
|
pcm_buffer = bytearray() |
|
online.insert_audio_chunk(pcm_array) |
|
beg_trans, end_trans, trans = online.process_iter() |
|
|
|
if trans: |
|
chunk_history.append({ |
|
"beg": beg_trans, |
|
"end": end_trans, |
|
"text": trans, |
|
"speaker": "0" |
|
}) |
|
full_transcription += trans |
|
|
|
|
|
|
|
|
|
if args.vac: |
|
|
|
buffer_text = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2] |
|
else: |
|
buffer_text = online.concatenate_tsw(online.transcript_buffer.buffer)[2] |
|
|
|
if buffer_text in full_transcription: |
|
buffer_text = "" |
|
|
|
buffer_line += buffer_text |
|
|
|
punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however') |
|
if not any(punctuation in buffer_line for punctuation in punctuations): |
|
continue |
|
|
|
last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line) |
|
extracted_text = buffer_line[:last_punctuation_index] |
|
buffer_line = buffer_line[last_punctuation_index:] |
|
buffer = {'speaker': '0', 'text': extracted_text, 'translation': None} |
|
|
|
translation = translator_wrapper(buffer['text'], translation_target_lang, mode='google') |
|
|
|
buffer['translation'] = translation |
|
buffer['text'] += ('|' + translation) |
|
buffer['audio_url'] = generate_audio(translation, speed=1.5) if args.generate_audio else '' |
|
buffers.append(buffer) |
|
|
|
|
|
|
|
|
|
''' |
|
print('Process lines') |
|
lines = [{"speaker": "0", "text": ""}] |
|
|
|
if args.diarization: |
|
await diarization.diarize(pcm_array) |
|
# diarization.assign_speakers_to_chunks(chunk_history) |
|
chunk_history = diarization.assign_speakers_to_chunks(chunk_history) |
|
|
|
for ch in chunk_history: |
|
if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]: |
|
lines.append({"speaker": ch["speaker"], "text": ch['text']}) |
|
|
|
else: |
|
lines.append({"speaker": ch["speaker"], "text": ch['text']}) |
|
|
|
for i, line in enumerate(lines): |
|
if line['text'].strip() == '': |
|
continue |
|
# translation = translator(line['text'])[0]['translation_text'] |
|
# translation = translation.replace(' ', '') |
|
# lines[i]['text'] = line['text'] + translation |
|
lines[i]['text'] = line['text'] |
|
''' |
|
|
|
print('Before making response') |
|
response = {'line': buffer, 'buffer': ''} |
|
print(response) |
|
await websocket.send_json(response) |
|
|
|
except Exception as e: |
|
print(f"Exception in ffmpeg_stdout_reader: {e}") |
|
break |
|
|
|
print("Exiting ffmpeg_stdout_reader...") |
|
|
|
stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader()) |
|
|
|
try: |
|
while True: |
|
|
|
message = await websocket.receive_bytes() |
|
|
|
ffmpeg_process.stdin.write(message) |
|
ffmpeg_process.stdin.flush() |
|
|
|
except WebSocketDisconnect: |
|
print("WebSocket connection closed.") |
|
except Exception as e: |
|
print(f"Error in websocket loop: {e}") |
|
finally: |
|
|
|
try: |
|
ffmpeg_process.stdin.close() |
|
except: |
|
pass |
|
stdout_reader_task.cancel() |
|
|
|
try: |
|
ffmpeg_process.stdout.close() |
|
except: |
|
pass |
|
|
|
ffmpeg_process.wait() |
|
del online |
|
|
|
if args.diarization: |
|
|
|
diarization.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run( |
|
"app:app", host=args.host, port=args.port, reload=True, |
|
log_level="info" |
|
) |
|
|