|
|
|
import torch |
|
from transformers import BertTokenizer |
|
from model.luna_model import LunaAI |
|
|
|
def predict(text, model_path='./luna_ai_model'): |
|
model = LunaAI(num_classes=2) |
|
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin")) |
|
model.eval() |
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_path) |
|
encoding = tokenizer.encode_plus(text, return_tensors='pt') |
|
input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask'] |
|
|
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask) |
|
_, prediction = torch.max(output, dim=1) |
|
return prediction.item() |
|
|
|
if __name__ == "__main__": |
|
sample_text = "Sample text to classify" |
|
prediction = predict(sample_text) |
|
print(f"Prediction: {prediction}") |
|
|