pavlyhalim commited on
Commit
98ae6d3
·
1 Parent(s): 9f5f3fe

Add application file

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # Load the tokenizer and model from HuggingFace Hub
7
+ @st.cache_resource(show_spinner=False)
8
+ def load_model():
9
+ model_name = "pavlyhalim/BERT_ALL_README"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
+ model.eval()
13
+ return tokenizer, model
14
+
15
+ tokenizer, model = load_model()
16
+
17
+ st.title("Readability Score Predictor based on BERT")
18
+
19
+ st.write("""
20
+ Enter a sentence, and the model will predict its readability score (from 1 to 6).
21
+ """)
22
+
23
+ user_input = st.text_area("Enter your sentence here:", height=100)
24
+
25
+ if st.button("Predict Readability Score"):
26
+ if user_input.strip() == "":
27
+ st.warning("Please enter a sentence.")
28
+ else:
29
+ with st.spinner('Predicting...'):
30
+ inputs = tokenizer(
31
+ user_input,
32
+ return_tensors="pt",
33
+ padding=True,
34
+ truncation=True,
35
+ max_length=128
36
+ )
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+ logits = outputs.logits
41
+
42
+ probabilities = F.softmax(logits, dim=1)
43
+ predicted_class = torch.argmax(probabilities, dim=1).item()
44
+ predicted_probability = probabilities[0][predicted_class].item()
45
+
46
+ predicted_label = predicted_class + 1
47
+
48
+ st.success(f"Predicted Readability Score: **{predicted_label}**")
49
+ st.write(f"Confidence: **{predicted_probability * 100:.2f}%**")
50
+
51
+ st.write("### Class Probabilities:")
52
+ for i, prob in enumerate(probabilities[0]):
53
+ label = i + 1
54
+ st.write(f"Score {label}: {prob.item() * 100:.2f}%")