Nopphakorn commited on
Commit
649e38c
·
1 Parent(s): 27a2b18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -4,19 +4,18 @@ import nltk
4
  import math
5
  import torch
6
 
7
- model_name = "fabiochiu/t5-base-medium-title-generation"
8
  max_input_length = 512
9
 
10
- st.header("Generate candidate titles for articles")
11
 
12
- st_model_load = st.text('Loading title generator model...')
13
 
14
  @st.cache(allow_output_mutation=True)
15
  def load_model():
16
  print("Loading model...")
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
- nltk.download('punkt')
20
  print("Model loaded!")
21
  return tokenizer, model
22
 
@@ -81,14 +80,14 @@ def generate_title():
81
  }
82
 
83
  # compute predictions
84
- outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
  predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
 
88
  st.session_state.titles = predicted_titles
89
 
90
  # generate title button
91
- st_generate_button = st.button('Generate title', on_click=generate_title)
92
 
93
  # title generation labels
94
  if 'titles' not in st.session_state:
@@ -96,6 +95,6 @@ if 'titles' not in st.session_state:
96
 
97
  if len(st.session_state.titles) > 0:
98
  with st.container():
99
- st.subheader("Generated titles")
100
  for title in st.session_state.titles:
101
  st.markdown("__" + title + "__")
 
4
  import math
5
  import torch
6
 
7
+ model_name = "Nopphakorn/mt5-small-thaisum-512-title"
8
  max_input_length = 512
9
 
10
+ st.header("Generate headline titles for Thai news")
11
 
12
+ st_model_load = st.text('Loading headlines summarizer model...')
13
 
14
  @st.cache(allow_output_mutation=True)
15
  def load_model():
16
  print("Loading model...")
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
19
  print("Model loaded!")
20
  return tokenizer, model
21
 
 
80
  }
81
 
82
  # compute predictions
83
+ outputs = model.generate(**inputs, do_sample=True, temperature=temperature, max_new_tokens=64)
84
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
85
  predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
86
 
87
  st.session_state.titles = predicted_titles
88
 
89
  # generate title button
90
+ st_generate_button = st.button('Generate headlines', on_click=generate_title)
91
 
92
  # title generation labels
93
  if 'titles' not in st.session_state:
 
95
 
96
  if len(st.session_state.titles) > 0:
97
  with st.container():
98
+ st.subheader("Generated headlines")
99
  for title in st.session_state.titles:
100
  st.markdown("__" + title + "__")