|
import streamlit as st |
|
import torch |
|
from prediction_sinhala import MDFEND, TokenizerFromPreTrained |
|
|
|
|
|
|
|
|
|
MODEL_SAVE_PATH = "models/last-epoch-model-2024-03-08-15_34_03_6.pth" |
|
BERT_MODEL_NAME = 'sinhala-nlp/sinbert-sold-si' |
|
DOMAIN_NUM = 3 |
|
MAX_LEN = 160 |
|
BATCH_SIZE = 100 |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
|
|
tokenizer = TokenizerFromPreTrained(MAX_LEN, BERT_MODEL_NAME) |
|
|
|
model = MDFEND(BERT_MODEL_NAME, DOMAIN_NUM, expert_num=18, mlp_dims=[5080, 4020, 3010, 2024, 1012, 606, 400]) |
|
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
text_input = st.text_area("Enter text here:") |
|
|
|
|
|
if st.button("Predict"): |
|
if text_input: |
|
|
|
inputs = tokenizer.tokenize(text_input) |
|
|
|
|
|
inputs = torch.tensor(inputs).unsqueeze(0).to(model.device) |
|
|
|
with torch.no_grad(): |
|
|
|
output_prob = model.predict(inputs) |
|
|
|
|
|
prediction = 1 if output_prob >= 0.5 else 0 |
|
result = "offensive" if prediction == 1 else "not offensive" |
|
st.write(f"Prediction: {result}") |
|
else: |
|
st.error("Please enter some text to predict.") |
|
|