Spaces:
Running
Running
import gradio as gr | |
import httpx | |
import os | |
import json | |
import inspect | |
import aiofiles | |
import asyncio | |
import tempfile | |
import math | |
from gradio.themes import colors, sizes, Font, GoogleFont, Origin | |
limits = httpx.Limits(max_connections=100, max_keepalive_connections=20) | |
HTTP_CLIENT = httpx.AsyncClient( | |
http2=False, | |
limits=limits, | |
timeout=httpx.Timeout(30.0, pool=10.0) | |
) | |
# --- SRT Generation Functions --- | |
def format_srt_time(total_seconds): | |
hours = math.floor(total_seconds / 3600) | |
minutes = math.floor((total_seconds % 3600) / 60) | |
seconds = math.floor(total_seconds % 60) | |
milliseconds = round((total_seconds - math.floor(total_seconds)) * 1000) | |
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" | |
def generate_srt_file(json_data): | |
if not json_data or "segments" not in json_data or not json_data["segments"]: | |
print("No segments to convert to SRT.") | |
return None | |
srt_content = "" | |
for index, segment in enumerate(json_data["segments"]): | |
sequence = index + 1 | |
start_time = format_srt_time(segment['start']) | |
end_time = format_srt_time(segment['end']) | |
text = segment['text'].strip() | |
srt_content += f"{sequence}\n{start_time} --> {end_time}\n{text}\n\n" | |
try: | |
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.srt', encoding='utf-8') as temp_file: | |
temp_file.write(srt_content) | |
print(f"Temporary SRT file created at: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
print(f"Error creating SRT file: {e}") | |
return None | |
# --- Tools for Chatbot --- | |
def get_city_info(city: str): | |
if "paris" in city.lower(): | |
return json.dumps({"population": "2.1 million", "monument": "Eiffel Tower", "fact": "Paris is known as the 'City of Light'."}) | |
elif "tokyo" in city.lower(): | |
return json.dumps({"population": "14 million", "monument": "Tokyo Tower", "fact": "Tokyo is the largest metropolitan area in the world."}) | |
else: | |
return json.dumps({"error": f"Sorry, I don't have information about {city}."}) | |
available_tools = {"get_city_info": get_city_info} | |
tools_schema = [ | |
{"type": "function", "function": {"name": "get_city_info", "description": "Get information about a specific city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The name of the city, e.g., 'Paris'."}}, "required": ["city"]}}} | |
] | |
async def upload_file_to_public_service(filepath: str): | |
if not filepath: | |
return None | |
url = "https://uguu.se/upload" | |
try: | |
async with aiofiles.open(filepath, 'rb') as f: | |
content = await f.read() | |
files = {'files[]': (os.path.basename(filepath), content)} | |
response = await HTTP_CLIENT.post(url, files=files, timeout=30.0) | |
response.raise_for_status() | |
result = response.json() | |
if "files" in result and result["files"] and "url" in result["files"][0]: | |
full_url = result["files"][0]["url"] | |
# print(f"File successfully uploaded: {full_url}") | |
return full_url | |
else: | |
print(f"Upload API response error: {result}") | |
return None | |
except httpx.HTTPStatusError as e: | |
print(f"HTTP error during upload: {e.response.status_code} - {e.response.text}") | |
return None | |
except httpx.RequestError as e: | |
print(f"Connection error during upload: {e}") | |
return None | |
except (IOError, FileNotFoundError) as e: | |
print(f"File read error: {e}") | |
return None | |
except (KeyError, IndexError) as e: | |
print(f"Unexpected JSON response structure: {e}") | |
return None | |
async def handle_api_call(api_key, messages, model_chat): | |
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
api_url = "https://api.mistral.ai/v1/chat/completions" | |
json_data = { | |
"model": model_chat, | |
"messages": messages, | |
"tools": tools_schema, | |
"tool_choice": "auto" | |
} | |
return await HTTP_CLIENT.post(api_url, headers=headers, json=json_data, timeout=60.0) | |
async def handle_chat_submission(api_key, user_message, chat_history_api, audio_url, model_chat): | |
if not api_key: | |
user_content = [{"type": "text", "text": user_message}] if user_message else [] | |
updated_history = chat_history_api + [{"role": "user", "content": user_content}, {"role": "assistant", "content": "Error: API key not configured."}] | |
return updated_history, "" | |
current_user_content = [] | |
if audio_url: | |
current_user_content.append({"type": "input_audio", "input_audio": {"data": audio_url, "format": "mp3"}}) | |
if user_message: | |
current_user_content.append({"type": "text", "text": user_message}) | |
if not current_user_content: | |
return chat_history_api, "" | |
chat_history_api.append({"role": "user", "content": current_user_content}) | |
try: | |
response = await handle_api_call(api_key, chat_history_api, model_chat) | |
if response.status_code != 200: | |
error_msg = response.json().get("message", response.text) | |
chat_history_api.append({"role": "assistant", "content": f"Error API: {error_msg}"}) | |
return chat_history_api, "" | |
assistant_message = response.json()['choices'][0]['message'] | |
except httpx.HTTPStatusError as e: | |
error_msg = e.response.json().get("message", e.response.text) | |
chat_history_api.append({"role": "assistant", "content": f"Error API: {error_msg}"}) | |
return chat_history_api, "" | |
except httpx.RequestError as e: | |
chat_history_api.append({"role": "assistant", "content": f"Connection error: {e}"}) | |
return chat_history_api, "" | |
if assistant_message.get("tool_calls"): | |
chat_history_api.append(assistant_message) | |
tool_call = assistant_message["tool_calls"][0] | |
function_name = tool_call['function']['name'] | |
function_args = json.loads(tool_call['function']['arguments']) | |
if function_name in available_tools: | |
tool_call_id = tool_call['id'] | |
function_to_call = available_tools[function_name] | |
if inspect.iscoroutinefunction(function_to_call): | |
tool_output = await function_to_call(**function_args) | |
else: | |
tool_output = function_to_call(**function_args) | |
chat_history_api.append({ | |
"role": "tool", | |
"tool_call_id": tool_call_id, | |
"content": tool_output | |
}) | |
try: | |
second_response = await handle_api_call(api_key, chat_history_api, model_chat) | |
if second_response.status_code != 200: | |
error_msg = second_response.json().get("message", second_response.text) | |
chat_history_api.append({"role": "assistant", "content": f"Error API after tool call: {error_msg}"}) | |
else: | |
chat_history_api.append(second_response.json()['choices'][0]['message']) | |
except httpx.HTTPStatusError as e: | |
error_msg = e.response.json().get("message", e.response.text) | |
chat_history_api.append({"role": "assistant", "content": f"Error API after tool call: {error_msg}"}) | |
except httpx.RequestError as e: | |
chat_history_api.append({"role": "assistant", "content": f"Connection error after tool call: {e}"}) | |
else: | |
chat_history_api.append({"role": "assistant", "content": f"Error: Unknown tool '{function_name}'."}) | |
else: | |
chat_history_api.append(assistant_message) | |
return chat_history_api, "" | |
def format_history_for_display(api_history): | |
display_messages = [] | |
for msg in api_history: | |
if msg['role'] == 'user': | |
text_content = "" | |
has_audio = False | |
if isinstance(msg.get('content'), list): | |
text_part = next((part['text'] for part in msg['content'] if part['type'] == 'text'), None) | |
if text_part: | |
text_content = text_part | |
if any(part['type'] == 'input_audio' for part in msg['content']): | |
has_audio = True | |
elif isinstance(msg.get('content'), str): | |
text_content = msg['content'] | |
display_content = f"🎤 {text_content}" if has_audio else text_content | |
if display_content: | |
display_messages.append({"role": "user", "content": display_content}) | |
elif msg['role'] == 'assistant': | |
if msg.get("tool_calls"): | |
tool_call = msg["tool_calls"][0] | |
func_name = tool_call['function']['name'] | |
func_args = tool_call['function']['arguments'] | |
tool_display_content = f"⚙️ *Calling tool `{func_name}` with arguments : `{func_args}`...*" | |
display_messages.append({"role": "assistant", "content": tool_display_content}) | |
elif msg.get('content'): | |
display_messages.append({"role": "assistant", "content": msg['content']}) | |
return display_messages | |
async def transcribe_audio(api_key, source_type, audio_file_path, audio_url, add_timestamps, model_transcription): | |
if not api_key: | |
return {"error": "Please first enter your API key."} | |
headers = {"Authorization": f"Bearer {api_key}"} | |
api_url = "https://api.mistral.ai/v1/audio/transcriptions" | |
try: | |
payload = {'model': (None, model_transcription)} | |
if add_timestamps: | |
payload['timestamp_granularities'] = (None, 'segment') | |
if source_type == "Upload a file": | |
if not audio_file_path: | |
return {"error": "Please upload an audio file."} | |
async with aiofiles.open(audio_file_path, "rb") as f: | |
content = await f.read() | |
payload['file'] = (os.path.basename(audio_file_path), content, "audio/mpeg") | |
response = await HTTP_CLIENT.post(api_url, headers=headers, files=payload, timeout=120.0) | |
elif source_type == "Use a URL": | |
if not audio_url: | |
return {"error": "Please provide an audio URL."} | |
payload['file_url'] = (None, audio_url) | |
response = await HTTP_CLIENT.post(api_url, headers=headers, files=payload, timeout=120.0) | |
else: | |
return {"error": "Invalid source type."} | |
response.raise_for_status() | |
return response.json() | |
except httpx.HTTPStatusError as e: | |
try: | |
await e.response.aread() | |
details = e.response.json().get("message", e.response.text) | |
except Exception: | |
details = e.response.text | |
return {"error": f"API Error {e.response.status_code}", "details": details} | |
except httpx.RequestError as e: | |
return {"error": "Connection error", "details": str(e)} | |
except IOError as e: | |
return {"error": "File reading error", "details": str(e)} | |
async def run_transcription_and_update_ui(api_key, source, file_path, url, timestamps, model_transcription): | |
yield { | |
transcription_button: gr.update(value="⏳ Transcription in progress...", interactive=False), | |
transcription_status: gr.update(value="*Starting transcription...*", visible=True), | |
transcription_output: gr.update(visible=False), | |
download_zone: gr.update(visible=False), | |
download_file_output: gr.update(value=None) | |
} | |
json_result = await transcribe_audio(api_key, source, file_path, url, timestamps, model_transcription) | |
has_segments = isinstance(json_result, dict) and "segments" in json_result and json_result.get("segments") | |
is_error = isinstance(json_result, dict) and "error" in json_result | |
if is_error: | |
error_title = json_result.get("error", "Unknown error") | |
error_details = json_result.get("details", "No details available.") | |
yield { | |
transcription_status: gr.update(value=f"### ❌ {error_title}\n\n*_{error_details}_*", visible=True), | |
transcription_button: gr.update(value="▶️ Start transcription", interactive=True) | |
} | |
elif has_segments: | |
yield { | |
transcription_status: gr.update(value="### ✔️ Transcription complete!", visible=True), | |
transcription_output: gr.update(value=json_result, visible=True), | |
download_zone: gr.update(visible=True), | |
transcription_button: gr.update(value="▶️ Start transcription", interactive=True) | |
} | |
else: | |
text_result = json_result.get('text', "No text detected.") | |
yield { | |
transcription_status: gr.update(value=f"### ⚠️ Partial result\n\n_{text_result}_", visible=True), | |
transcription_output: gr.update(value=json_result, visible=True), | |
download_zone: gr.update(visible=False), | |
transcription_button: gr.update(value="▶️ Start transcription", interactive=True) | |
} | |
theme = gr.themes.Origin( | |
primary_hue="orange", | |
secondary_hue="gray", | |
neutral_hue="zinc", | |
text_size="md", | |
spacing_size="md", | |
radius_size="xxl", | |
font=("Inter", "IBM Plex Sans", "ui-sans-serif", "system-ui", "sans-serif"), | |
).set( | |
body_background_fill="#f7f8fa", | |
block_background_fill="#fff", | |
block_shadow="0 4px 24px 0 #0001, 0 1.5px 4px 0 #0001", | |
block_border_width="1px", | |
block_border_color="#ececec", | |
button_primary_background_fill="#223a5e", | |
button_primary_background_fill_hover="#1a2c47", | |
button_primary_text_color="#fff", | |
input_border_color="#e5e7eb", | |
input_border_color_focus="#223a5e", | |
input_background_fill="#fafbfc", | |
input_shadow="0 0 0 2px #223a5e22", | |
) | |
custom_css = """ | |
.gradio-container label, | |
.gradio-container .gr-button, | |
.gradio-container .gr-button span, | |
.gradio-container a { | |
color: #FF6F3C !important; | |
} | |
.gradio-container .gr-button { | |
border-color: #FF6F3C !important; | |
background: #FF6F3C !important; | |
color: #fff !important; | |
} | |
.gradio-container .gr-button:not([disabled]):hover { | |
background: #fff !important; | |
color: #FF6F3C !important; | |
border: 2px solid #FF6F3C !important; | |
} | |
.gradio-container .gr-box, .gradio-container .gr-block { | |
background: #f7f8fa !important; | |
} | |
.gradio-container .gr-input, .gradio-container .gr-textbox, .gradio-container .gr-text-input, .gradio-container .gr-file, .gradio-container .gr-audio { | |
background: #f7f8fa !important; | |
color: #223a5e !important; | |
border: 1.2px solid #FF6F3C !important; | |
} | |
.gradio-container .gr-input::placeholder, .gradio-container .gr-textbox::placeholder, .gradio-container .gr-text-input::placeholder { | |
color: #888 !important; | |
} | |
.gradio-container .gr-markdown, .gradio-container .gr-markdown p { | |
color: #888 !important; | |
} | |
""" | |
with gr.Blocks(theme=theme, title="Voxtral Pro", css=custom_css) as demo: | |
gr.Markdown(""" | |
<div style='text-align:center; margin-bottom:1.5em;'> | |
<h1 style='margin-bottom:0.2em; color:#FF6F3C;'>Voxtral Pro</h1> | |
<div style='font-size:1.1em; color:#555;'>The all-in-one AI assistant for audio, text & productivity.<br> | |
<b style='color:#FF6F3C;'>Fast and powerful.</b></div> | |
<div style='width:60px; height:4px; background:#FF6F3C; margin:18px auto 0 auto; border-radius:2px;'></div> | |
</div> | |
""") | |
api_history_state = gr.State([]) | |
api_key_state = gr.State() | |
# Dropdowns for model selection | |
model_choices = ["voxtral-mini-2507", "voxtral-small-2507"] | |
chat_model_state = gr.State("voxtral-mini-2507") | |
transcription_model_state = gr.State("voxtral-mini-2507") | |
with gr.Accordion("🔑 API Key Configuration", open=True): | |
with gr.Row(): | |
api_key_input = gr.Textbox(label="Mistral API Key", placeholder="Enter your API key here...", type="password", scale=6) | |
chat_model_dropdown = gr.Dropdown(choices=model_choices, value="voxtral-mini-2507", label="Chat Model", scale=2) | |
transcription_model_dropdown = gr.Dropdown(choices=model_choices, value="voxtral-mini-2507", label="Transcription Model", scale=2) | |
save_api_key_button = gr.Button("Save Key", scale=1) | |
api_key_status = gr.Markdown(value="*Please save your API key to use the application.*") | |
gr.Markdown( | |
"<span style='font-size: 0.95em; color: #888;'>🔒 <b>Security:</b> Your API key is stored only in your browser session memory and is never sent to any server except Mistral's API. It is not saved or shared anywhere else.</span>", | |
elem_id="api-key-security-info" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("💬 Multimodal Chatbot"): | |
gr.Markdown("### Chat with text and audio files at any time.") | |
chatbot_display = gr.Chatbot( | |
label="Conversation", | |
height=500, | |
avatar_images=(None, "/Users/hasanbasbunar/voxtral-gradio/c29ca011-87ff-45b0-8236-08d629812732.svg"), | |
type="messages" | |
) | |
with gr.Row(): | |
audio_input_files = gr.File( | |
label="Drag and drop your audio files here", | |
file_count="multiple", | |
file_types=["audio"], | |
elem_id="upload-box", | |
scale=2, | |
height=100 | |
) | |
user_textbox = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2, | |
scale=6, | |
elem_id="user-message-box", | |
) | |
mic_input = gr.Audio( | |
label="Voice recording", | |
sources=["microphone"], | |
type="filepath", | |
elem_classes="voice-recorder", | |
scale=2, | |
) | |
send_button = gr.Button("Send", variant="primary") | |
clear_button = gr.Button("🗑️ Clear conversation", variant="secondary") | |
with gr.TabItem("🎙️ Audio Transcription"): | |
gr.Markdown("### Transcribe an audio file and export the result.") | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
gr.Markdown("#### 1. Audio Source") | |
source_type_transcription = gr.Radio(["Upload a file", "Use a URL"], label="Source type", value="Upload a file") | |
audio_file_input = gr.Audio(type="filepath", label="Audio file", visible=True) | |
audio_url_input = gr.Textbox(label="Audio file URL", placeholder="https://.../audio.mp3", visible=False) | |
gr.Markdown("#### 2. Options") | |
timestamp_checkbox = gr.Checkbox(label="Include timestamps (for .SRT)", value=True) | |
transcription_button = gr.Button("▶️ Start transcription", variant="primary") | |
with gr.Column(scale=2): | |
gr.Markdown("#### 3. Results") | |
transcription_status = gr.Markdown(visible=False) | |
transcription_output = gr.JSON(label="Raw transcription data", visible=False) | |
with gr.Group(visible=False) as download_zone: | |
download_srt_button = gr.Button("💾 Download .srt file", variant="secondary") | |
download_file_output = gr.File(label="Your file is ready:", interactive=False) | |
def save_key(api_key): | |
return api_key, "✅ API key saved." | |
save_api_key_button.click(fn=save_key, inputs=[api_key_input], outputs=[api_key_state, api_key_status]) | |
# Dropdown logic: update State when dropdown changes | |
def update_chat_model(model): | |
return model | |
def update_transcription_model(model): | |
return model | |
chat_model_dropdown.change(fn=update_chat_model, inputs=[chat_model_dropdown], outputs=[chat_model_state]) | |
transcription_model_dropdown.change(fn=update_transcription_model, inputs=[transcription_model_dropdown], outputs=[transcription_model_state]) | |
async def on_submit(api_key, user_msg, api_history, uploaded_files, mic_file, chat_model): | |
# 1. Check for API Key | |
if not api_key: | |
api_history.append({"role": "user", "content": user_msg or "..."}) | |
api_history.append({"role": "assistant", "content": "Error: Please configure your API key."}) | |
yield api_history, format_history_for_display(api_history), "", None, None | |
return | |
# 2. Collect all audio file paths | |
all_filepaths = [] | |
if uploaded_files: | |
all_filepaths.extend(p.name for p in uploaded_files) | |
if mic_file: | |
all_filepaths.append(mic_file) | |
# 3. Upload files in parallel and show loading state | |
audio_urls_to_send = [] | |
if all_filepaths: | |
audio_count = len(all_filepaths) | |
api_history.append({"role": "user", "content": user_msg or ""}) # Placeholder for display | |
api_history.append({"role": "assistant", "content": f"⏳ *Uploading {audio_count} audio file{'s' if audio_count > 1 else ''}...*"}) | |
yield api_history, format_history_for_display(api_history), user_msg, None, None | |
upload_tasks = [upload_file_to_public_service(path) for path in all_filepaths] | |
uploaded_urls = await asyncio.gather(*upload_tasks) | |
audio_urls_to_send = [url for url in uploaded_urls if url] | |
api_history.pop() # Remove loading message | |
if len(audio_urls_to_send) != audio_count: | |
api_history.append({"role": "assistant", "content": f"Error: Failed to upload {audio_count - len(audio_urls_to_send)} file(s)."}) | |
yield api_history, format_history_for_display(api_history), user_msg, None, None | |
return | |
# 4. Construct the user message for the API | |
current_user_content = [] | |
for url in audio_urls_to_send: | |
current_user_content.append({"type": "input_audio", "input_audio": {"data": url}}) | |
if user_msg: | |
current_user_content.append({"type": "text", "text": user_msg}) | |
if not current_user_content: | |
yield api_history, format_history_for_display(api_history), "", None, None | |
return | |
# If we had a placeholder, replace it. Otherwise, append. | |
if all_filepaths: | |
api_history[-1] = {"role": "user", "content": current_user_content} | |
else: | |
api_history.append({"role": "user", "content": current_user_content}) | |
# 5. Call API and handle tool calls | |
try: | |
response = await handle_api_call(api_key, api_history, chat_model) | |
response.raise_for_status() | |
assistant_message = response.json()['choices'][0]['message'] | |
api_history.append(assistant_message) | |
if "tool_calls" in assistant_message and assistant_message["tool_calls"]: | |
tool_call = assistant_message["tool_calls"][0] | |
function_name = tool_call['function']['name'] | |
if function_name in available_tools: | |
function_args = json.loads(tool_call['function']['arguments']) | |
tool_output = available_tools[function_name](**function_args) | |
api_history.append({"role": "tool", "tool_call_id": tool_call['id'], "content": tool_output}) | |
second_response = await handle_api_call(api_key, api_history, chat_model) | |
second_response.raise_for_status() | |
final_message = second_response.json()['choices'][0]['message'] | |
api_history.append(final_message) | |
except Exception as e: | |
error_details = str(e) | |
if hasattr(e, 'response') and e.response: | |
error_details = e.response.text | |
api_history.append({"role": "assistant", "content": f"API Error: {error_details}"}) | |
# 6. Final UI update | |
yield api_history, format_history_for_display(api_history), "", None, None | |
chat_inputs = [api_key_state, user_textbox, api_history_state, audio_input_files, mic_input, chat_model_state] | |
chat_outputs = [api_history_state, chatbot_display, user_textbox, audio_input_files, mic_input] | |
send_button.click( | |
fn=on_submit, | |
inputs=chat_inputs, | |
outputs=chat_outputs | |
) | |
user_textbox.submit( | |
fn=on_submit, | |
inputs=chat_inputs, | |
outputs=chat_outputs | |
) | |
def clear_chat(): return [], [], "", None, None | |
clear_button.click(fn=clear_chat, outputs=chat_outputs) | |
all_transcription_outputs = [ | |
transcription_button, | |
transcription_status, | |
transcription_output, | |
download_zone, | |
download_file_output | |
] | |
transcription_button.click( | |
fn=run_transcription_and_update_ui, | |
inputs=[api_key_state, source_type_transcription, audio_file_input, audio_url_input, timestamp_checkbox, transcription_model_state], | |
outputs=all_transcription_outputs | |
) | |
download_srt_button.click( | |
fn=generate_srt_file, | |
inputs=[transcription_output], | |
outputs=[download_file_output] | |
) | |
def toggle_transcription_inputs(source_type): return gr.update(visible=source_type == "Upload a file"), gr.update(visible=source_type == "Use a URL") | |
source_type_transcription.change(fn=toggle_transcription_inputs, inputs=source_type_transcription, outputs=[audio_file_input, audio_url_input]) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=20, max_size=40) | |
demo.launch(debug=False, max_threads=20) |