File size: 790 Bytes
226aae3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
# inference.py
import torch
from transformers import BertTokenizer
from model.luna_model import LunaAI
def predict(text, model_path='./luna_ai_model'):
model = LunaAI(num_classes=2)
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin"))
model.eval()
tokenizer = BertTokenizer.from_pretrained(model_path)
encoding = tokenizer.encode_plus(text, return_tensors='pt')
input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask']
with torch.no_grad():
output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)
return prediction.item()
if __name__ == "__main__":
sample_text = "Sample text to classify"
prediction = predict(sample_text)
print(f"Prediction: {prediction}")
|