Spaces:
Runtime error
Runtime error
# Install required dependency | |
# !pip install mistral-common | |
import gradio as gr | |
import torch | |
import tempfile | |
import os | |
from typing import List, Tuple | |
from transformers import VoxtralForConditionalGeneration, AutoProcessor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
repo_id = "mistralai/Voxtral-Mini-3B-2507" | |
processor = AutoProcessor.from_pretrained(repo_id) | |
model = VoxtralForConditionalGeneration.from_pretrained( | |
repo_id, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
) | |
def respond(audio_files: List[str], question: str) -> Tuple[str, List[str]]: | |
if not audio_files: | |
return "Please upload at least one audio file.", [] | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "audio", "path": path} for path in audio_files | |
] + [{"type": "text", "text": question}], | |
} | |
] | |
inputs = processor.apply_chat_template(conversation) | |
inputs = inputs.to(device, dtype=torch.bfloat16) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=500) | |
decoded = processor.batch_decode( | |
outputs[:, inputs.input_ids.shape[1]:], | |
skip_special_tokens=True, | |
) | |
return decoded[0], audio_files | |
demo = gr.Interface( | |
fn=respond, | |
inputs=[ | |
gr.Audio(type="filepath", label="Audio files", file_count="multiple"), | |
gr.Textbox(lines=2, placeholder="Ask something about the audio(s)...", label="Question"), | |
], | |
outputs=[ | |
gr.Textbox(label="Answer"), | |
gr.Gallery(label="Uploaded audio files"), | |
], | |
title="Voxtral-Mini-3B-2507 Audio Q&A", | |
description="Upload one or more audio files and ask any question about them.", | |
examples=[ | |
[ | |
[ | |
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/mary_had_lamb.mp3", | |
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/winning_call.mp3", | |
], | |
"What sport and what nursery rhyme are referenced?", | |
] | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |