JunhuiJi commited on
Commit
10fdae2
·
1 Parent(s): feb334e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from model import GPT2LMHeadModel
3
+ from transformers import BertTokenizer
4
+ import argparse
5
+ import os
6
+ import torch
7
+ import time
8
+ from generate_title import predict_one_sample
9
+
10
+ st.set_page_config(page_title="Demo", initial_sidebar_state="auto", layout="wide")
11
+
12
+
13
+ # @st.cache_data(allow_output_mutation=True)
14
+ def get_model(device, vocab_path, model_path):
15
+ tokenizer = BertTokenizer.from_pretrained(vocab_path, do_lower_case=True)
16
+ model = GPT2LMHeadModel.from_pretrained(model_path)
17
+ model.to(device)
18
+ model.eval()
19
+ return tokenizer, model
20
+
21
+
22
+ device_ids = 0
23
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
24
+ os.environ["CUDA_VISIBLE_DEVICE"] = str(device_ids)
25
+ device = torch.device("cuda:1" if torch.cuda.is_available() and int(device_ids) >= 0 else "cpu")
26
+ tokenizer, model = get_model(device, "vocab.txt", "checkpoint-55922")
27
+
28
+
29
+ def writer():
30
+ st.markdown(
31
+ """
32
+ ## Text Summary DEMO
33
+ """
34
+ )
35
+ st.sidebar.subheader("配置参数")
36
+ batch_size = st.sidebar.slider("batch_size", min_value=0, max_value=10, value=3)
37
+ generate_max_len = st.sidebar.number_input("generate_max_len", min_value=0, max_value=64, value=32, step=1)
38
+ repetition_penalty = st.sidebar.number_input("repetition_penalty", min_value=0.0, max_value=10.0, value=1.2,
39
+ step=0.1)
40
+ top_k = st.sidebar.slider("top_k", min_value=0, max_value=10, value=3, step=1)
41
+ top_p = st.sidebar.number_input("top_p", min_value=0.0, max_value=1.0, value=0.95, step=0.01)
42
+
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--batch_size', default=batch_size, type=int, help='生成标题的个数')
45
+ parser.add_argument('--generate_max_len', default=generate_max_len, type=int, help='生成标题的最大长度')
46
+ parser.add_argument('--repetition_penalty', default=repetition_penalty, type=float, help='重复处罚率')
47
+ parser.add_argument('--top_k', default=top_k, type=float, help='解码时保留概率最高的多少个标记')
48
+ parser.add_argument('--top_p', default=top_p, type=float, help='解码时保留概率累加大于多少的标记')
49
+ parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小')
50
+ args = parser.parse_args()
51
+
52
+ content = st.text_area("输入正文", max_chars=512)
53
+ if st.button("一键生成摘要"):
54
+ start_message = st.empty()
55
+ start_message.write("正在抽取,请等待...")
56
+ start_time = time.time()
57
+ titles = predict_one_sample(model, tokenizer, device, args, content)
58
+ end_time = time.time()
59
+ start_message.write("抽取完成,耗时{}s".format(end_time - start_time))
60
+ for i, title in enumerate(titles):
61
+ st.text_input("第{}个结果".format(i + 1), title)
62
+ else:
63
+ st.stop()
64
+
65
+
66
+ if __name__ == '__main__':
67
+ writer()