Morgan Funtowicz
commited on
Commit
·
8550385
1
Parent(s):
69fb91d
misc(sdk): use endpoint config parser
Browse files- handler.py +43 -31
handler.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import asyncio
|
2 |
-
import os
|
3 |
import zlib
|
4 |
from functools import lru_cache
|
5 |
from io import BytesIO
|
|
|
6 |
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
|
7 |
|
8 |
import numpy as np
|
|
|
9 |
from hfendpoints.openai import Context, run
|
10 |
from hfendpoints.openai.audio import (
|
11 |
AutomaticSpeechRecognitionEndpoint,
|
@@ -19,22 +20,25 @@ from hfendpoints.openai.audio import (
|
|
19 |
)
|
20 |
from librosa import load as load_audio, get_duration
|
21 |
from loguru import logger
|
|
|
22 |
from vllm import (
|
23 |
AsyncEngineArgs,
|
24 |
AsyncLLMEngine,
|
25 |
SamplingParams,
|
26 |
)
|
27 |
|
28 |
-
from hfendpoints import Handler
|
29 |
|
30 |
if TYPE_CHECKING:
|
31 |
from transformers import PreTrainedTokenizer
|
32 |
from vllm import CompletionOutput, RequestOutput
|
33 |
from vllm.sequence import SampleLogprobs
|
34 |
|
|
|
|
|
35 |
|
36 |
def chunk_audio_with_duration(
|
37 |
-
|
38 |
) -> Sequence[np.ndarray]:
|
39 |
"""
|
40 |
Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
|
@@ -63,10 +67,10 @@ def compression_ratio(text: str) -> float:
|
|
63 |
|
64 |
|
65 |
def create_prompt(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
):
|
71 |
"""
|
72 |
Generate the right prompt with the specific parameters to submit for inference over Whisper
|
@@ -93,7 +97,7 @@ def create_prompt(
|
|
93 |
|
94 |
|
95 |
def create_params(
|
96 |
-
|
97 |
) -> "SamplingParams":
|
98 |
"""
|
99 |
Create sampling parameters to submit for inference through vLLM `generate`
|
@@ -123,12 +127,12 @@ def get_avg_logprob(logprobs: "SampleLogprobs") -> float:
|
|
123 |
|
124 |
|
125 |
def process_chunk(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
) -> Generator:
|
133 |
"""
|
134 |
Decode a single transcribed audio chunk and generates all the segments associated
|
@@ -198,9 +202,9 @@ def process_chunk(
|
|
198 |
|
199 |
|
200 |
def process_chunks(
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
) -> Tuple[List[Segment], str]:
|
205 |
"""
|
206 |
Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
|
@@ -223,7 +227,7 @@ def process_chunks(
|
|
223 |
logprobs = generation.logprobs
|
224 |
|
225 |
for segment, _is_continuation in process_chunk(
|
226 |
-
|
227 |
):
|
228 |
materialized_segments.append(segment)
|
229 |
|
@@ -258,17 +262,17 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
258 |
enforce_eager=False,
|
259 |
enable_prefix_caching=True,
|
260 |
max_logprobs=1, # TODO(mfuntowicz) : Set from config?
|
261 |
-
disable_log_requests=True
|
262 |
)
|
263 |
)
|
264 |
|
265 |
async def transcribe(
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
) -> (List[Segment], str):
|
273 |
async def __agenerate__(request_id: str, prompt, params):
|
274 |
"""
|
@@ -319,14 +323,14 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
319 |
return segments, text
|
320 |
|
321 |
async def __call__(
|
322 |
-
|
323 |
) -> TranscriptionResponse:
|
324 |
with logger.contextualize(request_id=ctx.request_id):
|
325 |
with memoryview(request) as audio:
|
326 |
|
327 |
# Check if we need to enable the verbose path
|
328 |
is_verbose = (
|
329 |
-
|
330 |
)
|
331 |
|
332 |
# Retrieve the tokenizer and model config asynchronously while we decode audio
|
@@ -375,14 +379,22 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
|
|
375 |
|
376 |
|
377 |
def entrypoint():
|
378 |
-
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
|
|
381 |
endpoint = AutomaticSpeechRecognitionEndpoint(
|
382 |
-
WhisperHandler(
|
383 |
)
|
384 |
|
385 |
-
|
|
|
386 |
|
387 |
|
388 |
if __name__ == "__main__":
|
|
|
1 |
import asyncio
|
|
|
2 |
import zlib
|
3 |
from functools import lru_cache
|
4 |
from io import BytesIO
|
5 |
+
from pathlib import Path
|
6 |
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
|
7 |
|
8 |
import numpy as np
|
9 |
+
from hfendpoints.errors.config import UnsupportedModelArchitecture
|
10 |
from hfendpoints.openai import Context, run
|
11 |
from hfendpoints.openai.audio import (
|
12 |
AutomaticSpeechRecognitionEndpoint,
|
|
|
20 |
)
|
21 |
from librosa import load as load_audio, get_duration
|
22 |
from loguru import logger
|
23 |
+
from transformers import AutoConfig
|
24 |
from vllm import (
|
25 |
AsyncEngineArgs,
|
26 |
AsyncLLMEngine,
|
27 |
SamplingParams,
|
28 |
)
|
29 |
|
30 |
+
from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures
|
31 |
|
32 |
if TYPE_CHECKING:
|
33 |
from transformers import PreTrainedTokenizer
|
34 |
from vllm import CompletionOutput, RequestOutput
|
35 |
from vllm.sequence import SampleLogprobs
|
36 |
|
37 |
+
SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"]
|
38 |
+
|
39 |
|
40 |
def chunk_audio_with_duration(
|
41 |
+
audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
|
42 |
) -> Sequence[np.ndarray]:
|
43 |
"""
|
44 |
Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
|
|
|
67 |
|
68 |
|
69 |
def create_prompt(
|
70 |
+
audio: np.ndarray,
|
71 |
+
sampling_rate: int,
|
72 |
+
language: int,
|
73 |
+
timestamp_marker: int,
|
74 |
):
|
75 |
"""
|
76 |
Generate the right prompt with the specific parameters to submit for inference over Whisper
|
|
|
97 |
|
98 |
|
99 |
def create_params(
|
100 |
+
max_tokens: int, temperature: float, is_verbose: bool
|
101 |
) -> "SamplingParams":
|
102 |
"""
|
103 |
Create sampling parameters to submit for inference through vLLM `generate`
|
|
|
127 |
|
128 |
|
129 |
def process_chunk(
|
130 |
+
tokenizer: "PreTrainedTokenizer",
|
131 |
+
ids: np.ndarray,
|
132 |
+
logprobs: "SampleLogprobs",
|
133 |
+
request: TranscriptionRequest,
|
134 |
+
segment_offset: int,
|
135 |
+
timestamp_offset: int,
|
136 |
) -> Generator:
|
137 |
"""
|
138 |
Decode a single transcribed audio chunk and generates all the segments associated
|
|
|
202 |
|
203 |
|
204 |
def process_chunks(
|
205 |
+
tokenizer: "PreTrainedTokenizer",
|
206 |
+
chunks: List["RequestOutput"],
|
207 |
+
request: TranscriptionRequest,
|
208 |
) -> Tuple[List[Segment], str]:
|
209 |
"""
|
210 |
Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
|
|
|
227 |
logprobs = generation.logprobs
|
228 |
|
229 |
for segment, _is_continuation in process_chunk(
|
230 |
+
tokenizer, ids, logprobs, request, segment_offset, time_offset
|
231 |
):
|
232 |
materialized_segments.append(segment)
|
233 |
|
|
|
262 |
enforce_eager=False,
|
263 |
enable_prefix_caching=True,
|
264 |
max_logprobs=1, # TODO(mfuntowicz) : Set from config?
|
265 |
+
disable_log_requests=True,
|
266 |
)
|
267 |
)
|
268 |
|
269 |
async def transcribe(
|
270 |
+
self,
|
271 |
+
ctx: Context,
|
272 |
+
request: TranscriptionRequest,
|
273 |
+
tokenizer: "PreTrainedTokenizer",
|
274 |
+
audio_chunks: Iterable[np.ndarray],
|
275 |
+
params: "SamplingParams",
|
276 |
) -> (List[Segment], str):
|
277 |
async def __agenerate__(request_id: str, prompt, params):
|
278 |
"""
|
|
|
323 |
return segments, text
|
324 |
|
325 |
async def __call__(
|
326 |
+
self, request: TranscriptionRequest, ctx: Context
|
327 |
) -> TranscriptionResponse:
|
328 |
with logger.contextualize(request_id=ctx.request_id):
|
329 |
with memoryview(request) as audio:
|
330 |
|
331 |
# Check if we need to enable the verbose path
|
332 |
is_verbose = (
|
333 |
+
request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
|
334 |
)
|
335 |
|
336 |
# Retrieve the tokenizer and model config asynchronously while we decode audio
|
|
|
379 |
|
380 |
|
381 |
def entrypoint():
|
382 |
+
# Retrieve endpoint configuration
|
383 |
+
endpoint_config = EndpointConfig.from_env()
|
384 |
+
|
385 |
+
# Ensure the model is compatible is pre-downloaded
|
386 |
+
if (model_local_path := Path(endpoint_config.model_id)).exists():
|
387 |
+
if (config_local_path := (model_local_path / "config.json")).exists():
|
388 |
+
config = AutoConfig.from_pretrained(config_local_path)
|
389 |
+
ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES)
|
390 |
|
391 |
+
# Initialize the endpoint
|
392 |
endpoint = AutomaticSpeechRecognitionEndpoint(
|
393 |
+
WhisperHandler(endpoint_config.model_id)
|
394 |
)
|
395 |
|
396 |
+
# Serve the model
|
397 |
+
run(endpoint, endpoint_config.interface, endpoint_config.port)
|
398 |
|
399 |
|
400 |
if __name__ == "__main__":
|