Morgan Funtowicz commited on
Commit
8550385
·
1 Parent(s): 69fb91d

misc(sdk): use endpoint config parser

Browse files
Files changed (1) hide show
  1. handler.py +43 -31
handler.py CHANGED
@@ -1,11 +1,12 @@
1
  import asyncio
2
- import os
3
  import zlib
4
  from functools import lru_cache
5
  from io import BytesIO
 
6
  from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
7
 
8
  import numpy as np
 
9
  from hfendpoints.openai import Context, run
10
  from hfendpoints.openai.audio import (
11
  AutomaticSpeechRecognitionEndpoint,
@@ -19,22 +20,25 @@ from hfendpoints.openai.audio import (
19
  )
20
  from librosa import load as load_audio, get_duration
21
  from loguru import logger
 
22
  from vllm import (
23
  AsyncEngineArgs,
24
  AsyncLLMEngine,
25
  SamplingParams,
26
  )
27
 
28
- from hfendpoints import Handler
29
 
30
  if TYPE_CHECKING:
31
  from transformers import PreTrainedTokenizer
32
  from vllm import CompletionOutput, RequestOutput
33
  from vllm.sequence import SampleLogprobs
34
 
 
 
35
 
36
  def chunk_audio_with_duration(
37
- audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
38
  ) -> Sequence[np.ndarray]:
39
  """
40
  Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
@@ -63,10 +67,10 @@ def compression_ratio(text: str) -> float:
63
 
64
 
65
  def create_prompt(
66
- audio: np.ndarray,
67
- sampling_rate: int,
68
- language: int,
69
- timestamp_marker: int,
70
  ):
71
  """
72
  Generate the right prompt with the specific parameters to submit for inference over Whisper
@@ -93,7 +97,7 @@ def create_prompt(
93
 
94
 
95
  def create_params(
96
- max_tokens: int, temperature: float, is_verbose: bool
97
  ) -> "SamplingParams":
98
  """
99
  Create sampling parameters to submit for inference through vLLM `generate`
@@ -123,12 +127,12 @@ def get_avg_logprob(logprobs: "SampleLogprobs") -> float:
123
 
124
 
125
  def process_chunk(
126
- tokenizer: "PreTrainedTokenizer",
127
- ids: np.ndarray,
128
- logprobs: "SampleLogprobs",
129
- request: TranscriptionRequest,
130
- segment_offset: int,
131
- timestamp_offset: int,
132
  ) -> Generator:
133
  """
134
  Decode a single transcribed audio chunk and generates all the segments associated
@@ -198,9 +202,9 @@ def process_chunk(
198
 
199
 
200
  def process_chunks(
201
- tokenizer: "PreTrainedTokenizer",
202
- chunks: List["RequestOutput"],
203
- request: TranscriptionRequest,
204
  ) -> Tuple[List[Segment], str]:
205
  """
206
  Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
@@ -223,7 +227,7 @@ def process_chunks(
223
  logprobs = generation.logprobs
224
 
225
  for segment, _is_continuation in process_chunk(
226
- tokenizer, ids, logprobs, request, segment_offset, time_offset
227
  ):
228
  materialized_segments.append(segment)
229
 
@@ -258,17 +262,17 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
258
  enforce_eager=False,
259
  enable_prefix_caching=True,
260
  max_logprobs=1, # TODO(mfuntowicz) : Set from config?
261
- disable_log_requests=True
262
  )
263
  )
264
 
265
  async def transcribe(
266
- self,
267
- ctx: Context,
268
- request: TranscriptionRequest,
269
- tokenizer: "PreTrainedTokenizer",
270
- audio_chunks: Iterable[np.ndarray],
271
- params: "SamplingParams",
272
  ) -> (List[Segment], str):
273
  async def __agenerate__(request_id: str, prompt, params):
274
  """
@@ -319,14 +323,14 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
319
  return segments, text
320
 
321
  async def __call__(
322
- self, request: TranscriptionRequest, ctx: Context
323
  ) -> TranscriptionResponse:
324
  with logger.contextualize(request_id=ctx.request_id):
325
  with memoryview(request) as audio:
326
 
327
  # Check if we need to enable the verbose path
328
  is_verbose = (
329
- request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
330
  )
331
 
332
  # Retrieve the tokenizer and model config asynchronously while we decode audio
@@ -375,14 +379,22 @@ class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
375
 
376
 
377
  def entrypoint():
378
- interface = os.environ.get("HFENDPOINT_INTERFACE", "0.0.0.0")
379
- port = int(os.environ.get("HFENDPOINT_PORT", "8000"))
 
 
 
 
 
 
380
 
 
381
  endpoint = AutomaticSpeechRecognitionEndpoint(
382
- WhisperHandler("openai/whisper-large-v3")
383
  )
384
 
385
- run(endpoint, interface, port)
 
386
 
387
 
388
  if __name__ == "__main__":
 
1
  import asyncio
 
2
  import zlib
3
  from functools import lru_cache
4
  from io import BytesIO
5
+ from pathlib import Path
6
  from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
7
 
8
  import numpy as np
9
+ from hfendpoints.errors.config import UnsupportedModelArchitecture
10
  from hfendpoints.openai import Context, run
11
  from hfendpoints.openai.audio import (
12
  AutomaticSpeechRecognitionEndpoint,
 
20
  )
21
  from librosa import load as load_audio, get_duration
22
  from loguru import logger
23
+ from transformers import AutoConfig
24
  from vllm import (
25
  AsyncEngineArgs,
26
  AsyncLLMEngine,
27
  SamplingParams,
28
  )
29
 
30
+ from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures
31
 
32
  if TYPE_CHECKING:
33
  from transformers import PreTrainedTokenizer
34
  from vllm import CompletionOutput, RequestOutput
35
  from vllm.sequence import SampleLogprobs
36
 
37
+ SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"]
38
+
39
 
40
  def chunk_audio_with_duration(
41
+ audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
42
  ) -> Sequence[np.ndarray]:
43
  """
44
  Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
 
67
 
68
 
69
  def create_prompt(
70
+ audio: np.ndarray,
71
+ sampling_rate: int,
72
+ language: int,
73
+ timestamp_marker: int,
74
  ):
75
  """
76
  Generate the right prompt with the specific parameters to submit for inference over Whisper
 
97
 
98
 
99
  def create_params(
100
+ max_tokens: int, temperature: float, is_verbose: bool
101
  ) -> "SamplingParams":
102
  """
103
  Create sampling parameters to submit for inference through vLLM `generate`
 
127
 
128
 
129
  def process_chunk(
130
+ tokenizer: "PreTrainedTokenizer",
131
+ ids: np.ndarray,
132
+ logprobs: "SampleLogprobs",
133
+ request: TranscriptionRequest,
134
+ segment_offset: int,
135
+ timestamp_offset: int,
136
  ) -> Generator:
137
  """
138
  Decode a single transcribed audio chunk and generates all the segments associated
 
202
 
203
 
204
  def process_chunks(
205
+ tokenizer: "PreTrainedTokenizer",
206
+ chunks: List["RequestOutput"],
207
+ request: TranscriptionRequest,
208
  ) -> Tuple[List[Segment], str]:
209
  """
210
  Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
 
227
  logprobs = generation.logprobs
228
 
229
  for segment, _is_continuation in process_chunk(
230
+ tokenizer, ids, logprobs, request, segment_offset, time_offset
231
  ):
232
  materialized_segments.append(segment)
233
 
 
262
  enforce_eager=False,
263
  enable_prefix_caching=True,
264
  max_logprobs=1, # TODO(mfuntowicz) : Set from config?
265
+ disable_log_requests=True,
266
  )
267
  )
268
 
269
  async def transcribe(
270
+ self,
271
+ ctx: Context,
272
+ request: TranscriptionRequest,
273
+ tokenizer: "PreTrainedTokenizer",
274
+ audio_chunks: Iterable[np.ndarray],
275
+ params: "SamplingParams",
276
  ) -> (List[Segment], str):
277
  async def __agenerate__(request_id: str, prompt, params):
278
  """
 
323
  return segments, text
324
 
325
  async def __call__(
326
+ self, request: TranscriptionRequest, ctx: Context
327
  ) -> TranscriptionResponse:
328
  with logger.contextualize(request_id=ctx.request_id):
329
  with memoryview(request) as audio:
330
 
331
  # Check if we need to enable the verbose path
332
  is_verbose = (
333
+ request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
334
  )
335
 
336
  # Retrieve the tokenizer and model config asynchronously while we decode audio
 
379
 
380
 
381
  def entrypoint():
382
+ # Retrieve endpoint configuration
383
+ endpoint_config = EndpointConfig.from_env()
384
+
385
+ # Ensure the model is compatible is pre-downloaded
386
+ if (model_local_path := Path(endpoint_config.model_id)).exists():
387
+ if (config_local_path := (model_local_path / "config.json")).exists():
388
+ config = AutoConfig.from_pretrained(config_local_path)
389
+ ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES)
390
 
391
+ # Initialize the endpoint
392
  endpoint = AutomaticSpeechRecognitionEndpoint(
393
+ WhisperHandler(endpoint_config.model_id)
394
  )
395
 
396
+ # Serve the model
397
+ run(endpoint, endpoint_config.interface, endpoint_config.port)
398
 
399
 
400
  if __name__ == "__main__":