embeddings-qwen3 / handler.py
Morgan Funtowicz
do not put a space between prompt and content
4976a8c
from pathlib import Path
from typing import List, Optional, Dict, Any
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from loguru import logger
from vllm import (
AsyncLLMEngine, AsyncEngineArgs,
PoolingParams, EmbeddingRequestOutput,
)
from hfendpoints import EndpointConfig, Handler, __version__
from hfendpoints.http import Context, run
from hfendpoints.tasks import Usage
from hfendpoints.tasks.embedding import EmbeddingRequest, EmbeddingResponse
def get_sentence_transformers_config(config: EndpointConfig) -> Optional[Dict[str, Any]]:
st_config_path = None
if not config.is_debug:
st_config_path = (Path(config.repository) / "config_sentence_transformers.json")
if not st_config_path or not st_config_path.exists():
try:
st_config_path = hf_hub_download(config.model_id, filename="config_sentence_transformers.json")
except EntryNotFoundError:
logger.info(f"Sentence Transformers config not found on {config.model_id}")
return None
with open(st_config_path, "r", encoding="utf-8") as config_f:
from json import load
return load(config_f)
class VllmEmbeddingHandler(Handler):
__slot__ = ("_engine", "_sentence_transformer_config",)
def __init__(self, config: EndpointConfig):
self._sentence_transformers_config = get_sentence_transformers_config(config)
self._engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(
str(config.repository),
task="embed",
device="auto",
dtype="bfloat16",
kv_cache_dtype="auto",
enforce_eager=False,
enable_prefix_caching=True,
disable_log_requests=True,
)
)
async def embeds(
self,
prompts: str,
pooling: PoolingParams,
request_id: str
) -> List[EmbeddingRequestOutput]:
outputs = []
async for item in self._engine.encode(
prompts,
pooling_params=pooling,
request_id=request_id,
lora_request=None,
):
outputs.append(EmbeddingRequestOutput.from_base(item))
return outputs
async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
if "dimension" in request.parameters:
pooling_params = PoolingParams(dimensions=request.parameters["dimension"])
else:
pooling_params = None
if "prompt_name" in request.parameters and self._sentence_transformers_config:
prompt_name = request.parameters["prompt_name"]
tokenizer = await self._engine.get_tokenizer()
prompt = self._sentence_transformers_config.get("prompts", {}).get(prompt_name, None)
num_prompt_tokens = len(tokenizer.tokenize(prompt)) if prompt else 0
else:
prompt = None
num_prompt_tokens = 0
if request.is_batched:
embeddings = []
num_tokens = 0
for idx, document in enumerate(request.inputs):
input = f"{prompt}{document}" if prompt else document
output = await self.embeds(input, pooling_params, f"{ctx.request_id}-{idx}")
num_tokens += len(output[0].prompt_token_ids)
embeddings += [output[0].outputs.embedding]
else:
input = f"{prompt} {request.inputs}" if prompt else request.inputs
output = await self.embeds(input, pooling_params, ctx.request_id)
num_tokens = len(output[0].prompt_token_ids)
embeddings = output[0].outputs.embedding
return EmbeddingResponse(embeddings, prompt_tokens=num_prompt_tokens, num_tokens=num_tokens)
def entrypoint():
# Readout the endpoint configuration from the provided environment variable
config = EndpointConfig.from_env()
logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}")
# Allocate handler
handler = VllmEmbeddingHandler(config)
# Allocate endpoint
from hfendpoints.openai.embedding import EmbeddingEndpoint
endpoint = EmbeddingEndpoint(handler)
run(endpoint, config.interface, config.port)
if __name__ == "__main__":
entrypoint()