stuckdavis commited on
Commit
c200376
·
verified ·
1 Parent(s): c9f7808

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +42 -40
handler.py CHANGED
@@ -1,43 +1,45 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
 
4
- # Load once when the endpoint starts
5
- model_name = "open-paws/text_performance_prediction_longform"
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
- model.eval()
10
-
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model.to(device)
13
-
14
- def predict(inputs):
15
- """
16
- Hugging Face Inference Endpoints will call this function.
17
- `inputs` can be a single string or a list of strings.
18
- """
19
- if isinstance(inputs, str):
20
- inputs = [inputs]
21
-
22
- results = []
23
- for text in inputs:
24
- encoded = tokenizer(
25
- text,
26
- return_tensors="pt",
27
- truncation=True,
28
- padding="max_length",
29
- max_length=4096,
30
- )
31
- encoded = {k: v.to(device) for k, v in encoded.items()}
32
-
33
- with torch.no_grad():
34
- outputs = model(**encoded)
35
-
36
- raw_score = outputs.logits.squeeze().item()
37
- clipped_score = min(max(raw_score, 0.0), 1.0)
38
-
39
- results.append({
40
- "score": round(clipped_score, 4),
41
- })
42
-
43
- return results
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, path=""):
6
+ # Load model and tokenizer from the repo path
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
9
+ self.model.eval()
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.model.to(self.device)
12
+
13
+ def __call__(self, data):
14
+ """
15
+ This method is called when the endpoint receives a request.
16
+ Expected input: { "inputs": "some string" } or { "inputs": ["a", "b", ...] }
17
+ """
18
+ inputs = data.get("inputs", None)
19
+
20
+ if inputs is None:
21
+ return {"error": "No input provided"}
22
+
23
+ if isinstance(inputs, str):
24
+ inputs = [inputs]
25
+
26
+ results = []
27
+ for text in inputs:
28
+ encoded = self.tokenizer(
29
+ text,
30
+ return_tensors="pt",
31
+ truncation=True,
32
+ padding="max_length",
33
+ max_length=4096,
34
+ )
35
+ encoded = {k: v.to(self.device) for k, v in encoded.items()}
36
+
37
+ with torch.no_grad():
38
+ outputs = self.model(**encoded)
39
+
40
+ raw_score = outputs.logits.squeeze().item()
41
+ clipped_score = min(max(raw_score, 0.0), 1.0)
42
+
43
+ results.append({"score": round(clipped_score, 4)})
44
+
45
+ return results