autotrain-v2pzm-t7w84 / handler.py
AgentisLabs's picture
Create handler.py
867fc5e verified
from typing import Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path=""):
# Load the tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data: A dictionary with the key 'inputs' containing the input text.
Returns:
A dictionary with the generated text under the key 'generated_text'.
"""
# Extract input text
input_text = data.get("inputs", "")
if not input_text:
return {"error": "No input provided"}
# Tokenize the input
inputs = self.tokenizer(input_text, return_tensors="pt")
# Generate text
with torch.no_grad():
outputs = self.model.generate(**inputs, max_length=100)
# Decode the generated tokens
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated_text}