import os |
import io |
import requests |
import argparse |
import asyncio |
import numpy as np |
import ffmpeg |
from time import time |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
from fastapi.responses import HTMLResponse |
from fastapi.middleware.cors import CORSMiddleware |
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args |
import logging |
import logging.config |
from transformers import pipeline |
from huggingface_hub import login |
MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja' |
TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda') |
TRANSLATOR('Warming up!') |
def translator_wrapper(source_text, translation_target_lang, mode): |
if mode == 'deepl': |
params = { |
'auth_key' : os.environ['DEEPL_API_KEY'], |
'text' : source_text, |
'source_lang' : 'EN', |
"target_lang": 'JA', |
} |
try: |
request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) |
result = request.json()['translations'][0]['text'] |
except requests.exceptions.Timeout: |
result = "(timed out)" |
return result |
elif mode == 'marianmt': |
return TRANSLATOR(source_text)[0]['translation_text'] |
elif mode == 'google': |
import requests |
language_type = "" |
url = "https://translation.googleapis.com/language/translate/v2" |
data = { |
'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", |
'source': language_type, |
'target': translation_target_lang, |
'q': source_text, |
'format': "text" |
} |
response = requests.post(url, data) |
print(response) |
res = response.json() |
print(res["data"]["translations"][0]["translatedText"]) |
result = res["data"]["translations"][0]["translatedText"] |
print(result) |
return result |
def setup_logging(): |
logging_config = { |
'version': 1, |
'disable_existing_loggers': False, |
'formatters': { |
'standard': { |
'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s', |
}, |
}, |
'handlers': { |
'console': { |
'level': 'INFO', |
'class': 'logging.StreamHandler', |
'formatter': 'standard', |
}, |
}, |
'root': { |
'handlers': ['console'], |
'level': 'DEBUG', |
}, |
'loggers': { |
'uvicorn': { |
'handlers': ['console'], |
'level': 'INFO', |
'propagate': False, |
}, |
'uvicorn.error': { |
'level': 'INFO', |
}, |
'uvicorn.access': { |
'level': 'INFO', |
}, |
'src.whisper_streaming.online_asr': { |
'handlers': ['console'], |
'level': 'DEBUG', |
'propagate': False, |
}, |
'src.whisper_streaming.whisper_streaming': { |
'handlers': ['console'], |
'level': 'DEBUG', |
'propagate': False, |
}, |
}, |
} |
logging.config.dictConfig(logging_config) |
setup_logging() |
logger = logging.getLogger(__name__) |
app = FastAPI() |
app.add_middleware( |
CORSMiddleware, |
allow_origins=["*"], |
allow_credentials=True, |
allow_methods=["*"], |
allow_headers=["*"], |
) |
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server") |
parser.add_argument( |
"--host", |
type=str, |
default="localhost", |
help="The host address to bind the server to.", |
) |
parser.add_argument( |
"--port", type=int, default=8000, help="The port number to bind the server to." |
) |
parser.add_argument( |
"--warmup-file", |
type=str, |
dest="warmup_file", |
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .", |
) |
parser.add_argument( |
"--diarization", |
type=bool, |
default=False, |
help="Whether to enable speaker diarization.", |
) |
parser.add_argument( |
"--generate-audio", |
type=bool, |
default=False, |
help="Whether to generate translation audio.", |
) |
add_shared_args(parser) |
args = parser.parse_args() |
if args.lan == 'ja': |
translation_target_lang = 'en' |
elif args.lan == 'en': |
translation_target_lang = 'ja' |
asr, tokenizer = backend_factory(args) |
if args.diarization: |
from src.diarization.diarization_online import DiartDiarization |
with open("src/web/live_transcription.html", "r", encoding="utf-8") as f: |
html = f.read() |
@app.get("/") |
async def get(): |
return HTMLResponse(html) |
SAMPLE_RATE = 16000 |
SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size) |
def generate_audio(japanese_text, speed=1.0): |
api_url = "https://j6im8slpwcevr7g0.us-east-1.aws.endpoints.huggingface.cloud" |
headers = { |
"Accept" : "application/json", |
"Authorization": f"Bearer {HUGGING_FACE_TOKEN}", |
"Content-Type": "application/json" |
} |
payload = { |
"inputs": japanese_text, |
"speed": speed, |
} |
response = requests.post(api_url, headers=headers, json=payload).json() |
if 'error' in response: |
print(response) |
return '' |
return response |
async def start_ffmpeg_decoder(): |
""" |
Start an FFmpeg process in async streaming mode that reads WebM from stdin |
and outputs raw s16le PCM on stdout. Returns the process object. |
""" |
process = ( |
ffmpeg |
.input("pipe:0", format="webm") |
.output( |
"pipe:1", |
format="s16le", |
acodec="pcm_s16le", |
ar=str(SAMPLE_RATE), |
) |
.global_args('-loglevel', 'quiet') |
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=False, quiet=True) |
) |
return process |
import queue |
import threading |
@app.websocket("/asr") |
async def websocket_endpoint(websocket: WebSocket): |
await websocket.accept() |
print("WebSocket connection opened.") |
ffmpeg_process = await start_ffmpeg_decoder() |
pcm_buffer = bytearray() |
print("Loading online.") |
online = online_factory(args, asr, tokenizer) |
print("Online loaded.") |
if args.diarization: |
diarization = DiartDiarization(SAMPLE_RATE) |
async def ffmpeg_stdout_reader(): |
nonlocal pcm_buffer |
loop = asyncio.get_event_loop() |
full_transcription = "" |
beg = time() |
chunk_history = [] |
buffers = [{'speaker': '0', 'text': '', 'translation': None, 'audio_url': None}] |
buffer_line = '' |
chunk_queue = queue.Queue() |
def read_ffmpeg_stdout(): |
while True: |
try: |
chunk = ffmpeg_process.stdout.read(BYTES_PER_SEC) |
if not chunk: |
break |
chunk_queue.put(chunk) |
except Exception as e: |
print(f"Exception in read_ffmpeg_stdout: {e}") |
break |
threading.Thread(target=read_ffmpeg_stdout, daemon=True).start() |
while True: |
try: |
chunk = await loop.run_in_executor(None, chunk_queue.get) |
if not chunk: |
print("FFmpeg stdout closed.") |
break |
pcm_buffer.extend(chunk) |
print('len(pcm_buffer): ', len(pcm_buffer)) |
if len(pcm_buffer) >= BYTES_PER_SEC: |
pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0) |
pcm_buffer = bytearray() |
online.insert_audio_chunk(pcm_array) |
beg_trans, end_trans, trans = online.process_iter() |
if trans: |
chunk_history.append({ |
"beg": beg_trans, |
"end": end_trans, |
"text": trans, |
"speaker": "0" |
}) |
full_transcription += trans |
if args.vac: |
buffer_text = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2] |
else: |
buffer_text = online.concatenate_tsw(online.transcript_buffer.buffer)[2] |
if buffer_text in full_transcription: |
buffer_text = "" |
buffer_line += buffer_text |
punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however') |
if not any(punctuation in buffer_line for punctuation in punctuations): |
continue |
last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line) |
extracted_text = buffer_line[:last_punctuation_index] |
buffer_line = buffer_line[last_punctuation_index:] |
buffer = {'speaker': '0', 'text': extracted_text, 'translation': None} |
translation = translator_wrapper(buffer['text'], translation_target_lang, mode='google') |
buffer['translation'] = translation |
buffer['text'] += ('|' + translation) |
buffer['audio_url'] = generate_audio(translation, speed=1.5) if args.generate_audio else '' |
buffers.append(buffer) |
''' |
print('Process lines') |
lines = [{"speaker": "0", "text": ""}] |
if args.diarization: |
await diarization.diarize(pcm_array) |
# diarization.assign_speakers_to_chunks(chunk_history) |
chunk_history = diarization.assign_speakers_to_chunks(chunk_history) |
for ch in chunk_history: |
if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]: |
lines.append({"speaker": ch["speaker"], "text": ch['text']}) |
else: |
lines.append({"speaker": ch["speaker"], "text": ch['text']}) |
for i, line in enumerate(lines): |
if line['text'].strip() == '': |
continue |
# translation = translator(line['text'])[0]['translation_text'] |
# translation = translation.replace(' ', '') |
# lines[i]['text'] = line['text'] + translation |
lines[i]['text'] = line['text'] |
''' |
print('Before making response') |
response = {'line': buffer, 'buffer': ''} |
print(response) |
await websocket.send_json(response) |
except Exception as e: |
print(f"Exception in ffmpeg_stdout_reader: {e}") |
break |
print("Exiting ffmpeg_stdout_reader...") |
stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader()) |
try: |
while True: |
message = await websocket.receive_bytes() |
ffmpeg_process.stdin.write(message) |
ffmpeg_process.stdin.flush() |
except WebSocketDisconnect: |
print("WebSocket connection closed.") |
except Exception as e: |
print(f"Error in websocket loop: {e}") |
finally: |
try: |
ffmpeg_process.stdin.close() |
except: |
pass |
stdout_reader_task.cancel() |
try: |
ffmpeg_process.stdout.close() |
except: |
pass |
ffmpeg_process.wait() |
del online |
if args.diarization: |
diarization.close() |
if __name__ == "__main__": |
import uvicorn |
uvicorn.run( |
"app:app", host=args.host, port=args.port, reload=True, |
log_level="info" |
) |