File size: 1,295 Bytes
351aa3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Dict, Any, List
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class EndpointHandler():
    def __init__(self, path=""):
        # Load the model in FP16 to reduce memory usage while retaining performance.
        self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
        self.tokenizer = AutoTokenizer.from_pretrained(path)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (str): The text input or prompts for the model
        Return:
            A list containing the generated responses.
        """
        # Extract the input text from the request
        inputs = data.get("inputs", "")
        if not inputs:
            return [{"error": "No input provided"}]

        # Tokenize the input and run the model to generate output
        tokens = self.tokenizer(inputs, return_tensors="pt").to(torch.float16)
        output_tokens = self.model.generate(**tokens)
        
        # Decode the generated tokens back to text
        output_text = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
        
        # Return the generated response as a list (required format)
        return [{"generated_text": output_text}]