|
from pathlib import Path |
|
from typing import List, Optional, Dict, Any |
|
|
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.errors import EntryNotFoundError |
|
from loguru import logger |
|
from vllm import ( |
|
AsyncLLMEngine, AsyncEngineArgs, |
|
PoolingParams, EmbeddingRequestOutput, |
|
) |
|
|
|
from hfendpoints import EndpointConfig, Handler, __version__ |
|
from hfendpoints.http import Context, run |
|
from hfendpoints.tasks import Usage |
|
from hfendpoints.tasks.embedding import EmbeddingRequest, EmbeddingResponse |
|
|
|
|
|
def get_sentence_transformers_config(config: EndpointConfig) -> Optional[Dict[str, Any]]: |
|
st_config_path = None |
|
if not config.is_debug: |
|
st_config_path = (Path(config.repository) / "config_sentence_transformers.json") |
|
|
|
if not st_config_path or not st_config_path.exists(): |
|
try: |
|
st_config_path = hf_hub_download(config.model_id, filename="config_sentence_transformers.json") |
|
except EntryNotFoundError: |
|
logger.info(f"Sentence Transformers config not found on {config.model_id}") |
|
return None |
|
|
|
with open(st_config_path, "r", encoding="utf-8") as config_f: |
|
from json import load |
|
return load(config_f) |
|
|
|
|
|
class VllmEmbeddingHandler(Handler): |
|
__slot__ = ("_engine", "_sentence_transformer_config",) |
|
|
|
def __init__(self, config: EndpointConfig): |
|
self._sentence_transformers_config = get_sentence_transformers_config(config) |
|
self._engine = AsyncLLMEngine.from_engine_args( |
|
AsyncEngineArgs( |
|
str(config.repository), |
|
task="embed", |
|
device="auto", |
|
dtype="bfloat16", |
|
kv_cache_dtype="auto", |
|
enforce_eager=False, |
|
enable_prefix_caching=True, |
|
disable_log_requests=True, |
|
) |
|
) |
|
|
|
async def embeds( |
|
self, |
|
prompts: str, |
|
pooling: PoolingParams, |
|
request_id: str |
|
) -> List[EmbeddingRequestOutput]: |
|
outputs = [] |
|
async for item in self._engine.encode( |
|
prompts, |
|
pooling_params=pooling, |
|
request_id=request_id, |
|
lora_request=None, |
|
): |
|
outputs.append(EmbeddingRequestOutput.from_base(item)) |
|
|
|
return outputs |
|
|
|
async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse: |
|
if "dimension" in request.parameters: |
|
pooling_params = PoolingParams(dimensions=request.parameters["dimension"]) |
|
else: |
|
pooling_params = None |
|
|
|
if "prompt_name" in request.parameters and self._sentence_transformers_config: |
|
prompt_name = request.parameters["prompt_name"] |
|
tokenizer = await self._engine.get_tokenizer() |
|
prompt = self._sentence_transformers_config.get("prompts", {}).get(prompt_name, None) |
|
num_prompt_tokens = len(tokenizer.tokenize(prompt)) if prompt else 0 |
|
else: |
|
prompt = None |
|
num_prompt_tokens = 0 |
|
|
|
if request.is_batched: |
|
embeddings = [] |
|
num_tokens = 0 |
|
for idx, document in enumerate(request.inputs): |
|
input = f"{prompt}{document}" if prompt else document |
|
|
|
output = await self.embeds(input, pooling_params, f"{ctx.request_id}-{idx}") |
|
num_tokens += len(output[0].prompt_token_ids) |
|
embeddings += [output[0].outputs.embedding] |
|
else: |
|
input = f"{prompt} {request.inputs}" if prompt else request.inputs |
|
|
|
output = await self.embeds(input, pooling_params, ctx.request_id) |
|
num_tokens = len(output[0].prompt_token_ids) |
|
embeddings = output[0].outputs.embedding |
|
|
|
return EmbeddingResponse(embeddings, prompt_tokens=num_prompt_tokens, num_tokens=num_tokens) |
|
|
|
|
|
def entrypoint(): |
|
|
|
config = EndpointConfig.from_env() |
|
|
|
logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}") |
|
|
|
|
|
handler = VllmEmbeddingHandler(config) |
|
|
|
|
|
from hfendpoints.openai.embedding import EmbeddingEndpoint |
|
endpoint = EmbeddingEndpoint(handler) |
|
run(endpoint, config.interface, config.port) |
|
|
|
|
|
if __name__ == "__main__": |
|
entrypoint() |
|
|