MjolnirThor commited on
Commit
0026ff3
·
0 Parent(s):

Initial commit: Add FLAN-T5 custom handler

Browse files
Files changed (3) hide show
  1. handler.py +50 -0
  2. requirements.txt +3 -0
  3. test_handler.py +23 -0
handler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ # Load FLAN-T5 model and tokenizer
8
+ self.model_name = "google/flan-t5-large"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
11
+
12
+ # Enable evaluation mode
13
+ self.model.eval()
14
+
15
+ def __call__(self, data: Dict) -> List[Dict]:
16
+ # Get input text
17
+ inputs = data.pop("inputs", data)
18
+
19
+ # Ensure inputs is a list
20
+ if isinstance(inputs, str):
21
+ inputs = [inputs]
22
+
23
+ # Tokenize inputs
24
+ tokenized = self.tokenizer(
25
+ inputs,
26
+ padding=True,
27
+ truncation=True,
28
+ max_length=512,
29
+ return_tensors="pt"
30
+ )
31
+
32
+ # Perform inference
33
+ with torch.no_grad():
34
+ outputs = self.model.generate(
35
+ tokenized.input_ids,
36
+ max_length=512,
37
+ min_length=50,
38
+ temperature=0.9,
39
+ top_p=0.95,
40
+ top_k=50,
41
+ do_sample=True,
42
+ num_return_sequences=1
43
+ )
44
+
45
+ # Decode the generated responses
46
+ responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
47
+
48
+ # Format output
49
+ results = [{"generated_text": response} for response in responses]
50
+ return results
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ huggingface-hub>=0.19.0
test_handler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ def test_flan_t5():
4
+ # Initialize handler
5
+ handler = EndpointHandler()
6
+
7
+ # Test cases
8
+ test_inputs = [
9
+ "Explain quantum computing in simple terms",
10
+ "Translate 'Hello, how are you?' to French",
11
+ "Write a short story about a magical forest"
12
+ ]
13
+
14
+ # Test each input
15
+ for text in test_inputs:
16
+ print("\n" + "="*50)
17
+ print(f"Input text: {text}")
18
+ result = handler({"inputs": text})
19
+ print("Generated response:", result)
20
+ print("="*50)
21
+
22
+ if __name__ == "__main__":
23
+ test_flan_t5()