MjolnirThor commited on
Commit
f38a916
·
1 Parent(s): 0026ff3

Initial commit: Add FLAN-T5 custom handler

Browse files
Files changed (2) hide show
  1. handler.py +33 -44
  2. test_handler.py +12 -16
handler.py CHANGED
@@ -1,50 +1,39 @@
1
- from typing import Dict, List
2
- import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
4
 
5
- class EndpointHandler():
6
- def __init__(self, path=""):
7
- # Load FLAN-T5 model and tokenizer
8
- self.model_name = "google/flan-t5-large"
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
11
 
12
- # Enable evaluation mode
13
- self.model.eval()
14
-
15
- def __call__(self, data: Dict) -> List[Dict]:
16
- # Get input text
17
  inputs = data.pop("inputs", data)
18
 
19
- # Ensure inputs is a list
20
- if isinstance(inputs, str):
21
- inputs = [inputs]
22
-
23
- # Tokenize inputs
24
- tokenized = self.tokenizer(
25
- inputs,
26
- padding=True,
27
- truncation=True,
28
- max_length=512,
29
- return_tensors="pt"
30
- )
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Perform inference
33
- with torch.no_grad():
34
- outputs = self.model.generate(
35
- tokenized.input_ids,
36
- max_length=512,
37
- min_length=50,
38
- temperature=0.9,
39
- top_p=0.95,
40
- top_k=50,
41
- do_sample=True,
42
- num_return_sequences=1
43
- )
44
-
45
- # Decode the generated responses
46
- responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
47
-
48
- # Format output
49
- results = [{"generated_text": response} for response in responses]
50
- return results
 
 
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, path="google/flan-t5-large"):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
7
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
 
 
8
 
9
+ def __call__(self, data):
10
+ """
11
+ Args:
12
+ data: (dict): A dictionary with a "inputs" key containing the text to process
13
+ """
14
  inputs = data.pop("inputs", data)
15
 
16
+ # Parameters for text generation
17
+ parameters = {
18
+ "max_length": 512,
19
+ "min_length": 32,
20
+ "temperature": 0.9,
21
+ "top_p": 0.95,
22
+ "top_k": 50,
23
+ "do_sample": True,
24
+ "num_return_sequences": 1
25
+ }
26
+
27
+ # Update parameters if provided in the request
28
+ parameters.update(data)
29
+
30
+ # Tokenize the input
31
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
32
+
33
+ # Generate the response
34
+ outputs = self.model.generate(input_ids, **parameters)
35
+
36
+ # Decode the response
37
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
39
+ return {"generated_text": generated_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_handler.py CHANGED
@@ -1,23 +1,19 @@
1
  from handler import EndpointHandler
2
 
3
- def test_flan_t5():
4
- # Initialize handler
5
  handler = EndpointHandler()
6
 
7
- # Test cases
8
- test_inputs = [
9
- "Explain quantum computing in simple terms",
10
- "Translate 'Hello, how are you?' to French",
11
- "Write a short story about a magical forest"
12
- ]
13
 
14
- # Test each input
15
- for text in test_inputs:
16
- print("\n" + "="*50)
17
- print(f"Input text: {text}")
18
- result = handler({"inputs": text})
19
- print("Generated response:", result)
20
- print("="*50)
21
 
22
  if __name__ == "__main__":
23
- test_flan_t5()
 
1
  from handler import EndpointHandler
2
 
3
+ def test_handler():
4
+ # Initialize the handler
5
  handler = EndpointHandler()
6
 
7
+ # Test with a simple prompt
8
+ test_input = {
9
+ "inputs": "Explain quantum computing in simple terms"
10
+ }
 
 
11
 
12
+ # Get the response
13
+ response = handler(test_input)
14
+
15
+ print("Input:", test_input["inputs"])
16
+ print("Output:", response["generated_text"])
 
 
17
 
18
  if __name__ == "__main__":
19
+ test_handler()