Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,25 @@
|
|
1 |
import streamlit as st
|
2 |
-
import torch
|
3 |
-
from
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Load model and tokenizer
|
7 |
@st.cache(allow_output_mutation=True)
|
8 |
def load_model():
|
9 |
-
|
10 |
-
tokenizer =
|
11 |
-
model
|
|
|
|
|
|
|
12 |
return model, tokenizer
|
13 |
|
14 |
model, tokenizer = load_model()
|
@@ -18,8 +29,20 @@ text_input = st.text_area("Enter text here:")
|
|
18 |
|
19 |
# Prediction
|
20 |
if st.button("Predict"):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from prediction_sinhala import MDFEND, TokenizerFromPreTrained
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
# Set constants for model and tokenizer paths
|
8 |
+
MODEL_SAVE_PATH = "models/last-epoch-model-2024-03-08-15_34_03_6.pth"
|
9 |
+
BERT_MODEL_NAME = 'sinhala-nlp/sinbert-sold-si'
|
10 |
+
DOMAIN_NUM = 3
|
11 |
+
MAX_LEN = 160
|
12 |
+
BATCH_SIZE = 100
|
13 |
|
14 |
# Load model and tokenizer
|
15 |
@st.cache(allow_output_mutation=True)
|
16 |
def load_model():
|
17 |
+
# Load the tokenizer from the pre-trained model name
|
18 |
+
tokenizer = TokenizerFromPreTrained(MAX_LEN, BERT_MODEL_NAME)
|
19 |
+
# Initialize and load the custom model from saved state
|
20 |
+
model = MDFEND(BERT_MODEL_NAME, DOMAIN_NUM, expert_num=18, mlp_dims=[5080, 4020, 3010, 2024, 1012, 606, 400])
|
21 |
+
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=torch.device('cpu')))
|
22 |
+
model.eval() # Set the model to evaluation mode
|
23 |
return model, tokenizer
|
24 |
|
25 |
model, tokenizer = load_model()
|
|
|
29 |
|
30 |
# Prediction
|
31 |
if st.button("Predict"):
|
32 |
+
if text_input: # Check if input is not empty
|
33 |
+
# Process the input text through the custom tokenizer
|
34 |
+
inputs = tokenizer.tokenize(text_input)
|
35 |
+
|
36 |
+
# Convert to tensor, add batch dimension, and send to same device as model
|
37 |
+
inputs = torch.tensor(inputs).unsqueeze(0).to(model.device)
|
38 |
+
|
39 |
+
with torch.no_grad(): # No gradient computation
|
40 |
+
# Get model prediction
|
41 |
+
output_prob = model.predict(inputs)
|
42 |
+
|
43 |
+
# Interpret the output probability
|
44 |
+
prediction = 1 if output_prob >= 0.5 else 0
|
45 |
+
result = "offensive" if prediction == 1 else "not offensive"
|
46 |
+
st.write(f"Prediction: {result}")
|
47 |
+
else:
|
48 |
+
st.error("Please enter some text to predict.")
|