from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
class EndpointHandler: | |
def __init__(self, path="google/flan-t5-large"): | |
self.tokenizer = AutoTokenizer.from_pretrained(path) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(path) | |
def __call__(self, data): | |
""" | |
Args: | |
data: (dict): A dictionary with a "inputs" key containing the text to process | |
""" | |
inputs = data.pop("inputs", data) | |
# Parameters for text generation | |
parameters = { | |
"max_length": 512, | |
"min_length": 32, | |
"temperature": 0.9, | |
"top_p": 0.95, | |
"top_k": 50, | |
"do_sample": True, | |
"num_return_sequences": 1 | |
} | |
# Update parameters if provided in the request | |
parameters.update(data) | |
# Tokenize the input | |
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids | |
# Generate the response | |
outputs = self.model.generate(input_ids, **parameters) | |
# Decode the response | |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"generated_text": generated_text} |