JacobLinCool commited on
Commit
487ed33
·
1 Parent(s): ff0fd39

feat: add phi

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +129 -96
  3. model.py +188 -18
  4. requirements.txt +3 -2
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐠
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.4.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,109 +1,142 @@
 
1
  import gradio as gr
2
- from huggingface_hub.utils import get_token
3
- import requests
4
  import base64
5
- from model import model_id, transcribe_audio_local
6
 
7
- token = get_token()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def read_file_as_base64(file_path: str) -> str:
11
- with open(file_path, "rb") as f:
12
- return base64.b64encode(f.read()).decode()
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def transcribe_audio(audio: str) -> str:
16
- print(f"{audio=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- if audio is None:
19
- raise gr.Error(
20
- "Please wait a moment for the audio to be uploaded, then click the button again."
 
21
  )
22
 
23
- # resample to 16k mono to reduce file size
24
- import subprocess
25
- import os
26
-
27
- audio_resampled = audio.replace(".mp3", "_resampled.mp3")
28
- subprocess.run(
29
- [
30
- "ffmpeg",
31
- "-i",
32
- audio,
33
- "-ac",
34
- "1",
35
- "-ar",
36
- "16000",
37
- audio_resampled,
38
- "-y",
39
- ],
40
- check=True,
41
- )
42
-
43
- b64 = read_file_as_base64(audio_resampled)
44
- url = f"https://api-inference.huggingface.co/models/{model_id}"
45
- headers = {
46
- "Authorization": f"Bearer {token}",
47
- "Content-Type": "application/json",
48
- "x-wait-for-model": "true",
49
- }
50
- data = {
51
- "inputs": b64,
52
- "parameters": {
53
- "generate_kwargs": {
54
- "return_timestamps": True,
55
- }
56
- },
57
- }
58
- response = requests.post(url, headers=headers, json=data)
59
- print(f"{response.text=}")
60
- out = response.json()
61
- print(f"{out=}")
62
-
63
- return out["text"]
64
-
65
-
66
- with gr.Blocks() as demo:
67
- gr.Markdown("# TWASR: Chinese (Taiwan) Automatic Speech Recognition.")
68
- gr.Markdown("Upload an audio file or record your voice to transcribe it to text.")
69
- gr.Markdown(
70
- "First load may take a while to initialize the model, following requests will be faster."
71
- )
72
-
73
- with gr.Row():
74
- audio_input = gr.Audio(
75
- label="Audio", type="filepath", show_download_button=True
76
  )
77
- text_output = gr.Textbox(label="Transcription")
78
-
79
- transcribe_local_button = gr.Button(
80
- "Transcribe with Transformers", variant="primary"
81
- )
82
- transcribe_button = gr.Button("Transcribe with Inference API", variant="secondary")
83
-
84
- transcribe_local_button.click(
85
- fn=transcribe_audio_local, inputs=[audio_input], outputs=[text_output]
86
- )
87
- transcribe_button.click(
88
- fn=transcribe_audio, inputs=[audio_input], outputs=[text_output]
89
- )
90
-
91
- gr.Examples(
92
- [
93
- ["./examples/audio1.mp3"],
94
- ["./examples/audio2.mp3"],
95
- ],
96
- inputs=[audio_input],
97
- outputs=[text_output],
98
- fn=transcribe_audio_local,
99
- cache_examples=True,
100
- cache_mode="lazy",
101
- run_on_click=True,
102
- )
103
-
104
- gr.Markdown(
105
- f"Current model: {model_id}. For more information, visit the [model hub](https://huggingface.co/{model_id})."
106
- )
107
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import logging
4
+ from pathlib import Path
5
  import base64
 
6
 
7
+ from model import (
8
+ MODEL_ID as WHISPER_MODEL_ID,
9
+ PHI_MODEL_ID,
10
+ transcribe_audio_local,
11
+ transcribe_audio_phi,
12
+ preload_models,
13
+ )
14
+
15
+ # Set up logging
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Constants
22
+ EXAMPLES_DIR = Path("./examples")
23
+ MODEL_CHOICES = {WHISPER_MODEL_ID: "Whisper Model", PHI_MODEL_ID: "Phi-4 Model"}
24
+ EXAMPLE_FILES = [
25
+ (str(EXAMPLES_DIR / "audio1.mp3"), WHISPER_MODEL_ID),
26
+ (str(EXAMPLES_DIR / "audio2.mp3"), WHISPER_MODEL_ID),
27
+ ]
28
 
29
 
30
  def read_file_as_base64(file_path: str) -> str:
31
+ """
32
+ Read a file and encode it as base64.
33
+
34
+ Args:
35
+ file_path: Path to the file to read
36
+
37
+ Returns:
38
+ Base64 encoded string of file contents
39
+ """
40
+ try:
41
+ with open(file_path, "rb") as f:
42
+ return base64.b64encode(f.read()).decode()
43
+ except Exception as e:
44
+ logger.error(f"Failed to read file {file_path}: {str(e)}")
45
+ raise
46
+
47
+
48
+ def combined_transcription(audio: str, model_choice: str) -> str:
49
+ """
50
+ Transcribe audio using the selected model.
51
+
52
+ Args:
53
+ audio: Path to audio file
54
+ model_choice: Full model ID to use for transcription
55
+
56
+ Returns:
57
+ Transcription text
58
+ """
59
+ if not audio:
60
+ return "Please provide an audio file to transcribe."
61
+
62
+ try:
63
+ if model_choice == PHI_MODEL_ID:
64
+ return transcribe_audio_phi(audio)
65
+ elif model_choice == WHISPER_MODEL_ID:
66
+ return transcribe_audio_local(audio)
67
+ else:
68
+ logger.error(f"Unknown model choice: {model_choice}")
69
+ return f"Error: Unknown model {model_choice}"
70
+ except Exception as e:
71
+ logger.error(f"Transcription failed: {str(e)}")
72
+ return f"Error during transcription: {str(e)}"
73
+
74
+
75
+ def create_demo() -> gr.Blocks:
76
+ """Create and configure the Gradio demo interface"""
77
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
78
+ gr.Markdown("# TWASR: Chinese (Taiwan) Automatic Speech Recognition")
79
+ gr.Markdown(
80
+ "Upload an audio file or record your voice to transcribe it to text."
81
+ )
82
+ gr.Markdown(
83
+ "⚠️ First load may take a while to initialize the model, following requests will be faster."
84
+ )
85
 
86
+ with gr.Row():
87
+ audio_input = gr.Audio(
88
+ label="Audio Input", type="filepath", show_download_button=True
89
+ )
90
+ with gr.Column():
91
+ model_choice = gr.Dropdown(
92
+ label="Select Model",
93
+ choices=list(MODEL_CHOICES.keys()),
94
+ value=WHISPER_MODEL_ID,
95
+ info="Select the model for transcription",
96
+ )
97
+ text_output = gr.Textbox(label="Transcription Output", lines=5)
98
+
99
+ with gr.Row():
100
+ transcribe_button = gr.Button("🎯 Transcribe", variant="primary")
101
+ clear_button = gr.Button("🧹 Clear")
102
+
103
+ transcribe_button.click(
104
+ fn=combined_transcription,
105
+ inputs=[audio_input, model_choice],
106
+ outputs=[text_output],
107
+ show_progress=True,
108
+ )
109
 
110
+ clear_button.click(
111
+ fn=lambda: (None, ""),
112
+ inputs=[],
113
+ outputs=[audio_input, text_output],
114
  )
115
 
116
+ gr.Examples(
117
+ EXAMPLE_FILES,
118
+ inputs=[audio_input, model_choice],
119
+ outputs=[text_output],
120
+ fn=combined_transcription,
121
+ cache_examples=True,
122
+ cache_mode="lazy",
123
+ run_on_click=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
125
+
126
+ gr.Markdown("### Model Information")
127
+ with gr.Accordion("Model Details", open=False):
128
+ for model_id, model_name in MODEL_CHOICES.items():
129
+ gr.Markdown(
130
+ f"**{model_name}:** [{model_id}](https://huggingface.co/{model_id})"
131
+ )
132
+
133
+ return demo
134
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
+ # Preload models before starting the app to reduce cold start time
138
+ logger.info("Preloading models to reduce cold start time")
139
+ preload_models()
140
+
141
+ demo = create_demo()
142
+ demo.launch(share=False)
model.py CHANGED
@@ -1,35 +1,205 @@
1
- from transformers import pipeline
2
- from accelerate import Accelerator
3
  import spaces
 
 
 
 
 
 
4
  import librosa
 
 
 
 
 
 
 
 
5
 
6
- model_id = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW"
 
 
 
7
 
8
- pipe = None
 
 
 
9
 
 
 
10
 
11
- def load_model():
 
 
 
 
 
12
  global pipe
13
- device = Accelerator().device
14
- pipe = pipeline("automatic-speech-recognition", model=model_id, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def get_gpu_duration(audio: str) -> int:
18
- y, sr = librosa.load(audio)
19
- duration = librosa.get_duration(y=y, sr=sr) / 60.0
20
- gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0
21
- print(f"{duration=}, {gpu_duration=}")
22
- return int(gpu_duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  @spaces.GPU(duration=get_gpu_duration)
26
  def transcribe_audio_local(audio: str) -> str:
27
- print(f"{audio=}")
 
28
 
29
- if pipe is None:
30
- load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- out = pipe(audio, return_timestamps=True)
33
- print(f"{out=}")
 
34
 
35
- return out["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ from typing import Optional
3
+ import logging
4
+ import time
5
+ import threading
6
+
7
+ import torch
8
  import librosa
9
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, Pipeline
10
+ from accelerate import Accelerator
11
+
12
+ # Set up logging
13
+ logging.basicConfig(
14
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
15
+ )
16
+ logger = logging.getLogger(__name__)
17
 
18
+ # Model constants
19
+ MODEL_ID = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW"
20
+ PHI_MODEL_ID = "JacobLinCool/Phi-4-multimodal-instruct-commonvoice-zh-tw"
21
+ USE_FA = torch.cuda.is_available() # Use Flash Attention if CUDA is available
22
 
23
+ # Model instances (initialized lazily)
24
+ pipe: Optional[Pipeline] = None
25
+ phi_model = None
26
+ phi_processor = None
27
 
28
+ # Lock for thread-safe model loading
29
+ model_loading_lock = threading.Lock()
30
 
31
+
32
+ def load_model() -> None:
33
+ """
34
+ Load the Whisper model for transcription.
35
+ Uses GPU if available.
36
+ """
37
  global pipe
38
+ if pipe is not None:
39
+ return # Model already loaded
40
+
41
+ try:
42
+ start_time = time.time()
43
+ logger.info(f"Loading Whisper model {MODEL_ID}...")
44
+ device = Accelerator().device
45
+ pipe = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device)
46
+ logger.info(
47
+ f"Model loaded successfully in {time.time() - start_time:.2f} seconds"
48
+ )
49
+ except Exception as e:
50
+ logger.error(f"Failed to load Whisper model: {str(e)}")
51
+ raise
52
 
53
 
54
  def get_gpu_duration(audio: str) -> int:
55
+ """
56
+ Calculate required GPU allocation time based on audio duration.
57
+
58
+ Args:
59
+ audio: Path to audio file
60
+
61
+ Returns:
62
+ GPU allocation time in seconds
63
+ """
64
+ try:
65
+ y, sr = librosa.load(audio)
66
+ duration = librosa.get_duration(y=y, sr=sr) / 60.0
67
+ gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0
68
+ logger.info(
69
+ f"Audio duration: {duration:.2f} min, Allocated GPU time: {gpu_duration:.2f} min"
70
+ )
71
+ return int(gpu_duration)
72
+ except Exception as e:
73
+ logger.error(f"Failed to calculate GPU duration: {str(e)}")
74
+ return 60 # Default to 1 minute if calculation fails
75
 
76
 
77
  @spaces.GPU(duration=get_gpu_duration)
78
  def transcribe_audio_local(audio: str) -> str:
79
+ """
80
+ Transcribe audio using the Whisper model.
81
 
82
+ Args:
83
+ audio: Path to audio file
84
+
85
+ Returns:
86
+ Transcribed text
87
+ """
88
+ try:
89
+ logger.info(f"Transcribing audio with Whisper: {audio}")
90
+ if pipe is None:
91
+ load_model()
92
+
93
+ out = pipe(audio, return_timestamps=True)
94
+ return out.get("text", "No transcription generated")
95
+ except Exception as e:
96
+ logger.error(f"Whisper transcription error: {str(e)}")
97
+ raise
98
+
99
+
100
+ def load_phi_model() -> None:
101
+ """
102
+ Load the Phi-4 model and processor.
103
+ Uses GPU with Flash Attention if available.
104
+ """
105
+ global phi_model, phi_processor
106
+ if phi_model is not None and phi_processor is not None:
107
+ return # Model already loaded
108
+
109
+ try:
110
+ start_time = time.time()
111
+ logger.info(f"Loading Phi-4 model {PHI_MODEL_ID}...")
112
+
113
+ phi_processor = AutoProcessor.from_pretrained(
114
+ PHI_MODEL_ID, trust_remote_code=True
115
+ )
116
 
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ dtype = torch.bfloat16 if USE_FA else torch.float32
119
+ attn_implementation = "flash_attention_2" if USE_FA else "sdpa"
120
 
121
+ phi_model = AutoModelForCausalLM.from_pretrained(
122
+ PHI_MODEL_ID,
123
+ torch_dtype=dtype,
124
+ _attn_implementation=attn_implementation,
125
+ trust_remote_code=True,
126
+ ).to(device)
127
+
128
+ logger.info(
129
+ f"Phi-4 model loaded successfully in {time.time() - start_time:.2f} seconds"
130
+ )
131
+ except Exception as e:
132
+ logger.error(f"Failed to load Phi-4 model: {str(e)}")
133
+ raise
134
+
135
+
136
+ def transcribe_audio_phi(audio: str) -> str:
137
+ """
138
+ Transcribe audio using the Phi-4 model.
139
+
140
+ Args:
141
+ audio: Path to audio file
142
+
143
+ Returns:
144
+ Transcribed text
145
+ """
146
+ try:
147
+ logger.info(f"Transcribing audio with Phi-4: {audio}")
148
+ load_phi_model()
149
+
150
+ # Load and resample audio to 16kHz
151
+ y, sr = librosa.load(audio, sr=16000)
152
+
153
+ # Prepare the user message and generate the prompt
154
+ user_message = {
155
+ "role": "user",
156
+ "content": "<|audio_1|> Transcribe the audio clip into text.",
157
+ }
158
+ prompt = phi_processor.tokenizer.apply_chat_template(
159
+ [user_message], tokenize=False, add_generation_prompt=True
160
+ )
161
+
162
+ # Build inputs for the model
163
+ inputs = phi_processor(text=prompt, audios=[(y, sr)], return_tensors="pt")
164
+ inputs = {
165
+ k: v.to(phi_model.device) if hasattr(v, "to") else v
166
+ for k, v in inputs.items()
167
+ }
168
+
169
+ # Generate transcription without gradients
170
+ with torch.no_grad():
171
+ generated_ids = phi_model.generate(
172
+ **inputs,
173
+ eos_token_id=phi_processor.tokenizer.eos_token_id,
174
+ max_new_tokens=256, # Increased for longer transcriptions
175
+ do_sample=False,
176
+ )
177
+
178
+ # Decode the generated token IDs into text
179
+ transcription = phi_processor.decode(
180
+ generated_ids[0, inputs["input_ids"].shape[1] :],
181
+ skip_special_tokens=True,
182
+ clean_up_tokenization_spaces=False,
183
+ )
184
+
185
+ logger.info(f"Phi-4 transcription completed successfully")
186
+ return transcription
187
+ except Exception as e:
188
+ logger.error(f"Phi-4 transcription error: {str(e)}")
189
+ raise
190
+
191
+
192
+ def preload_models() -> None:
193
+ """
194
+ Preload models into memory to reduce cold start time.
195
+ This function can be called at application startup.
196
+ """
197
+ try:
198
+ logger.info("Preloading models to reduce cold start time")
199
+ # Load Whisper model first as it's the default
200
+ load_model()
201
+ # Then load Phi model
202
+ load_phi_model()
203
+ logger.info("All models preloaded successfully")
204
+ except Exception as e:
205
+ logger.error(f"Error during model preloading: {str(e)}")
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- gradio==5.4.0
2
- huggingface_hub==0.26.2
3
  transformers
4
  accelerate
5
  spaces
6
  librosa
 
 
1
+ gradio==5.20.1
2
+ huggingface_hub
3
  transformers
4
  accelerate
5
  spaces
6
  librosa
7
+ flash-attn