Ozgur98 commited on
Commit
a35c2e5
·
1 Parent(s): d4f40ab

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import logging
3
+
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import torch.cuda
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ LOGGER = logging.getLogger(__name__)
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.model = AutoModelForCausalLM.from_pretrained("Ozgur98/pushed_model_mosaic_small", load_in_8bit=True, device_map='auto')
13
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
14
+ # Load the Lora model
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ """
18
+ Args:
19
+ data (Dict): The payload with the text prompt and generation parameters.
20
+ """
21
+ LOGGER.info(f"Received data: {data}")
22
+ # Get inputs
23
+ prompt = data.pop("inputs", None)
24
+ parameters = data.pop("parameters", None)
25
+ if prompt is None:
26
+ raise ValueError("Missing prompt.")
27
+ # Preprocess
28
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
29
+ # Forward
30
+ LOGGER.info(f"Start generation.")
31
+ if parameters is not None:
32
+ output = self.model.generate(input_ids=input_ids, **parameters)
33
+ else:
34
+ output = self.model.generate(input_ids=input_ids)
35
+ # Postprocess
36
+ prediction = self.tokenizer.decode(output[0])
37
+ LOGGER.info(f"Generated text: {prediction}")
38
+ return {"generated_text": prediction}