File size: 5,401 Bytes
469746c
 
 
2dca6c6
469746c
5327d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67dc1fc
5327d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469746c
 
 
 
934d9ac
 
 
 
 
 
469746c
 
 
 
 
 
 
 
 
 
 
5327d9d
52a6261
5327d9d
 
 
469746c
f1b9275
469746c
 
 
 
 
 
 
 
5327d9d
 
469746c
2029809
 
 
 
 
 
 
 
 
 
f1b9275
469746c
5327d9d
 
469746c
5327d9d
 
 
 
 
 
 
 
 
 
 
 
 
72a0803
5327d9d
 
 
934d9ac
5327d9d
 
 
 
740245f
382aa9f
934d9ac
5327d9d
 
 
934d9ac
469746c
5327d9d
 
 
 
 
 
 
e2eb7cb
5327d9d
 
 
 
 
 
 
 
 
469746c
5327d9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import torch
from transformers import AutoProcessor, VoxtralForConditionalGeneration
import spaces

#### Functions

@spaces.GPU
def process_transcript(language: str, audio_path: str) -> str:
    """Process the audio file to return its transcription.

    Args:
        language: The language of the audio.
        audio_path: The path to the audio file.

    Returns:
        The transcribed text of the audio.
    """

    if audio_path is None:
        return "Please provide some input audio: either upload an audio file or use the microphone."
    else:
        id_language = dict_languages[language]
        inputs = processor.apply_transcrition_request(language=id_language, audio=audio_path, model_id=model_name)
        inputs = inputs.to(device, dtype=torch.bfloat16)
        outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS)
        decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)

        return decoded_outputs[0]
###

def process_translate(language: str, audio_path: str) -> str:
    conversation = [
        {
            "role": "user",
            "content": [
                {
                    "type": "audio",
                    "path": audio_path,
                },
                {"type": "text", "text": "Translate this in "+language},
            ],
        }
    ]
    
    inputs = processor.apply_chat_template(conversation)
    inputs = inputs.to(device, dtype=torch.bfloat16)
    
    outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS)
    decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)

    return decoded_outputs


def disable_buttons():
    return gr.update(interactive=False), gr.update(interactive=False)

def enable_buttons():
    return gr.update(interactive=True), gr.update(interactive=True)
###

### Initializations

MAX_TOKENS = 32000

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"*** Device: {device}")
model_name = 'mistralai/Voxtral-Mini-3B-2507'

processor = AutoProcessor.from_pretrained(model_name)
model = VoxtralForConditionalGeneration.from_pretrained(model_name,
                                                        torch_dtype=torch.bfloat16,
                                                        device_map=device)
# Supported languages
dict_languages = {"English": "en",
                  "French": "fr",
                  "German": "de",
                  "Spanish": "es",
                  "Italian": "it",
                  "Portuguese": "pt",
                  "Dutch": "nl",
                  "Hindi": "hi"}


#### Gradio interface
with gr.Blocks(title="Transcription") as audio:
    gr.Markdown("# Voxtral Mini Evaluation")
    gr.Markdown("#### Choose the language of the audio and set an audio file to process it.")
    gr.Markdown("##### *(Voxtral handles audios up to 30 minutes for transcription)*")

    with gr.Row():
        with gr.Column():
            sel_language = gr.Dropdown(
                choices=list(dict_languages.keys()),
                value="English",
                label="Select the language of the audio file:"
            )

        with gr.Column():
            sel_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", 
                                 label="Upload an audio file, record via microphone, or select a demo file:")

            example = [["mapo_tofu.mp3"]]
            gr.Examples(
                examples=example,
                inputs=sel_audio,
                outputs=None,
                fn=None,
                cache_examples=False,
                run_on_click=False
            )
    
    with gr.Row():
        with gr.Column():
            submit_transcript = gr.Button("Extract transcription", variant="primary")
            text_transcript = gr.Textbox(label="Generated transcription", lines=10)

        with gr.Column():
            sel_translate_language = gr.Dropdown(
                choices=list(dict_languages.keys()),
                value="English",
                label="Select the language for translation:"
            )

            submit_translate = gr.Button("Translate audio file", variant="primary")
            text_translate = gr.Textbox(label="Generated translation", lines=10)

        with gr.Column():
            submit_chat = gr.Button("Ask audio file", variant="primary")
            text_chat = gr.Textbox(label="Model answer", lines=10)
            
### Processing
    
    # Transcription
    submit_transcript.click(
        disable_buttons,
        outputs=[submit_transcript, submit_translate, submit_chat],
        trigger_mode="once",
    ).then(
        fn=process_transcript,
        inputs=[sel_language, sel_audio],
        outputs=text_transcript
    ).then(
        enable_buttons,
        outputs=[submit_transcript, submit_translate, submit_chat],
    )

    # Translation
    submit_transcript.click(
        disable_buttons,
        outputs=[submit_transcript, submit_translate, submit_chat],
        trigger_mode="once",
    ).then(
        fn=process_transcript,
        inputs=[sel_translate_language, sel_audio],
        outputs=text_transcript
    ).then(
        enable_buttons,
        outputs=[submit_transcript, submit_translate, submit_chat],
    )


### Launch the app

if __name__ == "__main__":
    audio.launch()