Spaces:
Runtime error
Runtime error
Commit
·
649e38c
1
Parent(s):
27a2b18
Update app.py
Browse files
app.py
CHANGED
@@ -4,19 +4,18 @@ import nltk
|
|
4 |
import math
|
5 |
import torch
|
6 |
|
7 |
-
model_name = "
|
8 |
max_input_length = 512
|
9 |
|
10 |
-
st.header("Generate
|
11 |
|
12 |
-
st_model_load = st.text('Loading
|
13 |
|
14 |
@st.cache(allow_output_mutation=True)
|
15 |
def load_model():
|
16 |
print("Loading model...")
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
19 |
-
nltk.download('punkt')
|
20 |
print("Model loaded!")
|
21 |
return tokenizer, model
|
22 |
|
@@ -81,14 +80,14 @@ def generate_title():
|
|
81 |
}
|
82 |
|
83 |
# compute predictions
|
84 |
-
outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
|
85 |
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
86 |
predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
|
87 |
|
88 |
st.session_state.titles = predicted_titles
|
89 |
|
90 |
# generate title button
|
91 |
-
st_generate_button = st.button('Generate
|
92 |
|
93 |
# title generation labels
|
94 |
if 'titles' not in st.session_state:
|
@@ -96,6 +95,6 @@ if 'titles' not in st.session_state:
|
|
96 |
|
97 |
if len(st.session_state.titles) > 0:
|
98 |
with st.container():
|
99 |
-
st.subheader("Generated
|
100 |
for title in st.session_state.titles:
|
101 |
st.markdown("__" + title + "__")
|
|
|
4 |
import math
|
5 |
import torch
|
6 |
|
7 |
+
model_name = "Nopphakorn/mt5-small-thaisum-512-title"
|
8 |
max_input_length = 512
|
9 |
|
10 |
+
st.header("Generate headline titles for Thai news")
|
11 |
|
12 |
+
st_model_load = st.text('Loading headlines summarizer model...')
|
13 |
|
14 |
@st.cache(allow_output_mutation=True)
|
15 |
def load_model():
|
16 |
print("Loading model...")
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
|
19 |
print("Model loaded!")
|
20 |
return tokenizer, model
|
21 |
|
|
|
80 |
}
|
81 |
|
82 |
# compute predictions
|
83 |
+
outputs = model.generate(**inputs, do_sample=True, temperature=temperature, max_new_tokens=64)
|
84 |
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
85 |
predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
|
86 |
|
87 |
st.session_state.titles = predicted_titles
|
88 |
|
89 |
# generate title button
|
90 |
+
st_generate_button = st.button('Generate headlines', on_click=generate_title)
|
91 |
|
92 |
# title generation labels
|
93 |
if 'titles' not in st.session_state:
|
|
|
95 |
|
96 |
if len(st.session_state.titles) > 0:
|
97 |
with st.container():
|
98 |
+
st.subheader("Generated headlines")
|
99 |
for title in st.session_state.titles:
|
100 |
st.markdown("__" + title + "__")
|