asritha22bce commited on
Commit
5964e76
·
verified ·
1 Parent(s): 920dc8b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +26 -0
handler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
+
4
+ class ModelHandler:
5
+ def __init__(self):
6
+ self.model_path = "asritha22bce/bart-positive-tone" # Change if needed
7
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
9
+
10
+ def preprocess(self, text):
11
+ return self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
12
+
13
+ def inference(self, inputs):
14
+ with torch.no_grad():
15
+ output_ids = self.model.generate(**inputs, max_length=50)
16
+ return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
17
+
18
+ def postprocess(self, output):
19
+ return {"positive_headline": output}
20
+
21
+ handler = ModelHandler()
22
+
23
+ def handle_request(text):
24
+ inputs = handler.preprocess(text)
25
+ output = handler.inference(inputs)
26
+ return handler.postprocess(output)