from typing import Dict, List import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer class EndpointHandler(): def __init__(self, path=""): # Load FLAN-T5 model and tokenizer self.model_name = "google/flan-t5-large" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) # Enable evaluation mode self.model.eval() def __call__(self, data: Dict) -> List[Dict]: # Get input text inputs = data.pop("inputs", data) # Ensure inputs is a list if isinstance(inputs, str): inputs = [inputs] # Tokenize inputs tokenized = self.tokenizer( inputs, padding=True, truncation=True, max_length=512, return_tensors="pt" ) # Perform inference with torch.no_grad(): outputs = self.model.generate( tokenized.input_ids, max_length=512, min_length=50, temperature=0.9, top_p=0.95, top_k=50, do_sample=True, num_return_sequences=1 ) # Decode the generated responses responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) # Format output results = [{"generated_text": response} for response in responses] return results