Transformers
English
code
Canstralian commited on
Commit
f678e87
·
verified ·
1 Parent(s): c52c7f3

Create src/model_inference.py

Browse files
Files changed (1) hide show
  1. src/model_inference.py +23 -0
src/model_inference.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+ def load_model(model_path):
5
+ """Load the trained model from the specified path."""
6
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
7
+ return model
8
+
9
+ def load_tokenizer(model_path):
10
+ """Load the tokenizer from the specified path."""
11
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
12
+ return tokenizer
13
+
14
+ def predict(model, tokenizer, text, device='cpu'):
15
+ """Predict the class of the input text."""
16
+ model.to(device)
17
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
18
+ inputs = {key: value.to(device) for key, value in inputs.items()}
19
+ with torch.no_grad():
20
+ outputs = model(**inputs)
21
+ logits = outputs.logits
22
+ predicted_class = torch.argmax(logits, dim=-1).item()
23
+ return predicted_class