from typing import Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer import torch class EndpointHandler: def __init__(self, path=""): # Load the tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path) self.model.eval() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data: A dictionary with the key 'inputs' containing the input text. Returns: A dictionary with the generated text under the key 'generated_text'. """ # Extract input text input_text = data.get("inputs", "") if not input_text: return {"error": "No input provided"} # Tokenize the input inputs = self.tokenizer(input_text, return_tensors="pt") # Generate text with torch.no_grad(): outputs = self.model.generate(**inputs, max_length=100) # Decode the generated tokens generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated_text}