Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, VoxtralForConditionalGeneration | |
import spaces | |
#### Functions | |
def process_transcript(language: str, audio_path: str) -> str: | |
"""Process the audio file to return its transcription. | |
Args: | |
language: The language of the audio. | |
audio_path: The path to the audio file. | |
Returns: | |
The transcribed text of the audio. | |
""" | |
if audio_path is None: | |
return "Please provide some input audio: either upload an audio file or use the microphone." | |
else: | |
id_language = dict_languages[language] | |
inputs = processor.apply_transcrition_request(language=id_language, audio=audio_path, model_id=model_name) | |
inputs = inputs.to(device, dtype=torch.bfloat16) | |
outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS) | |
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return decoded_outputs[0] | |
### | |
def process_translate(language: str, audio_path: str) -> str: | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "audio", | |
"path": audio_path, | |
}, | |
{"type": "text", "text": "Translate this in "+language}, | |
], | |
} | |
] | |
inputs = processor.apply_chat_template(conversation) | |
inputs = inputs.to(device, dtype=torch.bfloat16) | |
outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS) | |
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return decoded_outputs | |
def disable_buttons(): | |
return gr.update(interactive=False), gr.update(interactive=False) | |
def enable_buttons(): | |
return gr.update(interactive=True), gr.update(interactive=True) | |
### | |
### Initializations | |
MAX_TOKENS = 32000 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"*** Device: {device}") | |
model_name = 'mistralai/Voxtral-Mini-3B-2507' | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = VoxtralForConditionalGeneration.from_pretrained(model_name, | |
torch_dtype=torch.bfloat16, | |
device_map=device) | |
# Supported languages | |
dict_languages = {"English": "en", | |
"French": "fr", | |
"German": "de", | |
"Spanish": "es", | |
"Italian": "it", | |
"Portuguese": "pt", | |
"Dutch": "nl", | |
"Hindi": "hi"} | |
#### Gradio interface | |
with gr.Blocks(title="Transcription") as audio: | |
gr.Markdown("# Voxtral Mini Evaluation") | |
gr.Markdown("#### Choose the language of the audio and set an audio file to process it.") | |
gr.Markdown("##### *(Voxtral handles audios up to 30 minutes for transcription)*") | |
with gr.Row(): | |
with gr.Column(): | |
sel_language = gr.Dropdown( | |
choices=list(dict_languages.keys()), | |
value="English", | |
label="Select the language of the audio file:" | |
) | |
with gr.Column(): | |
sel_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", | |
label="Upload an audio file, record via microphone, or select a demo file:") | |
example = [["mapo_tofu.mp3"]] | |
gr.Examples( | |
examples=example, | |
inputs=sel_audio, | |
outputs=None, | |
fn=None, | |
cache_examples=False, | |
run_on_click=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
submit_transcript = gr.Button("Extract transcription", variant="primary") | |
text_transcript = gr.Textbox(label="Generated transcription", lines=10) | |
with gr.Column(): | |
sel_translate_language = gr.Dropdown( | |
choices=list(dict_languages.keys()), | |
value="English", | |
label="Select the language for translation:" | |
) | |
submit_translate = gr.Button("Translate audio file", variant="primary") | |
text_translate = gr.Textbox(label="Generated translation", lines=10) | |
with gr.Column(): | |
submit_chat = gr.Button("Ask audio file", variant="primary") | |
text_chat = gr.Textbox(label="Model answer", lines=10) | |
### Processing | |
# Transcription | |
submit_transcript.click( | |
disable_buttons, | |
outputs=[submit_transcript, submit_translate, submit_chat], | |
trigger_mode="once", | |
).then( | |
fn=process_transcript, | |
inputs=[sel_language, sel_audio], | |
outputs=text_transcript | |
).then( | |
enable_buttons, | |
outputs=[submit_transcript, submit_translate, submit_chat], | |
) | |
# Translation | |
submit_transcript.click( | |
disable_buttons, | |
outputs=[submit_transcript, submit_translate, submit_chat], | |
trigger_mode="once", | |
).then( | |
fn=process_transcript, | |
inputs=[sel_translate_language, sel_audio], | |
outputs=text_transcript | |
).then( | |
enable_buttons, | |
outputs=[submit_transcript, submit_translate, submit_chat], | |
) | |
### Launch the app | |
if __name__ == "__main__": | |
audio.launch() | |