from typing import Any, Dict from transformers import AutoTokenizer, AutoModel import torch import logging # Initialize logger logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any) -> None: self.model = AutoModel.from_pretrained( model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, ).eval() self.tokenizer = AutoTokenizer.from_pretrained( model_dir, trust_remote_code=True, use_fast=False ) def __call__(self, data: Dict[str, Any]) -> Any: logger.info(f"Received incoming request with {data}") # Extract input text from the request data input_text = data.get("inputs", "") if not input_text: logger.warning("No input text provided") return [{"generated_text": ""}] # Return empty result but in valid format # Tokenize the input inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) # Generate embeddings with torch.no_grad(): outputs = self.model(**inputs) # Process outputs - convert tensors to serializable format # Extract the last hidden state and convert to list for JSON serialization last_hidden_state = outputs.last_hidden_state # Convert to Python list (serializable) - using the mean of the embeddings as a simple approach embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist() return [{"input_text": input_text, "embedding": embedding}] if __name__ == "__main__": handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")