from typing import Dict, List | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
# Load FLAN-T5 model and tokenizer | |
self.model_name = "google/flan-t5-large" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
# Enable evaluation mode | |
self.model.eval() | |
def __call__(self, data: Dict) -> List[Dict]: | |
# Get input text | |
inputs = data.pop("inputs", data) | |
# Ensure inputs is a list | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
# Tokenize inputs | |
tokenized = self.tokenizer( | |
inputs, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
tokenized.input_ids, | |
max_length=512, | |
min_length=50, | |
temperature=0.9, | |
top_p=0.95, | |
top_k=50, | |
do_sample=True, | |
num_return_sequences=1 | |
) | |
# Decode the generated responses | |
responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
# Format output | |
results = [{"generated_text": response} for response in responses] | |
return results |