Spaces:
Sleeping
Sleeping
import torch | |
torch.set_float32_matmul_precision('high') | |
from flask import Flask, send_from_directory, request, Response | |
import os | |
import base64 | |
import numpy as np | |
from inference import OmniInference | |
import io | |
app = Flask(__name__) | |
# Initialize OmniInference | |
try: | |
print("Initializing OmniInference...") | |
omni = OmniInference() | |
print("OmniInference initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing OmniInference: {str(e)}") | |
raise | |
def serve_html(): | |
return send_from_directory('.', 'webui/omni_html_demo.html') | |
def chat(): | |
try: | |
audio_data = request.json['audio'] | |
if not audio_data: | |
return "No audio data received", 400 | |
# Check if the audio_data contains the expected base64 prefix | |
if ',' in audio_data: | |
audio_bytes = base64.b64decode(audio_data.split(',')[1]) | |
else: | |
audio_bytes = base64.b64decode(audio_data) | |
# Save audio to a temporary file | |
temp_audio_path = 'temp_audio.wav' | |
with open(temp_audio_path, 'wb') as f: | |
f.write(audio_bytes) | |
# Generate response using OmniInference | |
try: | |
response_generator = omni.run_AT_batch_stream(temp_audio_path) | |
# Concatenate all audio chunks | |
all_audio = b'' | |
for audio_chunk in response_generator: | |
all_audio += audio_chunk | |
# Clean up temporary file | |
os.remove(temp_audio_path) | |
return Response(all_audio, mimetype='audio/wav') | |
except Exception as inner_e: | |
print(f"Error in OmniInference processing: {str(inner_e)}") | |
return f"An error occurred during audio processing: {str(inner_e)}", 500 | |
finally: | |
# Ensure temporary file is removed even if an error occurs | |
if os.path.exists(temp_audio_path): | |
os.remove(temp_audio_path) | |
except Exception as e: | |
print(f"Error in chat endpoint: {str(e)}") | |
return f"An error occurred: {str(e)}", 500 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) | |