from pathlib import Path from typing import List, Optional, Dict, Any from huggingface_hub import hf_hub_download from huggingface_hub.errors import EntryNotFoundError from loguru import logger from vllm import ( AsyncLLMEngine, AsyncEngineArgs, PoolingParams, EmbeddingRequestOutput, ) from hfendpoints import EndpointConfig, Handler, __version__ from hfendpoints.http import Context, run from hfendpoints.tasks import Usage from hfendpoints.tasks.embedding import EmbeddingRequest, EmbeddingResponse def get_sentence_transformers_config(config: EndpointConfig) -> Optional[Dict[str, Any]]: st_config_path = None if not config.is_debug: st_config_path = (Path(config.repository) / "config_sentence_transformers.json") if not st_config_path or not st_config_path.exists(): try: st_config_path = hf_hub_download(config.model_id, filename="config_sentence_transformers.json") except EntryNotFoundError: logger.info(f"Sentence Transformers config not found on {config.model_id}") return None with open(st_config_path, "r", encoding="utf-8") as config_f: from json import load return load(config_f) class VllmEmbeddingHandler(Handler): __slot__ = ("_engine", "_sentence_transformer_config",) def __init__(self, config: EndpointConfig): self._sentence_transformers_config = get_sentence_transformers_config(config) self._engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs( str(config.repository), task="embed", device="auto", dtype="bfloat16", kv_cache_dtype="auto", enforce_eager=False, enable_prefix_caching=True, disable_log_requests=True, ) ) async def embeds( self, prompts: str, pooling: PoolingParams, request_id: str ) -> List[EmbeddingRequestOutput]: outputs = [] async for item in self._engine.encode( prompts, pooling_params=pooling, request_id=request_id, lora_request=None, ): outputs.append(EmbeddingRequestOutput.from_base(item)) return outputs async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse: if "dimension" in request.parameters: pooling_params = PoolingParams(dimensions=request.parameters["dimension"]) else: pooling_params = None if "prompt_name" in request.parameters and self._sentence_transformers_config: prompt_name = request.parameters["prompt_name"] tokenizer = await self._engine.get_tokenizer() prompt = self._sentence_transformers_config.get("prompts", {}).get(prompt_name, None) num_prompt_tokens = len(tokenizer.tokenize(prompt)) if prompt else 0 else: prompt = None num_prompt_tokens = 0 if request.is_batched: embeddings = [] num_tokens = 0 for idx, document in enumerate(request.inputs): input = f"{prompt}{document}" if prompt else document output = await self.embeds(input, pooling_params, f"{ctx.request_id}-{idx}") num_tokens += len(output[0].prompt_token_ids) embeddings += [output[0].outputs.embedding] else: input = f"{prompt} {request.inputs}" if prompt else request.inputs output = await self.embeds(input, pooling_params, ctx.request_id) num_tokens = len(output[0].prompt_token_ids) embeddings = output[0].outputs.embedding return EmbeddingResponse(embeddings, prompt_tokens=num_prompt_tokens, num_tokens=num_tokens) def entrypoint(): # Readout the endpoint configuration from the provided environment variable config = EndpointConfig.from_env() logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}") # Allocate handler handler = VllmEmbeddingHandler(config) # Allocate endpoint from hfendpoints.openai.embedding import EmbeddingEndpoint endpoint = EmbeddingEndpoint(handler) run(endpoint, config.interface, config.port) if __name__ == "__main__": entrypoint()