Dantinob commited on
Commit
cd6f314
·
verified ·
1 Parent(s): 36284ba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -19
handler.py CHANGED
@@ -1,25 +1,34 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
 
4
- class ModelHandler:
5
- def __init__(self):
6
- self.tokenizer = None
7
- self.model = None
8
-
9
- def initialize(self, model_dir):
10
- # Load the tokenizer and model
11
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
12
  self.model = AutoModelForCausalLM.from_pretrained(model_dir)
 
 
 
 
 
 
 
 
 
 
13
 
14
- def preprocess(self, inputs):
15
- # Preprocess the input prompt
16
- return self.tokenizer(inputs, return_tensors="pt", padding=True)
17
-
18
- def inference(self, inputs):
19
- # Generate text from the model
20
- input_ids = inputs["input_ids"]
21
- outputs = self.model.generate(input_ids, max_length=200, temperature=0.7)
22
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
23
-
24
- def postprocess(self, outputs):
25
- return {"generated_text": outputs}
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, model_dir):
6
+ # Load tokenizer and model during initialization
 
 
 
 
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
  self.model = AutoModelForCausalLM.from_pretrained(model_dir)
9
+
10
+ def __call__(self, data):
11
+ """
12
+ This method processes input data and generates output.
13
+ :param data: Input data, usually a dictionary with 'inputs' key.
14
+ """
15
+ # Extract input prompt
16
+ inputs = data.get("inputs", "")
17
+ if not inputs:
18
+ return {"error": "No input provided"}
19
 
20
+ # Preprocess input
21
+ encoded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
22
+
23
+ # Generate output
24
+ with torch.no_grad():
25
+ outputs = self.model.generate(
26
+ **encoded_inputs,
27
+ max_length=200,
28
+ temperature=0.7,
29
+ do_sample=True
30
+ )
31
+
32
+ # Decode and return response
33
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ return {"generated_text": response}