randomblock1 commited on
Commit
26082df
·
verified ·
1 Parent(s): 0342e22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -20
app.py CHANGED
@@ -15,15 +15,8 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(
15
  model_name, device_map=device, torch_dtype=torch.bfloat16
16
  )
17
 
18
- today_str = date.today().strftime("%B %d, %Y")
19
 
20
- system_prompt = (
21
- "Knowledge Cutoff Date: April 2024.\n"
22
- f"Today's Date: {today_str}.\n"
23
- "You are Granite, developed by IBM. You are a helpful AI assistant."
24
- )
25
-
26
- def transcribe(audio_file):
27
  # load wav file
28
  wav, sr = torchaudio.load(audio_file, normalize=True)
29
  if wav.shape[0] != 1 or sr != 16000:
@@ -32,20 +25,31 @@ def transcribe(audio_file):
32
  wav = torchaudio.functional.resample(wav, sr, 16000)
33
  sr = 16000
34
 
35
- # user prompt
36
- user_prompt = "<|audio|>can you transcribe the speech into a written format?"
 
 
 
 
 
 
37
  chat = [
38
  dict(role="system", content=system_prompt),
39
- dict(role="user", content=user_prompt),
40
  ]
41
- prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
 
42
 
43
  # run model
44
- model_inputs = processor(prompt, wav, device=device, return_tensors="pt").to(device)
 
 
 
 
45
  model_outputs = model.generate(
46
- **model_inputs,
47
- max_new_tokens=200,
48
- do_sample=False,
49
  num_beams=1
50
  )
51
 
@@ -58,15 +62,32 @@ def transcribe(audio_file):
58
 
59
  return output_text[0].strip()
60
 
 
61
  # Gradio UI
62
  with gr.Blocks() as demo:
63
- gr.Markdown("## Granite 3.3 Speech-to-Text Demo")
 
 
 
 
64
 
65
  with gr.Row():
66
- audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)")
 
67
  output_text = gr.Textbox(label="Transcription", lines=5)
68
 
 
 
 
 
 
 
69
  transcribe_btn = gr.Button("Transcribe")
70
- transcribe_btn.click(fn=transcribe, inputs=audio_input, outputs=output_text)
 
 
 
 
 
71
 
72
- demo.launch()
 
15
  model_name, device_map=device, torch_dtype=torch.bfloat16
16
  )
17
 
 
18
 
19
+ def transcribe(audio_file, user_prompt):
 
 
 
 
 
 
20
  # load wav file
21
  wav, sr = torchaudio.load(audio_file, normalize=True)
22
  if wav.shape[0] != 1 or sr != 16000:
 
25
  wav = torchaudio.functional.resample(wav, sr, 16000)
26
  sr = 16000
27
 
28
+ today_str = date.today().strftime("%B %d, %Y")
29
+
30
+ system_prompt = (
31
+ "Knowledge Cutoff Date: April 2024.\n"
32
+ f"Today's Date: {today_str}.\n"
33
+ "You are Granite, developed by IBM. You are a helpful AI assistant."
34
+ )
35
+
36
  chat = [
37
  dict(role="system", content=system_prompt),
38
+ dict(role="user", content=f"<|audio|>{user_prompt}"),
39
  ]
40
+ prompt = tokenizer.apply_chat_template(
41
+ chat, tokenize=False, add_generation_prompt=True)
42
 
43
  # run model
44
+ model_inputs = processor(
45
+ prompt,
46
+ wav,
47
+ device=device,
48
+ return_tensors="pt").to(device)
49
  model_outputs = model.generate(
50
+ **model_inputs,
51
+ max_new_tokens=512,
52
+ do_sample=False,
53
  num_beams=1
54
  )
55
 
 
62
 
63
  return output_text[0].strip()
64
 
65
+
66
  # Gradio UI
67
  with gr.Blocks() as demo:
68
+ gr.Markdown("## Granite 3.3 Speech-to-Text")
69
+ gr.Markdown(
70
+ "Upload an audio file and Granite Speech 3.3 8b will transcribe it into text."
71
+ "You can also edit the prompt below to customize what Granite should do with the audio, like translation."
72
+ )
73
 
74
  with gr.Row():
75
+ audio_input = gr.Audio(type="filepath",
76
+ label="Upload Audio (16kHz mono preferred)")
77
  output_text = gr.Textbox(label="Transcription", lines=5)
78
 
79
+ user_prompt = gr.Textbox(
80
+ label="User Prompt",
81
+ value="Can you transcribe the speech into a written format?",
82
+ lines=2
83
+ )
84
+
85
  transcribe_btn = gr.Button("Transcribe")
86
+ transcribe_btn.click(
87
+ fn=transcribe,
88
+ inputs=[
89
+ audio_input,
90
+ user_prompt],
91
+ outputs=output_text)
92
 
93
+ demo.launch()