File size: 1,573 Bytes
0026ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Dict, List
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

class EndpointHandler():
    def __init__(self, path=""):
        # Load FLAN-T5 model and tokenizer
        self.model_name = "google/flan-t5-large"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
        
        # Enable evaluation mode
        self.model.eval()

    def __call__(self, data: Dict) -> List[Dict]:
        # Get input text
        inputs = data.pop("inputs", data)
        
        # Ensure inputs is a list
        if isinstance(inputs, str):
            inputs = [inputs]
            
        # Tokenize inputs
        tokenized = self.tokenizer(
            inputs,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        # Perform inference
        with torch.no_grad():
            outputs = self.model.generate(
                tokenized.input_ids,
                max_length=512,
                min_length=50,
                temperature=0.9,
                top_p=0.95,
                top_k=50,
                do_sample=True,
                num_return_sequences=1
            )
            
        # Decode the generated responses
        responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            
        # Format output
        results = [{"generated_text": response} for response in responses]
        return results