CodyBontecou's picture
output serialization
b38d2da
from typing import Any, Dict
from transformers import AutoTokenizer, AutoModel
import torch
import logging
# Initialize logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any) -> None:
self.model = AutoModel.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True, use_fast=False
)
def __call__(self, data: Dict[str, Any]) -> Any:
logger.info(f"Received incoming request with {data}")
# Extract input text from the request data
input_text = data.get("inputs", "")
if not input_text:
logger.warning("No input text provided")
return [{"generated_text": ""}] # Return empty result but in valid format
# Tokenize the input
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Process outputs - convert tensors to serializable format
# Extract the last hidden state and convert to list for JSON serialization
last_hidden_state = outputs.last_hidden_state
# Convert to Python list (serializable) - using the mean of the embeddings as a simple approach
embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist()
return [{"input_text": input_text, "embedding": embedding}]
if __name__ == "__main__":
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")