Spaces:
Running
on
A10G
Running
on
A10G
import io | |
from flask import Flask, Response, send_from_directory, jsonify, request, abort | |
import os | |
from flask_cors import CORS | |
from multiprocessing import Queue | |
import base64 | |
from typing import Any, Dict, Tuple | |
from multiprocessing import Queue | |
import logging | |
import sys | |
from threading import Lock | |
from multiprocessing import Manager | |
import torch | |
from server.AudioTranscriber import AudioTranscriber | |
from server.ActionProcessor import ActionProcessor | |
from server.StandaloneApplication import StandaloneApplication | |
from server.TextFilterer import TextFilterer | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
# Get a logger for your app | |
logger = logging.getLogger(__name__) | |
# Use a directory in the user's home folder for static files | |
STATIC_DIR = ( | |
"/app/server/static" | |
if os.getenv("DEBUG") != "True" | |
else os.path.join(os.getcwd(), "html") | |
) | |
# Each packet is a tuple of (data, token) | |
audio_queue: "Queue[Tuple[io.BytesIO, str]]" = Queue() | |
text_queue: "Queue[Tuple[str, str]]" = Queue() | |
filtered_text_queue: "Queue[Tuple[str, str]]" = Queue() | |
action_queue: "Queue[Tuple[Dict[str, Any], str]]" = Queue() | |
# Thread-safe storage for actions by session | |
action_storage_lock = Lock() | |
manager = Manager() | |
action_storage = manager.dict() # Shared dictionary across processes | |
app = Flask(__name__, static_folder=STATIC_DIR) | |
_ = CORS( | |
app, | |
origins=["*"], | |
methods=["GET", "POST", "OPTIONS"], | |
allow_headers=["Content-Type", "Authorization"], | |
) | |
def add_header(response: Response): | |
# Add permissive CORS headers | |
response.headers["Access-Control-Allow-Origin"] = "*" | |
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" | |
response.headers["Access-Control-Allow-Headers"] = "*" # Allow all headers | |
# Cross-origin isolation headers | |
response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" | |
response.headers["Cross-Origin-Opener-Policy"] = "same-origin" | |
response.headers["Cross-Origin-Resource-Policy"] = "cross-origin" | |
return response | |
def serve_index(): | |
try: | |
response = send_from_directory(app.static_folder, "index.html") | |
response.headers["Cross-Origin-Opener-Policy"] = "same-origin" | |
response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" | |
return response | |
except FileNotFoundError: | |
abort( | |
404, | |
description=f"Static folder or index.html not found. Static folder: {app.static_folder}", | |
) | |
def get_data(): | |
return jsonify({"status": "success"}) | |
def post_order() -> Tuple[Response, int]: | |
try: | |
data = request.get_json() | |
if not data or "action" not in data: | |
return ( | |
jsonify({"error": "Missing 'action' in request", "status": "error"}), | |
400, | |
) | |
action_text: str = data["action"] | |
token = request.args.get("token") | |
if not token: | |
return jsonify({"error": "Missing token parameter", "status": "error"}), 400 | |
mid_split = len(action_text) // 2 | |
# Add the text to the queue | |
text_queue.put((action_text[:mid_split], token)) | |
text_queue.put((action_text, token)) | |
text_queue.put((action_text[mid_split:], token)) | |
return jsonify({"status": "success"}), 200 | |
except Exception as e: | |
return ( | |
jsonify( | |
{"error": f"Failed to process request: {str(e)}", "status": "error"} | |
), | |
500, | |
) | |
def process_data(): | |
try: | |
content_type = request.headers.get("Content-Type", "") | |
token = request.args.get("token") | |
if not token: | |
return jsonify({"error": "Missing token parameter", "status": "error"}), 400 | |
# Handle different content types | |
if "application/json" in content_type: | |
data = request.get_json() | |
audio_base64 = data.get("audio_chunk") | |
elif "multipart/form-data" in content_type: | |
audio_base64 = request.form.get("audio_chunk") | |
else: | |
# Try to get raw data | |
audio_base64 = request.get_data().decode("utf-8") | |
# Validate the incoming data | |
if not audio_base64: | |
return ( | |
jsonify({"error": "Missing audio_chunk in request", "status": "error"}), | |
400, | |
) | |
# Decode the base64 audio chunk | |
try: | |
audio_chunk = base64.b64decode(audio_base64) | |
except Exception as e: | |
return ( | |
jsonify( | |
{ | |
"error": f"Failed to decode audio chunk: {str(e)}", | |
"status": "error", | |
} | |
), | |
400, | |
) | |
# Put the audio chunk in the queue for processing | |
audio_queue.put((io.BytesIO(audio_chunk), token)) | |
return jsonify( | |
{ | |
"status": "success", | |
} | |
) | |
except Exception as e: | |
return ( | |
jsonify( | |
{"error": f"Failed to process request: {str(e)}", "status": "error"} | |
), | |
500, | |
) | |
def get_actions() -> Tuple[Response, int]: | |
"""Retrieve and clear all pending actions for the current session""" | |
token = request.args.get("token") | |
if not token: | |
return jsonify({"actions": [], "status": "error"}), 400 | |
with action_storage_lock: | |
# Get and clear actions for this session | |
actions = action_storage.get(token, []) | |
action_storage[token] = [] | |
return jsonify({"actions": actions, "status": "success"}), 200 | |
def serve_static(path: str): | |
try: | |
return send_from_directory(app.static_folder, path) | |
except FileNotFoundError: | |
abort(404, description=f"File {path} not found in static folder") | |
class ActionConsumer: | |
def __init__(self, action_queue: Queue): | |
self.action_queue = action_queue | |
self.running = True | |
def start(self): | |
import threading | |
self.thread = threading.Thread(target=self.run, daemon=True) | |
self.thread.start() | |
def run(self): | |
while self.running: | |
try: | |
action, token = self.action_queue.get() | |
with action_storage_lock: | |
if token not in action_storage: | |
logger.info(f"Creating new action storage for token: {token}") | |
action_storage[token] = [] | |
current_actions = action_storage[token] | |
current_actions.append(action) | |
action_storage[token] = current_actions | |
except Exception as e: | |
logger.error(f"Error in ActionConsumer: {e}") | |
if __name__ == "__main__": | |
if os.path.exists(app.static_folder): | |
logger.info(f"Static folder contents: {os.listdir(app.static_folder)}") | |
os.makedirs(app.static_folder, exist_ok=True) | |
num_devices = torch.cuda.device_count() | |
device_vram_gb: float = float( | |
torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
) | |
num_3gb_units = int(device_vram_gb) // 3 | |
logger.info( | |
f"Device 0 has {device_vram_gb:.1f}GB VRAM, equivalent to {num_3gb_units} units of Whisper" | |
) | |
# Start the audio transcriber thread | |
transcribers = [ | |
AudioTranscriber(audio_queue, text_queue, device_index=i % num_devices) | |
for i in range( | |
4 if os.getenv("DEBUG") == "True" else num_3gb_units * num_devices | |
) | |
] | |
for transcriber in transcribers: | |
transcriber.start() | |
# Start the action consumer thread | |
action_consumer = ActionConsumer(action_queue) | |
action_consumer.start() | |
# Start the action processor thread | |
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
if not MISTRAL_API_KEY: | |
raise ValueError("MISTRAL_API_KEY is not set") | |
filterer = TextFilterer(text_queue, filtered_text_queue) | |
filterer.start() | |
actions_processors = [ | |
ActionProcessor(filtered_text_queue, action_queue, MISTRAL_API_KEY) | |
for _ in range(4 if os.getenv("DEBUG") == "True" else 16) | |
] | |
for actions_processor in actions_processors: | |
actions_processor.start() | |
options: Any = { | |
"bind": "0.0.0.0:7860", | |
"workers": 3, | |
"worker_class": "sync", | |
"timeout": 120, | |
"forwarded_allow_ips": "*", | |
"accesslog": None, # Disable access logging | |
"errorlog": "-", # Keep error logging to stderr | |
"capture_output": True, | |
"enable_stdio_inheritance": True, | |
} | |
StandaloneApplication(app, options).run() | |