import torch import torchaudio import gradio as gr from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq from datetime import date # device setup device = "cuda" if torch.cuda.is_available() else "cpu" # load model + processor model_name = "ibm-granite/granite-speech-3.3-8b" processor = AutoProcessor.from_pretrained(model_name) tokenizer = processor.tokenizer model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, device_map=device, torch_dtype=torch.bfloat16 ) def transcribe(audio_file, user_prompt): # load wav file wav, sr = torchaudio.load(audio_file, normalize=True) if wav.shape[0] != 1 or sr != 16000: # resample + convert to mono if needed wav = torch.mean(wav, dim=0, keepdim=True) # mono wav = torchaudio.functional.resample(wav, sr, 16000) sr = 16000 today_str = date.today().strftime("%B %d, %Y") system_prompt = ( "Knowledge Cutoff Date: April 2024.\n" f"Today's Date: {today_str}.\n" "You are Granite, developed by IBM. You are a helpful AI assistant." ) chat = [ dict(role="system", content=system_prompt), dict(role="user", content=f"<|audio|>{user_prompt}"), ] prompt = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True) # run model model_inputs = processor( prompt, wav, device=device, return_tensors="pt").to(device) model_outputs = model.generate( **model_inputs, max_new_tokens=512, do_sample=False, num_beams=1 ) # strip prompt tokens num_input_tokens = model_inputs["input_ids"].shape[-1] new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) output_text = tokenizer.batch_decode( new_tokens, add_special_tokens=False, skip_special_tokens=True ) return output_text[0].strip() # Gradio UI with gr.Blocks() as demo: gr.Markdown("## Granite 3.3 Speech-to-Text") gr.Markdown( "Upload an audio file and Granite Speech 3.3 8b will transcribe it into text." "You can also edit the prompt below to customize what Granite should do with the audio, like translation." ) with gr.Row(): audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)") output_text = gr.Textbox(label="Transcription", lines=5) user_prompt = gr.Textbox( label="User Prompt", value="Can you transcribe the speech into a written format?", lines=2 ) transcribe_btn = gr.Button("Transcribe") transcribe_btn.click( fn=transcribe, inputs=[ audio_input, user_prompt], outputs=output_text) demo.launch()