File size: 1,736 Bytes
034457d
7d4aad0
bb6d908
fa994ec
 
 
 
 
034457d
 
 
 
dd0166e
034457d
 
 
 
 
dd0166e
034457d
 
 
 
d19a591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b38d2da
 
 
 
 
 
 
 
034457d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Any, Dict
from transformers import AutoTokenizer, AutoModel
import torch
import logging

# Initialize logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs: Any) -> None:
        self.model = AutoModel.from_pretrained(
            model_dir,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir, trust_remote_code=True, use_fast=False
        )

    def __call__(self, data: Dict[str, Any]) -> Any:
        logger.info(f"Received incoming request with {data}")

        # Extract input text from the request data
        input_text = data.get("inputs", "")
        if not input_text:
            logger.warning("No input text provided")
            return [{"generated_text": ""}]  # Return empty result but in valid format

        # Tokenize the input
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)

        # Generate embeddings
        with torch.no_grad():
            outputs = self.model(**inputs)

        # Process outputs - convert tensors to serializable format
        # Extract the last hidden state and convert to list for JSON serialization
        last_hidden_state = outputs.last_hidden_state

        # Convert to Python list (serializable) - using the mean of the embeddings as a simple approach
        embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist()

        return [{"input_text": input_text, "embedding": embedding}]


if __name__ == "__main__":
    handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")