|
from typing import Any, Dict |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str, **kwargs: Any) -> None: |
|
self.model = AutoModel.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
).eval() |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_dir, trust_remote_code=True, use_fast=False |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
logger.info(f"Received incoming request with {data}") |
|
|
|
|
|
input_text = data.get("inputs", "") |
|
if not input_text: |
|
logger.warning("No input text provided") |
|
return [{"generated_text": ""}] |
|
|
|
|
|
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
|
|
embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist() |
|
|
|
return [{"input_text": input_text, "embedding": embedding}] |
|
|
|
|
|
if __name__ == "__main__": |
|
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct") |
|
|