File size: 1,736 Bytes
034457d 7d4aad0 bb6d908 fa994ec 034457d dd0166e 034457d dd0166e 034457d d19a591 b38d2da 034457d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import Any, Dict
from transformers import AutoTokenizer, AutoModel
import torch
import logging
# Initialize logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any) -> None:
self.model = AutoModel.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True, use_fast=False
)
def __call__(self, data: Dict[str, Any]) -> Any:
logger.info(f"Received incoming request with {data}")
# Extract input text from the request data
input_text = data.get("inputs", "")
if not input_text:
logger.warning("No input text provided")
return [{"generated_text": ""}] # Return empty result but in valid format
# Tokenize the input
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Process outputs - convert tensors to serializable format
# Extract the last hidden state and convert to list for JSON serialization
last_hidden_state = outputs.last_hidden_state
# Convert to Python list (serializable) - using the mean of the embeddings as a simple approach
embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist()
return [{"input_text": input_text, "embedding": embedding}]
if __name__ == "__main__":
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
|