jschwab21's picture
Create handler.py
351aa3d verified
from typing import Dict, Any, List
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler():
def __init__(self, path=""):
# Load the model in FP16 to reduce memory usage while retaining performance.
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (str): The text input or prompts for the model
Return:
A list containing the generated responses.
"""
# Extract the input text from the request
inputs = data.get("inputs", "")
if not inputs:
return [{"error": "No input provided"}]
# Tokenize the input and run the model to generate output
tokens = self.tokenizer(inputs, return_tensors="pt").to(torch.float16)
output_tokens = self.model.generate(**tokens)
# Decode the generated tokens back to text
output_text = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Return the generated response as a list (required format)
return [{"generated_text": output_text}]