Upload 13 files
Browse files- utils/__init__.py +0 -0
- utils/arguments.py +56 -0
- utils/data_prep.py +433 -0
- utils/fs_utils.py +110 -0
- utils/input_process.py +233 -0
- utils/load_model.py +29 -0
- utils/ner_utils.py +311 -0
- utils/train_model.py +264 -0
- web/confirm.html +111 -0
- web/final.html +63 -0
- web/index.html +40 -0
- web/put.html +53 -0
- web/user.html +34 -0
utils/__init__.py
ADDED
File without changes
|
utils/arguments.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A
|
3 |
+
"""
|
4 |
+
# ๋ฏธ๋ฆฌ ์ค์ ๋ ์ธ์๋ค
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
|
7 |
+
# ์ฌ์ฉ์ ์ ์ ๋ณ์๋ค
|
8 |
+
ROOT_DIR = "" # ํ๋ก์ ํธ ๋ฃจํธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
9 |
+
BERT_PRETRAINED_DIR = "klue/roberta-large" # BERT ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
10 |
+
DATA_PREFIX = "data" # ๋ฐ์ดํฐ ํ์ผ๋ค์ ์์ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
11 |
+
CHECKPOINT_DIR = 'model' # ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
12 |
+
LOG_FATH = 'logs' # ํ๋ จ ๋ก๊ทธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
13 |
+
|
14 |
+
def get_train_args():
|
15 |
+
"""
|
16 |
+
ํ๋ จ ์ธ์ ์ค์
|
17 |
+
"""
|
18 |
+
parser = ArgumentParser(description='I_S', allow_abbrev=False)
|
19 |
+
|
20 |
+
# ์ธ์ ํ์ฑ
|
21 |
+
parser.add_argument('--model_name', type=str, default='KCSN')
|
22 |
+
|
23 |
+
# ๋ชจ๋ธ ์ค์
|
24 |
+
parser.add_argument('--pooling_type', type=str, default='max_pooling')
|
25 |
+
parser.add_argument('--classifier_intermediate_dim', type=int, default=100)
|
26 |
+
parser.add_argument('--nonlinear_type', type=str, default='tanh')
|
27 |
+
|
28 |
+
# BERT ์ค์
|
29 |
+
parser.add_argument('--bert_pretrained_dir', type=str, default=BERT_PRETRAINED_DIR)
|
30 |
+
|
31 |
+
# ํ๋ จ ์ค์
|
32 |
+
parser.add_argument('--margin', type=float, default=1.0)
|
33 |
+
parser.add_argument('--lr', type=float, default=2e-5)
|
34 |
+
parser.add_argument('--optimizer', type=str, default='adam')
|
35 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
36 |
+
parser.add_argument('--num_epochs', type=int, default=50)
|
37 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
38 |
+
parser.add_argument('--lr_decay', type=float, default=0.95)
|
39 |
+
parser.add_argument('--patience', type=int, default=10)
|
40 |
+
|
41 |
+
# ํ๋ จ, ๊ฐ๋ฐ ๋ฐ ํ
์คํธ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
42 |
+
parser.add_argument('--train_file', type=str, default=f'{DATA_PREFIX}/train_unsplit.txt')
|
43 |
+
parser.add_argument('--dev_file', type=str, default=f'{DATA_PREFIX}/dev_unsplit.txt')
|
44 |
+
parser.add_argument('--test_file', type=str, default=f'{DATA_PREFIX}/test_unsplit.txt')
|
45 |
+
parser.add_argument('--name_list_path', type=str, default=f'{DATA_PREFIX}/name_list.txt')
|
46 |
+
parser.add_argument('--ws', type=int, default=10) # ์๋์ฐ ํฌ๊ธฐ
|
47 |
+
|
48 |
+
parser.add_argument('--length_limit', type=int, default=510) # ์ํ์ค ๊ธธ์ด ์ ํ
|
49 |
+
|
50 |
+
# ์ฒดํฌํฌ์ธํธ ๋ฐ ๋ก๊ทธ ์ ์ฅ ๋๋ ํ ๋ฆฌ
|
51 |
+
parser.add_argument('--checkpoint_dir', type=str, default=CHECKPOINT_DIR)
|
52 |
+
parser.add_argument('--training_logs', type=str, default=LOG_FATH)
|
53 |
+
|
54 |
+
args, _ = parser.parse_known_args()
|
55 |
+
|
56 |
+
return args
|
utils/data_prep.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author:
|
3 |
+
"""
|
4 |
+
import copy
|
5 |
+
from typing import Any
|
6 |
+
from ckonlpy.tag import Twitter
|
7 |
+
from tqdm import tqdm
|
8 |
+
import re
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
from sklearn.model_selection import train_test_split
|
13 |
+
|
14 |
+
# ์ฌ์ฉ์๊ฐ ์ฌ์ ์ ๋จ์ด ์ถ๊ฐ๊ฐ ๊ฐ๋ฅํ ํํ์ ๋ถ์๊ธฐ๋ฅผ ์ด์ฉ(์ถํ์ name_list์ ๋ฑ์ฌ๋ ์ด๋ฆ์ ๋ฑ๋กํ์ฌ ์ธ์ ๋ฐ ๋ถ๋ฆฌํ๊ธฐ ์ํจ)
|
15 |
+
twitter = Twitter()
|
16 |
+
|
17 |
+
|
18 |
+
def load_data(filename) -> Any:
|
19 |
+
"""
|
20 |
+
์ง์ ๋ ํ์ผ์์ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
21 |
+
"""
|
22 |
+
return torch.load(filename)
|
23 |
+
|
24 |
+
|
25 |
+
def NML(seg_sents, mention_positions, ws):
|
26 |
+
"""
|
27 |
+
Nearest Mention Location (ํน์ ํ๋ณด ๋ฐํ์๊ฐ ์ธ๊ธ๋ ์์น์ค, ์ธ์ฉ๋ฌธ์ผ๋ก๋ถํฐ ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ๊ธ ์์น๋ฅผ ์ฐพ๋ ํจ์)
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
- seg_sents: ๋ฌธ์ฅ์ ๋ถํ ํ ๋ฆฌ์คํธ
|
31 |
+
- mention_positions: ํน์ ํ๋ณด ๋ฐํ์๊ฐ ์ธ๊ธ๋ ์์น๋ฅผ ๋ชจ๋ ๋ด์ ๋ฆฌ์คํธ [(sentence_index, word_index), ...]
|
32 |
+
- ws: ์ธ์ฉ๋ฌธ ์/๋ค๋ก ๊ณ ๋ คํ ๋ฌธ์ฅ์ ์
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
- ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ๊ธ ์์น์ (sentence_index, word_index)
|
36 |
+
"""
|
37 |
+
def word_dist(pos):
|
38 |
+
"""
|
39 |
+
๋ฐํ ํ๋ณด์ ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น์ ์ธ์ฉ๋ฌธ ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๋จ์ด ์์ค(word level)์์ ๋ฐํํฉ๋๋ค.
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
- pos: ๋ฐํ ํ๋ณด์๊ฐ ์ธ๊ธ๋ ์์น (sentence_index, word_index)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
- ๋ฐํ ํ๋ณด์์ ์ธ๊ธ๋ ์์น ์ฌ์ด์ ๊ฑฐ๋ฆฌ (๋จ์ด ์์ค)
|
46 |
+
"""
|
47 |
+
if pos[0] == ws:
|
48 |
+
w_d = ws * 2
|
49 |
+
elif pos[0] < ws:
|
50 |
+
w_d = sum(len(
|
51 |
+
sent) for sent in seg_sents[pos[0] + 1:ws]) + len(seg_sents[pos[0]][pos[1] + 1:])
|
52 |
+
else:
|
53 |
+
w_d = sum(
|
54 |
+
len(sent) for sent in seg_sents[ws + 1:pos[0]]) + len(seg_sents[pos[0]][:pos[1]])
|
55 |
+
return w_d
|
56 |
+
|
57 |
+
# ์ธ๊ธ๋ ์์น๋ค๊ณผ ์ธ์ฉ๋ฌธ ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ๊น์ด ์์ผ๋ก ์ ๋ ฌ
|
58 |
+
sorted_positions = sorted(mention_positions, key=lambda x: word_dist(x))
|
59 |
+
|
60 |
+
# ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ๊ธ ์์น(Nearest Mention Location) ๋ฐํ
|
61 |
+
return sorted_positions[0]
|
62 |
+
|
63 |
+
|
64 |
+
def max_len_cut(seg_sents, mention_pos, max_len):
|
65 |
+
"""
|
66 |
+
์ฃผ์ด์ง ๋ฌธ์ฅ์ ๋ชจ๋ธ์ ์
๋ ฅ ๊ฐ๋ฅํ ์ต๋ ๊ธธ์ด(max_len)๋ก ์๋ฅด๋ ํจ์
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
- seg_sents: ๋ฌธ์ฅ์ ๋ถํ ํ ๋ฆฌ์คํธ
|
70 |
+
- mention_pos: ๋ฐํ ํ๋ณด์๊ฐ ์ธ๊ธ๋ ์์น (sentence_index, word_index)
|
71 |
+
- max_len: ์
๋ ฅ ๊ฐ๋ฅํ ์ต๋ ๊ธธ์ด
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
- seg_sents : ์๋ฅด๊ณ ๋จ์ ๋ฌธ์ฅ ๋ฆฌ์คํธ
|
75 |
+
- mention_pos : ์กฐ์ ๋ ์ธ๊ธ๋ ์์น
|
76 |
+
"""
|
77 |
+
|
78 |
+
# ๊ฐ ๋ฌธ์ฅ์ ๊ธธ์ด๋ฅผ ๋ฌธ์ ๋จ์๋ก ๊ณ์ฐํ ๋ฆฌ์คํธ ์์ฑ
|
79 |
+
sent_char_lens = [sum(len(word) for word in sent) for sent in seg_sents]
|
80 |
+
|
81 |
+
# ์ ์ฒด ๋ฌธ์์ ๊ธธ์ด ํฉ
|
82 |
+
sum_char_len = sum(sent_char_lens)
|
83 |
+
|
84 |
+
# ๊ฐ ๋ฌธ์ฅ์์, cut์ ์คํํ ๋ฌธ์์ ์์น(๋งจ ๋ง์ง๋ง ๋ฌธ์)
|
85 |
+
running_cut_idx = [len(sent) - 1 for sent in seg_sents]
|
86 |
+
|
87 |
+
while sum_char_len > max_len:
|
88 |
+
max_len_sent_idx = max(list(enumerate(sent_char_lens)), key=lambda x: x[1])[0]
|
89 |
+
|
90 |
+
if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] == mention_pos[1]:
|
91 |
+
running_cut_idx[max_len_sent_idx] -= 1
|
92 |
+
|
93 |
+
if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] < mention_pos[1]:
|
94 |
+
mention_pos[1] -= 1
|
95 |
+
|
96 |
+
reduced_char_len = len(
|
97 |
+
seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]])
|
98 |
+
sent_char_lens[max_len_sent_idx] -= reduced_char_len
|
99 |
+
sum_char_len -= reduced_char_len
|
100 |
+
|
101 |
+
# ์๋ฅผ ์์น ์ญ์
|
102 |
+
del seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]]
|
103 |
+
|
104 |
+
# ์๋ฅผ ์์น ์
๋ฐ์ดํธ
|
105 |
+
running_cut_idx[max_len_sent_idx] -= 1
|
106 |
+
|
107 |
+
return seg_sents, mention_pos
|
108 |
+
|
109 |
+
|
110 |
+
def seg_and_mention_location(raw_sents_in_list, alias2id):
|
111 |
+
"""
|
112 |
+
์ฃผ์ด์ง ๋ฌธ์ฅ์ ๋ถํ ํ๊ณ ๋ฐํ์ ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น๋ฅผ ์ฐพ๋ ํจ์
|
113 |
+
|
114 |
+
Parameters:
|
115 |
+
- raw_sents_in_list: ๋ถํ ํ ์๋ณธ ๋ฌธ์ฅ ๋ฆฌ์คํธ
|
116 |
+
- alias2id: ์บ๋ฆญํฐ ๋ณ ์ด๋ฆ(๋ฐ ๋ณ์นญ)๊ณผ ID๋ฅผ ๋งคํํ ๋์
๋๋ฆฌ
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
- seg_sents: ๋ฌธ์ฅ์ ๋จ์ด๋ก ๋ถํ ํ ๋ฆฌ์คํธ
|
120 |
+
- character_mention_poses: ์บ๋ฆญํฐ๋ณ๋ก, ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น๋ฅผ ๋ชจ๋ ์ ์ฅํ ๋์
๋๋ฆฌ {character1_id: [[sent_idx, word_idx], ...]}
|
121 |
+
- name_list_index: ์ธ๊ธ๋ ์บ๋ฆญํฐ ์ด๋ฆ ๋ฆฌ์คํธ
|
122 |
+
"""
|
123 |
+
|
124 |
+
character_mention_poses = {}
|
125 |
+
seg_sents = []
|
126 |
+
id_pattern = ['&C{:02d}&'.format(i) for i in range(51)]
|
127 |
+
|
128 |
+
for sent_idx, sent in enumerate(raw_sents_in_list):
|
129 |
+
raw_sent_with_split = sent.split()
|
130 |
+
|
131 |
+
for word_idx, word in enumerate(raw_sent_with_split):
|
132 |
+
match = re.search(r'&C\d{1,2}&', word)
|
133 |
+
|
134 |
+
# &C00& ํ์์ผ๋ก ๋ ์ด๋ฆ์ด ์์ ๊ฒฝ์ฐ, result ๋ณ์๋ก ์ง์
|
135 |
+
if match:
|
136 |
+
result = match.group(0)
|
137 |
+
|
138 |
+
if alias2id[result] in character_mention_poses:
|
139 |
+
character_mention_poses[alias2id[result]].append([sent_idx, word_idx])
|
140 |
+
else:
|
141 |
+
character_mention_poses[alias2id[result]] = [[sent_idx, word_idx]]
|
142 |
+
|
143 |
+
seg_sents.append(raw_sent_with_split)
|
144 |
+
|
145 |
+
name_list_index = list(character_mention_poses.keys())
|
146 |
+
|
147 |
+
return seg_sents, character_mention_poses, name_list_index
|
148 |
+
|
149 |
+
|
150 |
+
def create_CSS(seg_sents, candidate_mention_poses, args):
|
151 |
+
"""
|
152 |
+
๊ฐ ์ธ์คํด์ค ๋ด ๊ฐ ๋ฐํ์ ํ๋ณด(candidate)์ ๋ํ์ฌ candidate-specific segments(CSS)๋ฅผ ๋ง๋ญ๋๋ค.
|
153 |
+
|
154 |
+
parameters:
|
155 |
+
seg_sents: 2ws + 1 ๊ฐ์ ๋ฌธ์ฅ(๊ฐ ๋ฌธ์ฅ์ ๋ถํ ๋จ)๋ค์ ๋ด์ ๋ฆฌ์คํธ
|
156 |
+
candidate_mention_poses: ๋ฐํ์๋ณ๋ก ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น๋ฅผ ๋ด๊ณ ์๋ ๋์
๋๋ฆฌ์ด๋ฉฐ, ํํ๋ ๋ค์๊ณผ ๊ฐ์.
|
157 |
+
{character index: [[sentence index, word index in sentence] of mention 1,...]...}.
|
158 |
+
args : ์คํ ์ธ์๋ฅผ ๋ด์ ๊ฐ์ฒด
|
159 |
+
|
160 |
+
return:
|
161 |
+
Returned contents are in lists, in which each element corresponds to a candidate.
|
162 |
+
The order of candidate is consistent with that in list(candidate_mention_poses.keys()).
|
163 |
+
many_css: ๊ฐ ๋ฐํ์ ํ๋ณด์ ๋ํ candidate-specific segments(CSS).
|
164 |
+
many_sent_char_len: ๊ฐ CSS์ ๋ฌธ์ ๊ธธ์ด ์ ๋ณด
|
165 |
+
[[character-level length of sentence 1,...] of the CSS of candidate 1,...].
|
166 |
+
many_mention_pos: CSS ๋ด์์, ์ธ์ฉ๋ฌธ๊ณผ ๊ฐ์ฅ ๊ฐ๊น์ด ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น ์ ๋ณด
|
167 |
+
[(sentence-level index of nearest mention in CSS,
|
168 |
+
character-level index of the leftmost character of nearest mention in CSS,
|
169 |
+
character-level index of the rightmost character + 1) of candidate 1,...].
|
170 |
+
many_quote_idx: CSS ๋ด์ ์ธ์ฉ๋ฌธ์ ๋ฌธ์ฅ ์ธ๋ฑ์ค
|
171 |
+
many_cut_css : ์ต๋ ๊ธธ์ด ์ ํ์ด ์ ์ฉ๋ CSS
|
172 |
+
|
173 |
+
"""
|
174 |
+
ws = args.ws
|
175 |
+
max_len = args.length_limit
|
176 |
+
model_name = args.model_name
|
177 |
+
|
178 |
+
# assert len(seg_sents) == ws * 2 + 1
|
179 |
+
|
180 |
+
many_css = []
|
181 |
+
many_sent_char_lens = []
|
182 |
+
many_mention_poses = []
|
183 |
+
many_quote_idxes = []
|
184 |
+
many_cut_css = []
|
185 |
+
|
186 |
+
for candidate_idx in candidate_mention_poses.keys():
|
187 |
+
nearest_pos = NML(seg_sents, candidate_mention_poses[candidate_idx], ws)
|
188 |
+
|
189 |
+
if nearest_pos[0] <= ws:
|
190 |
+
CSS = copy.deepcopy(seg_sents[nearest_pos[0]:ws + 1])
|
191 |
+
mention_pos = [0, nearest_pos[1]]
|
192 |
+
quote_idx = ws - nearest_pos[0]
|
193 |
+
else:
|
194 |
+
CSS = copy.deepcopy(seg_sents[ws:nearest_pos[0] + 1])
|
195 |
+
mention_pos = [nearest_pos[0] - ws, nearest_pos[1]]
|
196 |
+
quote_idx = 0
|
197 |
+
|
198 |
+
cut_CSS, mention_pos = max_len_cut(CSS, mention_pos, max_len)
|
199 |
+
sent_char_lens = [sum(len(word) for word in sent) for sent in cut_CSS]
|
200 |
+
|
201 |
+
mention_pos_left = sum(sent_char_lens[:mention_pos[0]]) + sum(
|
202 |
+
len(x) for x in cut_CSS[mention_pos[0]][:mention_pos[1]])
|
203 |
+
mention_pos_right = mention_pos_left + len(cut_CSS[mention_pos[0]][mention_pos[1]])
|
204 |
+
|
205 |
+
if model_name == 'CSN':
|
206 |
+
mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right)
|
207 |
+
cat_CSS = ''.join([''.join(sent) for sent in cut_CSS])
|
208 |
+
elif model_name == 'KCSN':
|
209 |
+
mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right, mention_pos[1])
|
210 |
+
cat_CSS = ' '.join([' '.join(sent) for sent in cut_CSS])
|
211 |
+
|
212 |
+
many_css.append(cat_CSS)
|
213 |
+
many_sent_char_lens.append(sent_char_lens)
|
214 |
+
many_mention_poses.append(mention_pos)
|
215 |
+
many_quote_idxes.append(quote_idx)
|
216 |
+
many_cut_css.append(cut_CSS)
|
217 |
+
|
218 |
+
return many_css, many_sent_char_lens, many_mention_poses, many_quote_idxes, many_cut_css
|
219 |
+
|
220 |
+
|
221 |
+
class ISDataset(Dataset):
|
222 |
+
"""
|
223 |
+
๋ฐํ์ ์๋ณ์ ์ํ ๋ฐ์ดํฐ์
์๋ธํด๋์ค
|
224 |
+
"""
|
225 |
+
def __init__(self, data_list):
|
226 |
+
super(ISDataset, self).__init__()
|
227 |
+
self.data = data_list
|
228 |
+
|
229 |
+
def __len__(self):
|
230 |
+
return len(self.data)
|
231 |
+
|
232 |
+
def __getitem__(self, idx):
|
233 |
+
return self.data[idx]
|
234 |
+
|
235 |
+
|
236 |
+
def build_data_loader(data_file, alias2id, args, save_name=None) -> DataLoader:
|
237 |
+
"""
|
238 |
+
ํ์ต์ ์ํ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค.
|
239 |
+
"""
|
240 |
+
# ์ฌ์ ์ ์ด๋ฆ์ ์ถ๊ฐ
|
241 |
+
for alias in alias2id:
|
242 |
+
twitter.add_dictionary(alias, 'Noun')
|
243 |
+
|
244 |
+
# ํ์ผ์ ์ค๋ณ๋ก ๋ถ๋ฌ๋ค์
|
245 |
+
with open(data_file, 'r', encoding='utf-8') as fin:
|
246 |
+
data_lines = fin.readlines()
|
247 |
+
|
248 |
+
# ์ ์ฒ๋ฆฌ
|
249 |
+
data_list = []
|
250 |
+
|
251 |
+
for i, line in enumerate(tqdm(data_lines)):
|
252 |
+
offset = i % 31
|
253 |
+
|
254 |
+
if offset == 0:
|
255 |
+
instance_index = line.strip().split()[-1]
|
256 |
+
raw_sents_in_list = []
|
257 |
+
continue
|
258 |
+
|
259 |
+
if offset < 22:
|
260 |
+
raw_sents_in_list.append(line.strip())
|
261 |
+
|
262 |
+
if offset == 22:
|
263 |
+
speaker_name = line.strip().split()[-1]
|
264 |
+
|
265 |
+
# ๋น ๋ฆฌ์คํธ๋ ์ ๊ฑฐ
|
266 |
+
filtered_list = [li for li in raw_sents_in_list if li]
|
267 |
+
|
268 |
+
# ๏ฟฝ๏ฟฝ์ฅ ๋ถํ ๋ฐ ๋ฑ์ฅ์ธ๋ฌผ ์ธ๊ธ ์์น ์ถ์ถ
|
269 |
+
seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
|
270 |
+
filtered_list, alias2id)
|
271 |
+
|
272 |
+
# CSS ์์ฑ
|
273 |
+
css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS(
|
274 |
+
seg_sents, candidate_mention_poses, args)
|
275 |
+
|
276 |
+
# ํ๋ณด์ ๋ฆฌ์คํธ
|
277 |
+
candidates_list = list(candidate_mention_poses.keys())
|
278 |
+
|
279 |
+
# ์ํซ ๋ ์ด๋ธ ์์ฑ
|
280 |
+
one_hot_label = [0 if character_idx != alias2id[speaker_name]
|
281 |
+
else 1 for character_idx in candidate_mention_poses.keys()]
|
282 |
+
|
283 |
+
true_index = one_hot_label.index(1) if 1 in one_hot_label else 0
|
284 |
+
|
285 |
+
if offset == 24:
|
286 |
+
category = line.strip().split()[-1]
|
287 |
+
|
288 |
+
if offset == 25:
|
289 |
+
name = ' '.join(line.strip().split()[1:])
|
290 |
+
|
291 |
+
if offset == 26:
|
292 |
+
scene = line.strip().split()[-1]
|
293 |
+
|
294 |
+
if offset == 27:
|
295 |
+
place = line.strip().split()[-1]
|
296 |
+
|
297 |
+
if offset == 28:
|
298 |
+
time = line.strip().split()[-1]
|
299 |
+
|
300 |
+
if offset == 29:
|
301 |
+
cut_position = line.strip().split()[-1]
|
302 |
+
data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
|
303 |
+
cut_css, one_hot_label, true_index, category, name_list_index,
|
304 |
+
name, scene, place, time, cut_position, candidates_list,
|
305 |
+
instance_index))
|
306 |
+
# ๋ฐ์ดํฐ๋ก๋ ์์ฑ
|
307 |
+
data_loader = DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
|
308 |
+
|
309 |
+
# ์ ์ฅํ ์ด๋ฆ์ด ์ฃผ์ด์ง ๊ฒฝ์ฐ ๋ฐ์ดํฐ ๋ฆฌ์คํธ ์ ์ฅ
|
310 |
+
if save_name is not None:
|
311 |
+
torch.save(data_list, save_name)
|
312 |
+
|
313 |
+
return data_loader
|
314 |
+
|
315 |
+
|
316 |
+
def load_data_loader(saved_filename: str) -> DataLoader:
|
317 |
+
"""
|
318 |
+
์ ์ฅ๋ ํ์ผ์์ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๊ณ DataLoader ๊ฐ์ฒด๋ก ๋ณํํฉ๋๋ค.
|
319 |
+
"""
|
320 |
+
# ์ ์ฅ๋ ๋ฐ์ดํฐ ๋ฆฌ์คํธ ๋ก๋
|
321 |
+
data_list = load_data(saved_filename)
|
322 |
+
return DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
|
323 |
+
|
324 |
+
|
325 |
+
def split_train_val_test(data_file, alias2id, args, save_name=None, test_size=0.2, val_size=0.1, random_state=13):
|
326 |
+
"""
|
327 |
+
๊ธฐ์กด ๊ฒ์ฆ ๋ฐฉ์์ ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋น๋ํฉ๋๋ค.
|
328 |
+
์ฃผ์ด์ง ๋ฐ์ดํฐ ํ์ผ์ ํ๋ จ, ๊ฒ์ฆ, ํ
์คํธ ์ธํธ๋ก ๋ถํ ํ๊ณ ๊ฐ๊ฐ์ DataLoader๋ฅผ ์์ฑํฉ๋๋ค.
|
329 |
+
|
330 |
+
Parameters:
|
331 |
+
- data_file: ๋ถํ ํ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
|
332 |
+
- alias2id: ๋ฑ์ฅ์ธ๋ฌผ ์ด๋ฆ๊ณผ ID๋ฅผ ๋งคํํ ๋์
๋๋ฆฌ
|
333 |
+
- args: ์คํ ์ธ์๋ฅผ ๋ด์ ๊ฐ์ฒด
|
334 |
+
- save_name: ๋ถํ ๋ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ ํ์ผ ์ด๋ฆ
|
335 |
+
- test_size: ํ
์คํธ ์ธํธ์ ๋น์จ (๊ธฐ๋ณธ๊ฐ: 0.2)
|
336 |
+
- val_size: ๊ฒ์ฆ ์ธํธ์ ๋น์จ (๊ธฐ๋ณธ๊ฐ: 0.1)
|
337 |
+
- random_state: ๋๋ค ์๋ (๊ธฐ๋ณธ๊ฐ: 13)
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
- train_loader: ํ๋ จ ๋ฐ์ดํฐ๋ก๋
|
341 |
+
- val_loader: ๊ฒ์ฆ ๋ฐ์ดํฐ๋ก๋
|
342 |
+
- test_loader: ํ
์คํธ ๋ฐ์ดํฐ๋ก๋
|
343 |
+
"""
|
344 |
+
|
345 |
+
# ์ฌ์ ์ ์ด๋ฆ ์ถ๊ฐ
|
346 |
+
for alias in alias2id:
|
347 |
+
twitter.add_dictionary(alias, 'Noun')
|
348 |
+
|
349 |
+
# ํ์ผ์์ ์ธ์คํด์ค ๋ก๋
|
350 |
+
with open(data_file, 'r', encoding='utf-8') as fin:
|
351 |
+
data_lines = fin.readlines()
|
352 |
+
|
353 |
+
# ์ ์ฒ๋ฆฌ
|
354 |
+
data_list = []
|
355 |
+
|
356 |
+
for i, line in enumerate(tqdm(data_lines)):
|
357 |
+
offset = i % 31
|
358 |
+
|
359 |
+
if offset == 0:
|
360 |
+
instance_index = line.strip().split()[-1]
|
361 |
+
raw_sents_in_list = []
|
362 |
+
continue
|
363 |
+
|
364 |
+
if offset < 22:
|
365 |
+
raw_sents_in_list.append(line.strip())
|
366 |
+
|
367 |
+
if offset == 22:
|
368 |
+
speaker_name = line.strip().split()[-1]
|
369 |
+
|
370 |
+
# ๋น ๋ฆฌ์คํธ๋ ์ ๊ฑฐํฉ๋๋ค.
|
371 |
+
filtered_list = [li for li in raw_sents_in_list if li]
|
372 |
+
|
373 |
+
# ๋ฌธ์ฅ ๋ถํ ๋ฐ ๋ฑ์ฅ์ธ๋ฌผ ์ธ๊ธ ์์น ์ถ์ถ
|
374 |
+
seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
|
375 |
+
filtered_list, alias2id)
|
376 |
+
|
377 |
+
# CSS ์์ฑ
|
378 |
+
css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS(
|
379 |
+
seg_sents, candidate_mention_poses, args)
|
380 |
+
|
381 |
+
# ํ๋ณด์ ๋ฆฌ์คํธ
|
382 |
+
candidates_list = list(candidate_mention_poses.keys())
|
383 |
+
|
384 |
+
# ์ํซ ๋ ์ด๋ธ ์์ฑ
|
385 |
+
one_hot_label = [0 if character_idx != alias2id[speaker_name]
|
386 |
+
else 1 for character_idx in candidate_mention_poses.keys()]
|
387 |
+
|
388 |
+
true_index = one_hot_label.index(1) if 1 in one_hot_label else 0
|
389 |
+
|
390 |
+
if offset == 24:
|
391 |
+
category = line.strip().split()[-1]
|
392 |
+
|
393 |
+
if offset == 25:
|
394 |
+
name = ' '.join(line.strip().split()[1:])
|
395 |
+
|
396 |
+
if offset == 26:
|
397 |
+
scene = line.strip().split()[-1]
|
398 |
+
|
399 |
+
if offset == 27:
|
400 |
+
place = line.strip().split()[-1]
|
401 |
+
|
402 |
+
if offset == 28:
|
403 |
+
time = line.strip().split()[-1]
|
404 |
+
|
405 |
+
if offset == 29:
|
406 |
+
cut_position = line.strip().split()[-1]
|
407 |
+
data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
|
408 |
+
cut_css, one_hot_label, true_index, category, name_list_index,
|
409 |
+
name, scene, place, time, cut_position, candidates_list,
|
410 |
+
instance_index))
|
411 |
+
|
412 |
+
# train-validation-test๋ก ๋ฐ์ดํฐ๋ฅผ ๋๋๊ธฐ
|
413 |
+
train_data, test_data = train_test_split(
|
414 |
+
data_list, test_size=test_size, random_state=random_state)
|
415 |
+
train_data, val_data = train_test_split(
|
416 |
+
train_data, test_size=val_size, random_state=random_state)
|
417 |
+
|
418 |
+
# train DataLoader ์์ฑ
|
419 |
+
train_loader = DataLoader(ISDataset(train_data), batch_size=1, collate_fn=lambda x: x[0])
|
420 |
+
|
421 |
+
# validation DataLoader ์์ฑ
|
422 |
+
val_loader = DataLoader(ISDataset(val_data), batch_size=1, collate_fn=lambda x: x[0])
|
423 |
+
|
424 |
+
# test DataLoader ์์ฑ
|
425 |
+
test_loader = DataLoader(ISDataset(test_data), batch_size=1, collate_fn=lambda x: x[0])
|
426 |
+
|
427 |
+
if save_name is not None:
|
428 |
+
# ๊ฐ๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅ
|
429 |
+
torch.save(train_data, save_name.replace(".pt", "_train.pt"))
|
430 |
+
torch.save(val_data, save_name.replace(".pt", "_val.pt"))
|
431 |
+
torch.save(test_data, save_name.replace(".pt", "_test.pt"))
|
432 |
+
|
433 |
+
return train_loader, val_loader, test_loader
|
utils/fs_utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ํ์ ์ฐพ๋ ๋ชจ๋ธ ์ ํธ ํ์ผ๋ค
|
3 |
+
"""
|
4 |
+
class InputFeatures:
|
5 |
+
"""
|
6 |
+
BERT ๋ชจ๋ธ์ ์
๋ ฅ๋ค
|
7 |
+
"""
|
8 |
+
def __init__(self, tokens, input_ids, input_mask, input_type_ids):
|
9 |
+
self.tokens = tokens
|
10 |
+
self.input_ids = input_ids
|
11 |
+
self.input_mask = input_mask
|
12 |
+
self.input_type_ids = input_type_ids
|
13 |
+
|
14 |
+
|
15 |
+
def convert_examples_to_features(examples, tokenizer):
|
16 |
+
"""
|
17 |
+
ํ
์คํธ segment๋ฅผ ๋จ์ด ID๋ก ๋ณํํฉ๋๋ค.
|
18 |
+
"""
|
19 |
+
features = []
|
20 |
+
tokens_list = []
|
21 |
+
|
22 |
+
for (ex_index, example) in enumerate(examples):
|
23 |
+
tokens = tokenizer.tokenize(example)
|
24 |
+
tokens_list.append(tokens)
|
25 |
+
|
26 |
+
new_tokens = []
|
27 |
+
input_type_ids = []
|
28 |
+
|
29 |
+
new_tokens.append("[CLS]")
|
30 |
+
input_type_ids.append(0)
|
31 |
+
new_tokens = new_tokens + tokens
|
32 |
+
input_type_ids = input_type_ids + [0] * len(tokens)
|
33 |
+
new_tokens.append("[SEP]")
|
34 |
+
input_type_ids.append(0)
|
35 |
+
|
36 |
+
input_ids = tokenizer.convert_tokens_to_ids(new_tokens)
|
37 |
+
input_mask = [1] * len(input_ids)
|
38 |
+
|
39 |
+
features.append(
|
40 |
+
InputFeatures(
|
41 |
+
tokens=new_tokens,
|
42 |
+
input_ids=input_ids,
|
43 |
+
input_mask=input_mask,
|
44 |
+
input_type_ids=input_type_ids))
|
45 |
+
|
46 |
+
return features, tokens_list
|
47 |
+
|
48 |
+
|
49 |
+
def get_alias2id(name_list_path) -> dict:
|
50 |
+
"""
|
51 |
+
์ฃผ์ด์ง ์ด๋ฆ ๋ชฉ๋ก ํ์ผ์์ ๋ณ์นญ(alias)์ ID๋ก ๋งคํํ๋ ์ฌ์ ์ ์์ฑ.
|
52 |
+
"""
|
53 |
+
with open(name_list_path, 'r', encoding='utf-8') as fin:
|
54 |
+
name_lines = fin.readlines()
|
55 |
+
alias2id = {}
|
56 |
+
|
57 |
+
for i, line in enumerate(name_lines):
|
58 |
+
for alias in line.strip().split()[1:]:
|
59 |
+
alias2id[alias] = i
|
60 |
+
|
61 |
+
return alias2id
|
62 |
+
|
63 |
+
|
64 |
+
def find_speak(fs_model, input_data, tokenizer, alias2id):
|
65 |
+
"""
|
66 |
+
์ฃผ์ด์ง ๋ชจ๋ธ๊ณผ ์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ์
๋ ฅ์ ๋ํ ํ์๋ฅผ ์ฐพ๋ ํจ์
|
67 |
+
"""
|
68 |
+
model = fs_model
|
69 |
+
check_data_iter = iter(input_data)
|
70 |
+
|
71 |
+
names = []
|
72 |
+
|
73 |
+
for _ in range(len(input_data)):
|
74 |
+
|
75 |
+
seg_sents, css, scl, mp, qi, cut_css, name_list_index = next(check_data_iter)
|
76 |
+
features, tokens_list = convert_examples_to_features(examples=css, tokenizer=tokenizer)
|
77 |
+
|
78 |
+
try:
|
79 |
+
predictions = model(features, scl, mp, qi, 0, "cuda:0", tokens_list, cut_css)
|
80 |
+
except RuntimeError:
|
81 |
+
predictions = model(features, scl, mp, qi, 0, "cpu", tokens_list, cut_css)
|
82 |
+
|
83 |
+
scores, _, _ = predictions
|
84 |
+
|
85 |
+
# ํ์ฒ๋ฆฌ
|
86 |
+
try:
|
87 |
+
scores_np = scores.detach().cpu().numpy()
|
88 |
+
scores_list = scores_np.tolist()
|
89 |
+
score_index = scores_list.index(max(scores_list))
|
90 |
+
name_index = name_list_index[score_index]
|
91 |
+
|
92 |
+
for key, val in alias2id.items():
|
93 |
+
if val == name_index:
|
94 |
+
result_key = key
|
95 |
+
|
96 |
+
names.append(result_key)
|
97 |
+
except AttributeError:
|
98 |
+
names.append('์ ์ ์์')
|
99 |
+
|
100 |
+
return names
|
101 |
+
|
102 |
+
|
103 |
+
def making_script(text, speaker:list, instance_num:list) -> str:
|
104 |
+
"""
|
105 |
+
์ฃผ์ด์ง ํ
์คํธ์ ํ์ ๋ชฉ๋ก, ํด๋นํ๋ ์ค ๋ฒํธ๋ฅผ ์ฌ์ฉํ์ฌ ๋ํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๋ ํจ์
|
106 |
+
"""
|
107 |
+
lines = text.splitlines()
|
108 |
+
for num, people in zip(instance_num, speaker):
|
109 |
+
lines[num] = f'{people}: {lines[num]}'
|
110 |
+
return lines
|
utils/input_process.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
์ฌ์ฉ์ ์
๋ ฅ์ ๊ฐ๊ณตํ๋ ๋ชจ๋
|
3 |
+
"""
|
4 |
+
import copy
|
5 |
+
import re
|
6 |
+
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
|
9 |
+
|
10 |
+
class ISDataset(Dataset):
|
11 |
+
"""
|
12 |
+
Dataset subclass for Identifying speaker.
|
13 |
+
"""
|
14 |
+
def __init__(self, data_list):
|
15 |
+
super(ISDataset, self).__init__()
|
16 |
+
self.data = data_list
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self.data)
|
20 |
+
|
21 |
+
def __getitem__(self, idx):
|
22 |
+
return self.data[idx]
|
23 |
+
|
24 |
+
|
25 |
+
def make_instance_list(text: str, ws=10) -> list:
|
26 |
+
"""
|
27 |
+
์
๋ ฅ๋ฐ์ ๋ฌธ์ฅ์ ๊ธฐ์ด์ ์ธ ์ธ์คํด์ค ๋ฆฌ์คํธ๋ก ๋ง๋ค์ด์ค๋๋ค.
|
28 |
+
"""
|
29 |
+
lines = text.splitlines()
|
30 |
+
max_line = len(lines)
|
31 |
+
|
32 |
+
utterance = ['"', 'โ', 'โ']
|
33 |
+
instance_num = []
|
34 |
+
|
35 |
+
for idx, line in enumerate(lines):
|
36 |
+
if any(u in line for u in utterance):
|
37 |
+
instance_num.append(idx)
|
38 |
+
|
39 |
+
instance = [[] for _ in range(len(instance_num))]
|
40 |
+
|
41 |
+
for i, num in enumerate(instance_num):
|
42 |
+
if num - ws <= 0 and num + ws + 1 < max_line:
|
43 |
+
instance[i] += ([''] * (ws - num))
|
44 |
+
instance[i] +=(lines[:num + 1 + ws])
|
45 |
+
elif num - ws <= 0 and num + ws + 1 >= max_line:
|
46 |
+
instance[i] += ([''] * (ws - num))
|
47 |
+
instance[i] +=(lines)
|
48 |
+
instance[i] += ([''] * (ws * 2 - len(instance[i]) + 1))
|
49 |
+
elif num + ws + 1 >= max_line:
|
50 |
+
instance[i] +=(lines[num-ws:max_line+1])
|
51 |
+
instance[i] += ([''] * (num + ws + 1 - max_line))
|
52 |
+
else:
|
53 |
+
instance[i] += (lines[num-ws:num + ws + 1])
|
54 |
+
|
55 |
+
return instance, instance_num
|
56 |
+
|
57 |
+
|
58 |
+
def NML(seg_sents, mention_positions, ws):
|
59 |
+
"""
|
60 |
+
Nearest Mention Location
|
61 |
+
"""
|
62 |
+
def word_dist(pos):
|
63 |
+
"""
|
64 |
+
The word level distance between quote and the mention position
|
65 |
+
"""
|
66 |
+
if pos[0] == ws:
|
67 |
+
w_d = ws * 2
|
68 |
+
elif pos[0] < ws:
|
69 |
+
w_d = sum(len(
|
70 |
+
sent) for sent in seg_sents[pos[0] + 1:ws]) + len(seg_sents[pos[0]][pos[1] + 1:])
|
71 |
+
else:
|
72 |
+
w_d = sum(
|
73 |
+
len(sent) for sent in seg_sents[ws + 1:pos[0]]) + len(seg_sents[pos[0]][:pos[1]])
|
74 |
+
return w_d
|
75 |
+
|
76 |
+
sorted_positions = sorted(mention_positions, key=lambda x: word_dist(x))
|
77 |
+
|
78 |
+
return sorted_positions[0]
|
79 |
+
|
80 |
+
|
81 |
+
def max_len_cut(seg_sents, mention_pos, max_len):
|
82 |
+
sent_char_lens = [sum(len(word) for word in sent) for sent in seg_sents]
|
83 |
+
sum_char_len = sum(sent_char_lens)
|
84 |
+
|
85 |
+
running_cut_idx = [len(sent) - 1 for sent in seg_sents]
|
86 |
+
|
87 |
+
while sum_char_len > max_len:
|
88 |
+
max_len_sent_idx = max(list(enumerate(sent_char_lens)), key=lambda x: x[1])[0]
|
89 |
+
|
90 |
+
if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] == mention_pos[1]:
|
91 |
+
running_cut_idx[max_len_sent_idx] -= 1
|
92 |
+
|
93 |
+
if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] < mention_pos[1]:
|
94 |
+
mention_pos[1] -= 1
|
95 |
+
|
96 |
+
reduced_char_len = len(
|
97 |
+
seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]])
|
98 |
+
sent_char_lens[max_len_sent_idx] -= reduced_char_len
|
99 |
+
sum_char_len -= reduced_char_len
|
100 |
+
|
101 |
+
del seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]]
|
102 |
+
|
103 |
+
running_cut_idx[max_len_sent_idx] -= 1
|
104 |
+
|
105 |
+
return seg_sents, mention_pos
|
106 |
+
|
107 |
+
|
108 |
+
def seg_and_mention_location(raw_sents_in_list, alias2id):
|
109 |
+
character_mention_poses = {}
|
110 |
+
seg_sents = []
|
111 |
+
id_pattern = ['&C{:02d}&'.format(i) for i in range(51)]
|
112 |
+
|
113 |
+
for sent_idx, sent in enumerate(raw_sents_in_list):
|
114 |
+
raw_sent_with_split = sent.split()
|
115 |
+
|
116 |
+
for word_idx, word in enumerate(raw_sent_with_split):
|
117 |
+
match = re.search(r'&C\d{1,2}&', word)
|
118 |
+
|
119 |
+
if match:
|
120 |
+
result = match.group(0)
|
121 |
+
|
122 |
+
if alias2id[result] in character_mention_poses:
|
123 |
+
character_mention_poses[alias2id[result]].append([sent_idx, word_idx])
|
124 |
+
else:
|
125 |
+
character_mention_poses[alias2id[result]] = [[sent_idx, word_idx]]
|
126 |
+
|
127 |
+
seg_sents.append(raw_sent_with_split)
|
128 |
+
|
129 |
+
name_list_index = list(character_mention_poses.keys())
|
130 |
+
|
131 |
+
return seg_sents, character_mention_poses, name_list_index
|
132 |
+
|
133 |
+
|
134 |
+
def create_css(seg_sents, candidate_mention_poses, ws=10):
|
135 |
+
"""
|
136 |
+
Create candidate-specific segments for each candidate in an instance.
|
137 |
+
"""
|
138 |
+
# assert len(seg_sents) == ws * 2 + 1
|
139 |
+
|
140 |
+
many_css = []
|
141 |
+
many_sent_char_lens = []
|
142 |
+
many_mention_poses = []
|
143 |
+
many_quote_idxes = []
|
144 |
+
many_cut_css = []
|
145 |
+
|
146 |
+
for candidate_idx in candidate_mention_poses.keys():
|
147 |
+
nearest_pos = NML(seg_sents, candidate_mention_poses[candidate_idx], ws)
|
148 |
+
|
149 |
+
if nearest_pos[0] <= ws:
|
150 |
+
CSS = copy.deepcopy(seg_sents[nearest_pos[0]:ws + 1])
|
151 |
+
mention_pos = [0, nearest_pos[1]]
|
152 |
+
quote_idx = ws - nearest_pos[0]
|
153 |
+
else:
|
154 |
+
CSS = copy.deepcopy(seg_sents[ws:nearest_pos[0] + 1])
|
155 |
+
mention_pos = [nearest_pos[0] - ws, nearest_pos[1]]
|
156 |
+
quote_idx = 0
|
157 |
+
|
158 |
+
cut_CSS, mention_pos = max_len_cut(CSS, mention_pos, 510)
|
159 |
+
sent_char_lens = [sum(len(word) for word in sent) for sent in cut_CSS]
|
160 |
+
|
161 |
+
mention_pos_left = sum(sent_char_lens[:mention_pos[0]]) + sum(
|
162 |
+
len(x) for x in cut_CSS[mention_pos[0]][:mention_pos[1]])
|
163 |
+
mention_pos_right = mention_pos_left + len(cut_CSS[mention_pos[0]][mention_pos[1]])
|
164 |
+
mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right, mention_pos[1])
|
165 |
+
cat_CSS = ' '.join([' '.join(sent) for sent in cut_CSS])
|
166 |
+
|
167 |
+
many_css.append(cat_CSS)
|
168 |
+
many_sent_char_lens.append(sent_char_lens)
|
169 |
+
many_mention_poses.append(mention_pos)
|
170 |
+
many_quote_idxes.append(quote_idx)
|
171 |
+
many_cut_css.append(cut_CSS)
|
172 |
+
|
173 |
+
return many_css, many_sent_char_lens, many_mention_poses, many_quote_idxes, many_cut_css
|
174 |
+
|
175 |
+
|
176 |
+
def input_data_loader(instances: list, alias2id) -> DataLoader:
|
177 |
+
"""
|
178 |
+
๋๋ ์ง ๋ฐ์ดํฐ๋ฅผ ๋ง์ถ๊ธฐ ์ํด ๊ฐ๊ณต
|
179 |
+
"""
|
180 |
+
data_list = []
|
181 |
+
|
182 |
+
for instance in instances:
|
183 |
+
seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
|
184 |
+
instance, alias2id)
|
185 |
+
css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_css(
|
186 |
+
seg_sents, candidate_mention_poses)
|
187 |
+
|
188 |
+
data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
|
189 |
+
cut_css, name_list_index))
|
190 |
+
|
191 |
+
data_loader = DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
|
192 |
+
|
193 |
+
return data_loader
|
194 |
+
|
195 |
+
|
196 |
+
def make_ner_input(text, chunk_size=500) -> list:
|
197 |
+
"""
|
198 |
+
๋ฌธ์ฅ์ New Lines ๊ธฐ์ค์ผ๋ก ๋๋์ด ์ค๋๋ค.
|
199 |
+
chunk size๋ณด๋ค ๋ฌธ์ฅ์ด ๊ธธ ๊ฒฝ์ฐ, ๋ง์ง๋ง ๋ฌธ์ฅ์ ๋ค์์ chunk size ๋งํผ ์ถ๊ฐํฉ๋๋ค.
|
200 |
+
"""
|
201 |
+
count_text = chunk_size
|
202 |
+
max_text = len(text)
|
203 |
+
newline_position = []
|
204 |
+
|
205 |
+
while count_text < max_text:
|
206 |
+
sentence = text[:count_text]
|
207 |
+
last_newline_position = sentence.rfind('\n')
|
208 |
+
newline_position.append(last_newline_position)
|
209 |
+
count_text = last_newline_position + chunk_size
|
210 |
+
|
211 |
+
split_sentences = []
|
212 |
+
start_num = 0
|
213 |
+
|
214 |
+
for _, num in enumerate(newline_position):
|
215 |
+
split_sentences.append(text[start_num:num])
|
216 |
+
start_num = num
|
217 |
+
|
218 |
+
if max_text % chunk_size != 0:
|
219 |
+
f_sentence = text[max_text-500:]
|
220 |
+
first_newline_position = max_text-500 + f_sentence.find('\n')
|
221 |
+
split_sentences.append(text[first_newline_position:])
|
222 |
+
|
223 |
+
return split_sentences
|
224 |
+
|
225 |
+
|
226 |
+
def making_script(text, speaker:list, instance_num:list) -> str:
|
227 |
+
"""
|
228 |
+
์คํฌ๋ฆฝํธ๋ฅผ ๋ง๋๋ ํจ์
|
229 |
+
"""
|
230 |
+
lines = text.splitlines()
|
231 |
+
for num, people in zip(instance_num, speaker):
|
232 |
+
lines[num] = f'{people}: {lines[num]}'
|
233 |
+
return lines
|
utils/load_model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
๋ชจ๋ธ๋ค ๋ถ๋ฌ์ค๋ ๋ชจ๋
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
# from .load_model import KCSN
|
6 |
+
# from .arguments import get_train_args
|
7 |
+
|
8 |
+
|
9 |
+
# args = get_train_args()
|
10 |
+
|
11 |
+
def load_ner(path ='model/NER.pth'):
|
12 |
+
"""
|
13 |
+
NER ๋ชจ๋ธ
|
14 |
+
"""
|
15 |
+
checkpoint = torch.load(path)
|
16 |
+
model = checkpoint['model']
|
17 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
18 |
+
|
19 |
+
return model, checkpoint
|
20 |
+
|
21 |
+
|
22 |
+
# def load_fs(path = 'model/FS.pth'):
|
23 |
+
# """
|
24 |
+
# Find Speaker ๋ชจ๋ธ
|
25 |
+
# """
|
26 |
+
# model = KCSN(args)
|
27 |
+
# checkpoint = torch.load(path)
|
28 |
+
# model.load_state_dict(checkpoint['model_state_dict'])
|
29 |
+
# return model, checkpoint
|
utils/ner_utils.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
NER ๋ชจ๋ธ์ ์ด์ฉํ์ฌ ์์
ํ๋ ์ฝ๋์
๋๋ค.
|
3 |
+
"""
|
4 |
+
import re
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from collections import Counter
|
8 |
+
|
9 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
def ner_tokenizer(text, max_seq_length, checkpoint):
|
12 |
+
"""
|
13 |
+
NER์ ์ํด ํ
์คํธ๋ฅผ ํ ํฐํํฉ๋๋ค.
|
14 |
+
Args:
|
15 |
+
sent: ์ฒ๋ฆฌํ๊ณ ์ ํ๋ ํ
์คํธ๋ฅผ ์
๋ ฅ๋ฐ์ต๋๋ค.
|
16 |
+
max_seq_length: BERT์ config์์ ์ฒ๋ฆฌ ๊ฐ๋ฅํ ์ต๋ ๋ฌธ์์ด ๊ธธ์ด๋ 512์
๋๋ค. ์ต๋ ๊ธธ์ด๋ฅผ ๋์ด์์ง ์๋๋ก, ํ
์คํธ ๊ธธ์ด๊ฐ 512๋ฅผ ๋์ด๊ฐ ๊ฒฝ์ฐ ์ฌ๋ฌ ๊ฐ์ ๋ฌธ์์ด๋ก ๋ถ๋ฆฌํฉ๋๋ค.
|
17 |
+
๋ฌธ๋งฅ ์ ๋ณด๋ฅผ ๊ณ ๋ คํ๋ฏ๋ก ๊ฐ๋ฅํ ๊ธด ๊ธธ์ด๋ก chunkingํ๋ ๊ฒ์ด ์ข์ ์ฑ๋ฅ์ ๋ณด์ฅํ ์ ์์ต๋๋ค.
|
18 |
+
checkpoint: NER ๋ชจ๋ธ์ ๋ํ ์ ๋ณด๋ฅผ ๋ถ๋ฌ๋ค์
๋๋ค.
|
19 |
+
Return:
|
20 |
+
ner_tokenizer_dict: ์๋ ์ธ ์์๋ฅผ ํฌํจํ ๋์
๋๋ฆฌ์
๋๋ค.
|
21 |
+
input_ids: ๊ฐ ํ ํฐ์ ๋ชจ๋ธ ๋์
๋๋ฆฌ์์์ ์์ด๋๊ฐ์
๋๋ค.
|
22 |
+
attention_mask: ๊ฐ ํ ํฐ์ ์ดํ ์
๋ง์คํฌ ํ์ฑํ ์ฌ๋ถ์
๋๋ค.
|
23 |
+
token_type_ids: ๊ฐ์ฒด๋ช
์ธ์ ๋ ํ ํฐ์ ๊ฒฝ์ฐ ๊ทธ ํ์
์ ์์ด๋(์ซ์ ์กฐํฉ)๋ฅผ ๋ฐํํฉ๋๋ค.
|
24 |
+
"""
|
25 |
+
#์ ์ฅ๋ ๋ชจ๋ธ์ ํ ํฌ๋์ด์ ๋ฅผ ๋ถ๋ฌ์ต๋๋ค.
|
26 |
+
tokenizer = checkpoint['tokenizer']
|
27 |
+
|
28 |
+
#๊ฐ๊ฐ ํจ๋ฉ, ๋ฌธ์ฅ ์์, ๋ฌธ์ฅ ๋์ ๋ํ๋ด๋ ํน๋ณํ ํ ํฐ๋ค์ ID ๊ฐ๋ค์ ๊ฐ์ ธ์ต๋๋ค.
|
29 |
+
pad_token_id = tokenizer.pad_token_id
|
30 |
+
cls_token_id = tokenizer.cls_token_id
|
31 |
+
sep_token_id = tokenizer.sep_token_id
|
32 |
+
|
33 |
+
#์ด์ ์์ ์ ์ ์ฅํ๋ ๋ณ์๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
|
34 |
+
pre_syllable = "_"
|
35 |
+
|
36 |
+
#ํ ํฌ๋์ด์ง๋ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ๋ค์ ์ด๊ธฐํํฉ๋๋ค.
|
37 |
+
input_ids = [pad_token_id] * (max_seq_length - 1)
|
38 |
+
attention_mask = [0] * (max_seq_length - 1)
|
39 |
+
token_type_ids = [0] * max_seq_length
|
40 |
+
|
41 |
+
#์
๋ ฅ๋ ํ
์คํธ๋ฅผ ์ต๋ ์ํ์ค ๊ธธ์ด์ ๋ง๊ฒ ์๋ผ๋
๋๋ค.
|
42 |
+
text = text[:max_seq_length-2]
|
43 |
+
|
44 |
+
#ํ
์คํธ์ ๊ฐ ์์ ์ ๋ํด ๋ฐ๋ณต๋ฌธ์ ์คํํฉ๋๋ค.
|
45 |
+
for i, syllable in enumerate(text):
|
46 |
+
if syllable == '_':
|
47 |
+
pre_syllable = syllable
|
48 |
+
if pre_syllable != "_":
|
49 |
+
syllable = '##' + syllable
|
50 |
+
pre_syllable = syllable
|
51 |
+
|
52 |
+
#ํ ํฐ์ ๋ชจ๋ธ์ ๋จ์ด ์ฌ์ ์ ์๋ ID ๊ฐ์ผ๋ก ๋ณํํ์ฌ input_ids ๋ฆฌ์คํธ์ ์ ์ฅํฉ๋๋ค.
|
53 |
+
input_ids[i] = tokenizer.convert_tokens_to_ids(syllable)
|
54 |
+
#ํด๋น ์์น์ ํ ํฐ์ ๋ํ ์ดํ
์
๋ง์คํฌ๋ฅผ ํ์ฑํํฉ๋๋ค.
|
55 |
+
attention_mask[i] = 1
|
56 |
+
|
57 |
+
#์
๋ ฅ ์ํ์ค์ ์์์๋ cls_token_id๋ฅผ, ๋์๋ sep_token_id๋ฅผ ์ถ๊ฐํฉ๋๋ค.
|
58 |
+
input_ids = [cls_token_id] + input_ids[:-1] + [sep_token_id]
|
59 |
+
#์ดํ
์
๋ง์คํฌ๋ ์์๊ณผ ๋ ํ ํฐ์ ๊ณ ๋ คํ์ฌ ์์ ํฉ๋๋ค.
|
60 |
+
attention_mask = [1] + attention_mask[:-1] + [1]
|
61 |
+
|
62 |
+
ner_tokenizer_dict = {"input_ids": input_ids,
|
63 |
+
"attention_mask": attention_mask,
|
64 |
+
"token_type_ids": token_type_ids}
|
65 |
+
|
66 |
+
return ner_tokenizer_dict
|
67 |
+
|
68 |
+
def get_ner_predictions(text, checkpoint):
|
69 |
+
"""
|
70 |
+
ํ ํฐํํ ๋ฌธ์ฅ(tokenized_sent)๊ณผ ์์ธกํ ํ๊ทธ(pred_tags) ๊ฐ์ ๋ง๋๋ ํจ์์
๋๋ค.
|
71 |
+
Args:
|
72 |
+
text: NER ์์ธก์ ํ์๋ก ํ๋ ํ
์คํธ๋ฅผ ์
๋ ฅํฉ๋๋ค.
|
73 |
+
checkpoint: ์ ์ฅํ ๋ชจ๋ธ์ ๋ถ๋ฌ๋ค์
๋๋ค.
|
74 |
+
Returns:
|
75 |
+
tokenized_sent: ๋ชจ๋ธ ์
๋ ฅ์ ์ํ ํ ํฐํ๋ ๋ฌธ์ฅ ์ ๋ณด์
๋๋ค.
|
76 |
+
pred_tags: ๊ฐ ํ ํฐ์ ๋ํ ์์ธก๋ ํ๊ทธ๋ค์ ํฌํจํฉ๋๋ค.
|
77 |
+
"""
|
78 |
+
#์ ์ฅํ ๋ชจ๋ธ์ ๋ถ๋ฌ๋ค์
๋๋ค.
|
79 |
+
model = checkpoint['model']
|
80 |
+
#ํ๊ทธ์ ํด๋น ํ๊ทธ์ ID ๋งคํ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
|
81 |
+
tag2id = checkpoint['tag2id']
|
82 |
+
model.to(device)
|
83 |
+
#์
๋ ฅ๋ ํ
์คํธ์์ ๊ณต๋ฐฑ์ ์ธ๋์ค์ฝ์ด(_)๋ก ๋์ฒดํฉ๋๋ค.
|
84 |
+
text = text.replace(' ', '_')
|
85 |
+
|
86 |
+
#์์ธก๊ฐ๊ณผ ์ค์ ๋ผ๋ฒจ์ ์ ์ฅํ ๋น ๋ฆฌ์คํธ๋ฅผ ์์ฑํฉ๋๋ค.
|
87 |
+
predictions, true_labels = [], []
|
88 |
+
|
89 |
+
#ner_tokenizer ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ํ
์คํธ๋ฅผ ํ ํฐํํฉ๋๋ค.
|
90 |
+
tokenized_sent = ner_tokenizer(text, len(text) + 2, checkpoint)
|
91 |
+
|
92 |
+
#ํ ํฐํ๋ ๊ฒฐ๊ณผ๋ฅผ ํ ๋๋ก ํ
์๋ก ๋ณํํ์ฌ ๋ชจ๋ธ ์
๋ ฅ ํ์์ ๋ง๊ฒ ์ค๋นํฉ๋๋ค.
|
93 |
+
input_ids = torch.tensor(
|
94 |
+
tokenized_sent['input_ids']).unsqueeze(0).to(device)
|
95 |
+
attention_mask = torch.tensor(
|
96 |
+
tokenized_sent['attention_mask']).unsqueeze(0).to(device)
|
97 |
+
token_type_ids = torch.tensor(
|
98 |
+
tokenized_sent['token_type_ids']).unsqueeze(0).to(device)
|
99 |
+
|
100 |
+
#๊ทธ๋๋์ธํธ ๊ณ์ฐ์ ์ํํ์ง ์๊ธฐ ์ํด torch.no_grad() ์ปจํ
์คํธ ๋ด์์ ๋ค์์ ์คํํฉ๋๋ค. (eval ์์ญ์ด๊ธฐ ๋๋ฌธ์ ํ์ต์ ํ์ง ์์ต๋๋ค)
|
101 |
+
with torch.no_grad():
|
102 |
+
outputs = model(
|
103 |
+
input_ids=input_ids,
|
104 |
+
attention_mask=attention_mask,
|
105 |
+
token_type_ids=token_type_ids)
|
106 |
+
|
107 |
+
#๋ชจ๋ธ ์ถ๋ ฅ์์ ๋ก์ง ๊ฐ์ ๊ฐ์ ธ์ Numpy๊ฐ์ผ๋ก ๋ณํํ๊ณ , ๋ผ๋ฒจ ID๋ค์ CPU ์์ NumPy ๋ฐฐ์ด๋ก ๊ฐ์ ธ์ต๋๋ค.
|
108 |
+
logits = outputs['logits']
|
109 |
+
logits = logits.detach().cpu().numpy()
|
110 |
+
label_ids = token_type_ids.cpu().numpy()
|
111 |
+
|
112 |
+
#์์ธก๋ ๋ผ๋ฒจ ๊ฐ์ ๊ฐ์ ธ์์ ๋ฆฌ์คํธ์ ์ถ๊ฐํฉ๋๋ค.
|
113 |
+
predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
|
114 |
+
#์ค์ ๋ผ๋ฒจ์ ๋ฆฌ์คํธ์ ์ถ๊ฐํฉ๋๋ค.
|
115 |
+
true_labels.append(label_ids)
|
116 |
+
|
117 |
+
#์์ธก๋ ๋ผ๋ฒจ ID๋ฅผ ์ค์ ํ๊ทธ๋ก ๋ณํํฉ๋๋ค.
|
118 |
+
pred_tags = [list(tag2id.keys())[p_i] for p in predictions for p_i in p]
|
119 |
+
|
120 |
+
return tokenized_sent, pred_tags
|
121 |
+
|
122 |
+
|
123 |
+
def ner_inference(tokenized_sent, pred_tags, checkpoint, name_len=5) -> list:
|
124 |
+
"""
|
125 |
+
NER์ ์คํํ๊ณ , ์ด๋ฆ๊ณผ ์๊ฐ ๋ฐ ๊ณต๊ฐ ์ ๋ณด๋ฅผ ์ถ์ถํฉ๋๋ค.
|
126 |
+
Args:
|
127 |
+
tokenized_sent: ํ ํฐํ๋ ๋ฌธ์ฅ์ด ์ ์ฅ๋ ๋ฆฌ์คํธ
|
128 |
+
pred_tags: ๊ฐ ํ ํฐ์ ๋ํ ์์ธก ํ๊ทธ๊ฐ (NER ๊ฒฐ๊ณผ)
|
129 |
+
checkpoint: ์ ์ฅํด๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ด
|
130 |
+
name_len: ๋ ์ ํํ ์ด๋ฆ ์ธ์์ ์ํด ์๋ค๋ก ๋ช ๊ฐ์ ์์ ์ ๋ ๊ฒํ ํ ์ง ์ง์ ํฉ๋๋ค.
|
131 |
+
Returns:
|
132 |
+
namelist: ์ถ์ถํ ์ด๋ฆ(๋ณ์นญ ํฌํจ) ๋ฆฌ์คํธ์
๋๋ค. ํ์ฒ๋ฆฌ๋ฅผ ํตํด
|
133 |
+
scene: ์ถ์ถํ ์ฅ์ ์๊ฐ ์ฌ์ ์
๋๋ค.
|
134 |
+
"""
|
135 |
+
name_list = []
|
136 |
+
speaker = ''
|
137 |
+
tokenizer = checkpoint['tokenizer']
|
138 |
+
scene = {'์ฅ์': [], '์๊ฐ': []}
|
139 |
+
target = ''
|
140 |
+
c_tag = None
|
141 |
+
|
142 |
+
for i, tag in enumerate(pred_tags):
|
143 |
+
token = tokenizer.convert_ids_to_tokens(tokenized_sent['input_ids'][i]).replace('#', '')
|
144 |
+
if 'PER' in tag:
|
145 |
+
if 'B' in tag and speaker != '':
|
146 |
+
name_list.append(speaker)
|
147 |
+
speaker = ''
|
148 |
+
speaker += token
|
149 |
+
|
150 |
+
elif speaker != '' and tag != pred_tags[i-1]:
|
151 |
+
if speaker in name_list:
|
152 |
+
name_list.append(speaker)
|
153 |
+
else:
|
154 |
+
tmp = speaker
|
155 |
+
found_name = False
|
156 |
+
# print(f'{speaker}์ ์๋ฌธ์ด ์๊ฒจ ํ์ธํด๋ด
๋๋ค.')
|
157 |
+
for j in range(name_len):
|
158 |
+
if i + j < len(tokenized_sent['input_ids']):
|
159 |
+
token = tokenizer.convert_ids_to_tokens(
|
160 |
+
tokenized_sent['input_ids'][i+j]).replace('#', '')
|
161 |
+
tmp += token
|
162 |
+
# print(f'{speaker} ๋ค๋ก ๋์จ {j} ๋ฒ์งธ ๊น์ง ํ์ธํ๊ฒฐ๊ณผ, {tmp} ์
๋๋ค')
|
163 |
+
if tmp in name_list:
|
164 |
+
name_list.append(tmp)
|
165 |
+
found_name = True
|
166 |
+
# print(f'๋ช
๋จ์ {tmp} ๊ฐ ์กด์ฌํ์ฌ, {speaker} ๋์ ์ถ๊ฐํ์์ต๋๋ค.')
|
167 |
+
break
|
168 |
+
|
169 |
+
if not found_name:
|
170 |
+
name_list.append(speaker)
|
171 |
+
# print(f'์ฐพ์ง ๋ชปํ์ฌ {speaker} ๋ฅผ ์ถ๊ฐํ์์ต๋๋ค.')
|
172 |
+
speaker = ''
|
173 |
+
|
174 |
+
elif tag != 'O':
|
175 |
+
if tag.startswith('B'):
|
176 |
+
if c_tag in ['TIM', 'DAT']:
|
177 |
+
scene['์๊ฐ'].append(target)
|
178 |
+
elif c_tag =='LOC':
|
179 |
+
scene['์ฅ์'].append(target)
|
180 |
+
c_tag = tag[2:]
|
181 |
+
target = token
|
182 |
+
else:
|
183 |
+
target += token.replace('_', ' ')
|
184 |
+
|
185 |
+
return name_list, scene
|
186 |
+
|
187 |
+
|
188 |
+
def make_name_list(ner_inputs, checkpoint):
|
189 |
+
"""
|
190 |
+
๋ฌธ์ฅ๋ค์ NER ๋๋ ค์ Name List ๋ง๋ค๊ธฐ.
|
191 |
+
"""
|
192 |
+
name_list = []
|
193 |
+
times = []
|
194 |
+
places = []
|
195 |
+
|
196 |
+
for ner_input in ner_inputs:
|
197 |
+
tokenized_sent, pred_tags = get_ner_predictions(ner_input, checkpoint)
|
198 |
+
names, scene = ner_inference(tokenized_sent, pred_tags, checkpoint)
|
199 |
+
name_list.extend(names)
|
200 |
+
times.extend(scene['์๊ฐ'])
|
201 |
+
places.extend(scene['์ฅ์'])
|
202 |
+
|
203 |
+
return name_list, times, places
|
204 |
+
|
205 |
+
|
206 |
+
def show_name_list(name_list):
|
207 |
+
"""
|
208 |
+
์ฌ์ฉ์ ์นํ์ ์ผ๋ก ๋ค์๋ฆฌ์คํธ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
|
209 |
+
Arg:
|
210 |
+
name_list: ์ถ์ถํ ์ด๋ฆ ๋ฆฌ์คํธ
|
211 |
+
Return:
|
212 |
+
name: ๋์ผํ ์ด๋ฆ์ด ๋ช ๋ฒ ๋ฑ์ฅํ๋์ง ํ์๋ฅผ ํจ๊ป ์ ๊ณตํฉ๋๋ค.
|
213 |
+
"""
|
214 |
+
name = dict(Counter(name_list))
|
215 |
+
|
216 |
+
return name
|
217 |
+
|
218 |
+
|
219 |
+
def compare_strings(str1, str2):
|
220 |
+
"""
|
221 |
+
ner๋ก ์ถ์ถํ ์ธ๋ช
์ ํ์ฒ๋ฆฌํ๋ ์ฝ๋์
๋๋ค.
|
222 |
+
๋น๊ตํ ๋ ๋ฌธ์์ด์ ๊ธธ์ด๊ฐ ๋ค๋ฅผ ๊ฒฝ์ฐ, ๋ ์งง์ ๋ฌธ์์ด์ด ๋ ๊ธด ๋ฌธ์์ด์ ํฌํจ๋๋์ง ํ์ธํฉ๋๋ค.
|
223 |
+
๋น๊ตํ ๋ ๋ฌธ์์ด์ ๊ธธ์ด๊ฐ ๊ฐ์ ๊ฒฝ์ฐ, ๊ฒน์น๋ ๋ถ๋ถ์ด 2๊ธ์ ์ด์์ผ ๊ฒฝ์ฐ ๊ฐ์ ์ด๋ฆ์ผ๋ก ์ง์ ํฉ๋๋ค.
|
224 |
+
์ด ํจ์์ ์๋์ combine_similar_names๋ฅผ ํจ๊ป ์คํํ๋ฉด, 'ํ๋ค์ '๊ณผ '๋ค์ ์ด', '๋ค์ ์ด๊ฐ' ๋ฑ์ ๋ชจ๋ ํ๋์ ์ธ๋ฌผ๋ก ๋ฌถ์ ์ ์์ต๋๋ค.
|
225 |
+
|
226 |
+
Args: ๋น๊ตํ๋ ค๋ ๋ ๋ฌธ์์ด
|
227 |
+
Return: ๋ ๋ฌธ์์ด์ด ๊ฐ์ ์ด๋ฆ์ผ๋ก ํ๋จ๋ ๊ฒฝ์ฐ True, ์๋ ๊ฒฝ์ฐ False
|
228 |
+
"""
|
229 |
+
if len(str1) != len(str2):
|
230 |
+
# ๋ ์งง์ ๋ฌธ์์ด์ด ๋ ๊ธด ๋ฌธ์์ด์ ํฌํจ๋๋์ง ํ์ธ
|
231 |
+
shorter, longer = (str1, str2) if len(str1) < len(str2) else (str2, str1)
|
232 |
+
if shorter in longer:
|
233 |
+
return True
|
234 |
+
else:
|
235 |
+
same_part = []
|
236 |
+
for i in range(len(str1)):
|
237 |
+
if str1[i] in str2:
|
238 |
+
same_part += str1[i]
|
239 |
+
continue
|
240 |
+
else:
|
241 |
+
break
|
242 |
+
if len(same_part) >= 2:
|
243 |
+
return True
|
244 |
+
|
245 |
+
return False
|
246 |
+
|
247 |
+
def combine_similar_names(names_dict):
|
248 |
+
"""
|
249 |
+
compare_strings ํจ์๋ฅผ ๋ฐํ์ผ๋ก ์ ์ฌํ ์ด๋ฆ์ ํจ๊ป ๋ฌถ์ต๋๋ค.
|
250 |
+
2๊ธ์๋ ์ด๋ฆ์ผ ํ๋ฅ ์ด ๋์ผ๋ ๊ธฐ์ค์ ์ผ๋ก ์ง์ ํฉ๋๋ค.
|
251 |
+
"""
|
252 |
+
names = names_dict.keys()
|
253 |
+
similar_groups = [[name] for name in names if len(name) == 2]
|
254 |
+
idx = 0
|
255 |
+
# print(similar_groups, '\n',idx)
|
256 |
+
|
257 |
+
for name in names:
|
258 |
+
found = False
|
259 |
+
for group in similar_groups:
|
260 |
+
idx += 1
|
261 |
+
for item in group:
|
262 |
+
if compare_strings(name, item) and len(name)>1:
|
263 |
+
found = True
|
264 |
+
cleaned_text = re.sub(r'(์|์ด)$', '', item)
|
265 |
+
if len(name) == len(item):
|
266 |
+
same_part = ''
|
267 |
+
# ์์ ํ ์ผ์นํ๋ ๋ถ๋ถ์ด ์๋์ง ํ์ธ
|
268 |
+
for i in range(len(name)):
|
269 |
+
if name[i] in item:
|
270 |
+
same_part += name[i]
|
271 |
+
if same_part not in group and cleaned_text not in group:
|
272 |
+
group.append(cleaned_text)
|
273 |
+
# print(similar_groups, '\n',idx, '๋ฌธ์์ด์ ๊ธธ์ด๊ฐ ๊ฐ์ ๋')
|
274 |
+
else:
|
275 |
+
group.append(name)
|
276 |
+
# print(similar_groups, '\n',idx, '๋ฌธ์์ด์ ๊ธธ์ด๊ฐ ๋ค๋ฅผ ๋')
|
277 |
+
break
|
278 |
+
if found:
|
279 |
+
break
|
280 |
+
if not found:
|
281 |
+
similar_groups.append([name])
|
282 |
+
|
283 |
+
updated_names = {tuple(name for name in group if len(name) > 1): counts for group, counts in (
|
284 |
+
(group, sum(names_dict[name] for name in group if name != '')) for group in similar_groups)
|
285 |
+
if len([name for name in group if len(name) > 1]) > 0}
|
286 |
+
|
287 |
+
return updated_names
|
288 |
+
|
289 |
+
def convert_name2codename(codename2name, text):
|
290 |
+
"""RE๋ฅผ ์ด์ฉํ์ฌ ์ด๋ฆ์ ์ฝ๋๋ค์์ผ๋ก ๋ณ๊ฒฝํฉ๋๋ค. ์ด๋ ๊ฐ ์ฝ๋๋ค์์ ๋ฒํธ๋ ๋น๋์ ๊ธฐ์ค ๋ด๋ฆผ์ฐจ์ํ ๊ฒฐ๊ณผ์
๋๋ค."""
|
291 |
+
import re
|
292 |
+
for n_list in codename2name.values():
|
293 |
+
n_list.sort(key=lambda x:(len(x), x), reverse=True)
|
294 |
+
|
295 |
+
for codename, n_list in codename2name.items():
|
296 |
+
for subname in n_list:
|
297 |
+
text = re.sub(subname, codename, text)
|
298 |
+
|
299 |
+
return text
|
300 |
+
|
301 |
+
|
302 |
+
def convert_codename2name(codename2name, text):
|
303 |
+
"""์ฝ๋๋ค์์ ์ด๋ฆ์ผ๋ก ๋ณ๊ฒฝํด์ค๋๋ค."""
|
304 |
+
outputs = []
|
305 |
+
for i in text:
|
306 |
+
try:
|
307 |
+
outputs.append(codename2name[i][0])
|
308 |
+
except:
|
309 |
+
outputs.append('์ ์ ์์')
|
310 |
+
|
311 |
+
return outputs
|
utils/train_model.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author:
|
3 |
+
"""
|
4 |
+
import re
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as functional
|
7 |
+
import torch
|
8 |
+
from transformers import AutoModel
|
9 |
+
import torch.autograd as autograd
|
10 |
+
|
11 |
+
def get_nonlinear(nonlinear):
|
12 |
+
"""
|
13 |
+
Activation function.
|
14 |
+
"""
|
15 |
+
nonlinear_dict = {'relu': nn.ReLU(), 'tanh': nn.Tanh(),
|
16 |
+
'sigmoid': nn.Sigmoid(), 'softmax': nn.Softmax(dim=-1)}
|
17 |
+
try:
|
18 |
+
return nonlinear_dict[nonlinear]
|
19 |
+
except:
|
20 |
+
raise ValueError('not a valid nonlinear type!')
|
21 |
+
|
22 |
+
|
23 |
+
class SeqPooling(nn.Module):
|
24 |
+
"""
|
25 |
+
Sequence pooling module.
|
26 |
+
|
27 |
+
Can do max-pooling, mean-pooling and attentive-pooling on a list of sequences of different lengths.
|
28 |
+
"""
|
29 |
+
def __init__(self, pooling_type, hidden_dim):
|
30 |
+
super(SeqPooling, self).__init__()
|
31 |
+
self.pooling_type = pooling_type
|
32 |
+
self.hidden_dim = hidden_dim
|
33 |
+
if pooling_type == 'attentive_pooling':
|
34 |
+
self.query_vec = nn.parameter.Parameter(torch.randn(hidden_dim))
|
35 |
+
|
36 |
+
def max_pool(self, seq):
|
37 |
+
return seq.max(0)[0]
|
38 |
+
|
39 |
+
def mean_pool(self, seq):
|
40 |
+
return seq.mean(0)
|
41 |
+
|
42 |
+
def attn_pool(self, seq):
|
43 |
+
attn_score = torch.mm(seq, self.query_vec.view(-1, 1)).view(-1)
|
44 |
+
attn_w = nn.Softmax(dim=0)(attn_score)
|
45 |
+
weighted_sum = torch.mm(attn_w.view(1, -1), seq).view(-1)
|
46 |
+
return weighted_sum
|
47 |
+
|
48 |
+
def forward(self, batch_seq):
|
49 |
+
pooling_fn = {'max_pooling': self.max_pool,
|
50 |
+
'mean_pooling': self.mean_pool,
|
51 |
+
'attentive_pooling': self.attn_pool}
|
52 |
+
pooled_seq = [pooling_fn[self.pooling_type](seq) for seq in batch_seq]
|
53 |
+
return torch.stack(pooled_seq, dim=0)
|
54 |
+
|
55 |
+
|
56 |
+
class MLP_Scorer(nn.Module):
|
57 |
+
"""
|
58 |
+
MLP scorer module.
|
59 |
+
|
60 |
+
A perceptron with two layers.
|
61 |
+
"""
|
62 |
+
def __init__(self, args, classifier_input_size):
|
63 |
+
super(MLP_Scorer, self).__init__()
|
64 |
+
self.scorer = nn.ModuleList()
|
65 |
+
self.scorer.append(nn.Linear(classifier_input_size, args.classifier_intermediate_dim))
|
66 |
+
self.scorer.append(nn.Linear(args.classifier_intermediate_dim, 1))
|
67 |
+
self.nonlinear = get_nonlinear(args.nonlinear_type)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
for model in self.scorer:
|
71 |
+
x = self.nonlinear(model(x))
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class KCSN(nn.Module):
|
76 |
+
"""
|
77 |
+
Candidate Scoring Network.
|
78 |
+
|
79 |
+
It's built on BERT with an MLP and other simple components.
|
80 |
+
"""
|
81 |
+
def __init__(self, args):
|
82 |
+
super(KCSN, self).__init__()
|
83 |
+
self.args = args
|
84 |
+
self.bert_model = AutoModel.from_pretrained(args.bert_pretrained_dir)
|
85 |
+
self.pooling = SeqPooling(args.pooling_type, self.bert_model.config.hidden_size)
|
86 |
+
self.mlp_scorer = MLP_Scorer(args, self.bert_model.config.hidden_size * 3)
|
87 |
+
self.dropout = nn.Dropout(args.dropout)
|
88 |
+
|
89 |
+
def forward(self, features, sent_char_lens, mention_poses, quote_idxes, true_index, device, tokens_list, cut_css):
|
90 |
+
# encoding
|
91 |
+
qs_hid = []
|
92 |
+
ctx_hid = []
|
93 |
+
cdd_hid = []
|
94 |
+
|
95 |
+
unk_loc_li = []
|
96 |
+
unk_loc = 0
|
97 |
+
|
98 |
+
for i, (cdd_sent_char_lens, cdd_mention_pos, cdd_quote_idx) in enumerate(
|
99 |
+
zip(sent_char_lens, mention_poses, quote_idxes)):
|
100 |
+
unk_loc = unk_loc + 1
|
101 |
+
|
102 |
+
bert_output = self.bert_model(
|
103 |
+
torch.tensor([features[i].input_ids], dtype=torch.long).to(device),
|
104 |
+
token_type_ids=None,
|
105 |
+
attention_mask=torch.tensor([features[i].input_mask], dtype=torch.long).to(device)
|
106 |
+
)
|
107 |
+
|
108 |
+
modified_list = [s.replace('#', '') for s in tokens_list[i]]
|
109 |
+
cnt = 1
|
110 |
+
verify = 0
|
111 |
+
num_check = 0
|
112 |
+
num_vid = -999
|
113 |
+
accum_char_len = [0]
|
114 |
+
|
115 |
+
for idx, txt in enumerate(cut_css[i]):
|
116 |
+
result_string = ''.join(txt)
|
117 |
+
replace_dict = {']': r'\]', '[': r'\[', '?': r'\?', '-': r'\-', '!': r'\!'}
|
118 |
+
string_processing = result_string[-7:].translate(str.maketrans(replace_dict))
|
119 |
+
pattern = re.compile(rf'[{string_processing}]')
|
120 |
+
cnt = 1
|
121 |
+
|
122 |
+
if num_check == 1000:
|
123 |
+
accum_char_len.append(num_vid)
|
124 |
+
|
125 |
+
num_check = 1000
|
126 |
+
|
127 |
+
for string in modified_list:
|
128 |
+
string_nospace = string.replace(' ','')
|
129 |
+
if len(accum_char_len) > idx + 1:
|
130 |
+
continue
|
131 |
+
|
132 |
+
for letter in string_nospace:
|
133 |
+
match_result = pattern.match(letter)
|
134 |
+
if match_result:
|
135 |
+
verify += 1
|
136 |
+
if verify == len(result_string[-7:]):
|
137 |
+
if cnt > accum_char_len[-1]:
|
138 |
+
accum_char_len.append(cnt)
|
139 |
+
verify = 0
|
140 |
+
num_check = len(accum_char_len)
|
141 |
+
else:
|
142 |
+
verify = 0
|
143 |
+
cnt = cnt + 1
|
144 |
+
|
145 |
+
if num_check == 1000:
|
146 |
+
accum_char_len.append(num_vid)
|
147 |
+
|
148 |
+
if -999 in accum_char_len:
|
149 |
+
unk_loc_li.append(unk_loc)
|
150 |
+
continue
|
151 |
+
|
152 |
+
CSS_hid = bert_output['last_hidden_state'][0][1:sum(cdd_sent_char_lens) + 1].to(device)
|
153 |
+
|
154 |
+
qs_hid.append(CSS_hid[accum_char_len[cdd_quote_idx]:accum_char_len[cdd_quote_idx + 1]])
|
155 |
+
|
156 |
+
## ๋ฐํ์ ๋ถ๋ถ ์ฐพ์์ - bert tokenizer ๋ ๋ถ๋ถ์ ์ธ๋ฑ์ฑ ํ๋ ๋ถ๋ถ
|
157 |
+
cnt = 1
|
158 |
+
cdd_mention_pos_bert_li = []
|
159 |
+
cdd_mention_pos_unk = []
|
160 |
+
name = cut_css[i][cdd_mention_pos[0]][cdd_mention_pos[3]]
|
161 |
+
|
162 |
+
# extract only name
|
163 |
+
# ์ด๋ฆ๋ง ์ถ์ถ
|
164 |
+
cdd_pattern = re.compile(r'&C[0-5][0-9]&')
|
165 |
+
name_process = cdd_pattern.search(name)
|
166 |
+
|
167 |
+
# find candidate location in bert output
|
168 |
+
# ๋ฒํธ ๊ฒฐ๊ณผ์์ ๋ฐํ์ ์์น๋ฅผ ์ฐพ์ต๋๋ค
|
169 |
+
pattern_unk = re.compile(r'[\[UNK\]]')
|
170 |
+
|
171 |
+
# ์ด ๋ถ๋ถ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ๊ฒ ๋๋ฉด, ๋ ์ด์ ๋์ด๊ฐ์ง ์๋๋ก ํ๋ ์ฝ๋ ์
๋๋ค.
|
172 |
+
if len(accum_char_len) < cdd_mention_pos[0]+1:
|
173 |
+
maxx_len = accum_char_len[len(accum_char_len)-1]
|
174 |
+
elif len(accum_char_len) == cdd_mention_pos[0]+1:
|
175 |
+
maxx_len = accum_char_len[-1] + 1000
|
176 |
+
else:
|
177 |
+
maxx_len = accum_char_len[cdd_mention_pos[0]+1]
|
178 |
+
|
179 |
+
# ํฌํจ๋๋ ๋ฐํ์๋ฅผ ์ฐพ๊ธฐ ์ํด.
|
180 |
+
start_name = None
|
181 |
+
name_match = '&'
|
182 |
+
for string in modified_list:
|
183 |
+
string_nospace = string.replace(' ','')
|
184 |
+
for letter in string_nospace:
|
185 |
+
match_result_unk = pattern_unk.match(letter)
|
186 |
+
if match_result_unk:
|
187 |
+
cdd_mention_pos_unk.append(cnt)
|
188 |
+
if start_name is True:
|
189 |
+
name_match += letter
|
190 |
+
if (name_match == name_process.group(0) or letter == '&') and len(
|
191 |
+
cdd_mention_pos_bert_li) < 3 and maxx_len > cnt >= accum_char_len[
|
192 |
+
cdd_mention_pos[0]]: # ๋ง์ฝ & ๊ฐ ํฌํจ๋์ด ์์ ๊ฒฝ์ฐ์ ์ฌ๋์ผ๋ก ์ถ์ถ
|
193 |
+
start_name = True # ๋งค์นญ์ด ๋๋ฉด, 1์ ๋ํฉ๋๋ค.
|
194 |
+
if len(cdd_mention_pos_bert_li) == 1 and name_match != name_process.group(0): # ๋ง์ฝ &๊ฐ ๋๋ฒ์งธ๋ก ๋์ค๊ณ , ๋งค์นญ์ด ์๋ ๊ฒฝ์ฐ
|
195 |
+
start_name = None
|
196 |
+
name_match = '&'
|
197 |
+
cdd_mention_pos_bert_li = []
|
198 |
+
elif name_match == name_process.group(0): # ๋๋ฒ์งธ ์ถ๊ฐ
|
199 |
+
cdd_mention_pos_bert_li.append(cnt)
|
200 |
+
start_name = None
|
201 |
+
name_match = '&'
|
202 |
+
else:
|
203 |
+
cdd_mention_pos_bert_li.append(cnt-1)
|
204 |
+
cnt += 1
|
205 |
+
|
206 |
+
if len(cdd_mention_pos_bert_li) == 0 & len(cdd_mention_pos_unk) != 0:
|
207 |
+
cdd_mention_pos_bert_li.extend([cdd_mention_pos_unk[0], cdd_mention_pos_unk[0]+1])
|
208 |
+
elif len(cdd_mention_pos_bert_li) != 2:
|
209 |
+
cdd_mention_pos_bert_li = []
|
210 |
+
cdd_mention_pos_bert_li.extend([int(cdd_mention_pos[1] * accum_char_len[-1]/sum(
|
211 |
+
cdd_sent_char_lens)), int(cdd_mention_pos[2] * accum_char_len[-1]/sum(
|
212 |
+
cdd_sent_char_lens))])
|
213 |
+
if cdd_mention_pos_bert_li[0] == cdd_mention_pos_bert_li[1]:
|
214 |
+
cdd_mention_pos_bert_li[1] = cdd_mention_pos_bert_li[1]+1
|
215 |
+
|
216 |
+
# ctx ๊ฒฐ์ ํ๋ ์ฝ๋. candidate ์ฃผ๋ณ ์ ๋ณด ์ถ์ถ
|
217 |
+
# ํ๋์ผ ๊ฒฝ์ฐ์๋ ์ ์ฒด ๋ถ๋ถ์ ๊ฐ์ ธ์จ๋ค.
|
218 |
+
if len(cdd_sent_char_lens) == 1:
|
219 |
+
ctx_hid.append(torch.zeros(1, CSS_hid.size(1)).to(device))
|
220 |
+
|
221 |
+
# ๋ง์ฝ ์์ ๋ฐํ์๊ฐ ์์ ๊ฒฝ์ฐ์ ์ ๋ฌธ์ฅ๋ถํฐ, ๋ง์ง๋ง(์ธ์ฉ๋ฌธ) ์ ๊น์ง ๊ฐ์ ธ์จ๋ค.
|
222 |
+
elif cdd_mention_pos[0] == 0:
|
223 |
+
ctx_hid.append(CSS_hid[:accum_char_len[-2]])
|
224 |
+
|
225 |
+
# ๋ง์ง๋ง์ผ๋ก ๋ฐํ์๊ฐ ๋ค์ ์์ ๊ฒฝ์ฐ์๋ ๋๋ฒ์งธ ๋ถํฐ ๋๊น์ง ๊ฐ์ ธ์จ๋ค.
|
226 |
+
else:
|
227 |
+
ctx_hid.append(CSS_hid[accum_char_len[1]:])
|
228 |
+
|
229 |
+
cdd_mention_pos_bert = (cdd_mention_pos[0], cdd_mention_pos_bert_li[0],
|
230 |
+
cdd_mention_pos_bert_li[1])
|
231 |
+
cdd_hid.append(CSS_hid[cdd_mention_pos_bert[1]:cdd_mention_pos_bert[2]])
|
232 |
+
|
233 |
+
# pooling
|
234 |
+
if not qs_hid:
|
235 |
+
scores = '1'
|
236 |
+
scores_false = 1
|
237 |
+
scores_true = 1
|
238 |
+
return scores, scores_false, scores_true
|
239 |
+
|
240 |
+
qs_rep = self.pooling(qs_hid).to(device)
|
241 |
+
ctx_rep = self.pooling(ctx_hid).to(device)
|
242 |
+
cdd_rep = self.pooling(cdd_hid).to(device)
|
243 |
+
|
244 |
+
# concatenate
|
245 |
+
feature_vector = torch.cat([qs_rep, ctx_rep, cdd_rep], dim=-1).to(device)
|
246 |
+
|
247 |
+
# dropout
|
248 |
+
feature_vector = self.dropout(feature_vector).to(device)
|
249 |
+
|
250 |
+
# scoring
|
251 |
+
scores = self.mlp_scorer(feature_vector).view(-1).to(device)
|
252 |
+
|
253 |
+
for i in unk_loc_li:
|
254 |
+
# ์ถ๊ฐํ ์์
|
255 |
+
new_element = torch.tensor([-0.9000], requires_grad=True).to(device)
|
256 |
+
# ํน์ ์ธ๋ฑ์ค์ ์์๋ฅผ ์ถ๊ฐํ๊ธฐ ์ํด torch.cat()๊ณผ ์ฌ๋ผ์ด์ฑ์ ์ฌ์ฉํฉ๋๋ค.
|
257 |
+
index_to_insert = i - 1
|
258 |
+
scores = torch.cat((scores[:index_to_insert], new_element, scores[index_to_insert:]),
|
259 |
+
dim=0).to(device)
|
260 |
+
|
261 |
+
scores_false = [scores[i] for i in range(scores.size(0)) if i != true_index]
|
262 |
+
scores_true = [scores[true_index] for i in range(scores.size(0) - 1)]
|
263 |
+
|
264 |
+
return scores, scores_false, scores_true
|
web/confirm.html
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="ko">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<title>Spakers in Text</title>
|
8 |
+
<link rel="stylesheet" href="{{ url_for('static', path='css/put.css') }}">
|
9 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
10 |
+
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
|
11 |
+
</head>
|
12 |
+
|
13 |
+
<body>
|
14 |
+
<!-- jQuery ๋ฐ Bootstrap JavaScript ๋ก๋ -->
|
15 |
+
<script src="https://code.jquery.com/jquery-3.6.4.min.js"
|
16 |
+
integrity="sha256-oP6HI/tZ1aS9sz3Jr4+6zqbc9BE/l6fLx+Vz2I+H/GL4ZiI/Z5L3hMv8w3yXdBi"
|
17 |
+
crossorigin="anonymous"></script>
|
18 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
|
19 |
+
integrity="sha384-3ziFjNlAXja/Yb0M7y2BmFvR3s09gRPbrCm0lF+SvL4uIboD5lv3U3BdD7dW7Y3"
|
20 |
+
crossorigin="anonymous"></script>
|
21 |
+
|
22 |
+
<!-- ๋ชจ๋ฌ ํ์ ์คํฌ๋ฆฝํธ -->
|
23 |
+
<script>
|
24 |
+
// JavaScript ์์ : ๋ฒํผ์ด ํด๋ฆญ๋๋ฉด ๋ชจ๋ฌ ํ์
|
25 |
+
$(document).ready(function () {
|
26 |
+
$('#exampleModal').modal("show");
|
27 |
+
});
|
28 |
+
|
29 |
+
//์ถ๊ฐ ๋ชจ๋ฌ ์ด๊ธฐ ํจ์
|
30 |
+
function openAddItemModal() {
|
31 |
+
const addItemModal = createAddItemModal();
|
32 |
+
openModal(addItemModal);
|
33 |
+
}
|
34 |
+
|
35 |
+
//ํผ ์ ์ถ ์ ํญ๋ชฉ ์ถ๊ฐ ํจ์
|
36 |
+
function addItem() {
|
37 |
+
event.preventDefault();
|
38 |
+
const newItem = document.getElementById('newItem').value;
|
39 |
+
itemList.push(newItem);
|
40 |
+
updateItemList();
|
41 |
+
closeModal('addItemModal');
|
42 |
+
}
|
43 |
+
</script>
|
44 |
+
|
45 |
+
<div class="background">
|
46 |
+
<div class="header"><a href="/">
|
47 |
+
<span class="title">Nouvel : Novel for you</span></a></div>
|
48 |
+
<div class="empty"></div>
|
49 |
+
|
50 |
+
<div class="box">
|
51 |
+
<div class="subtitle">
|
52 |
+
๊ฐ์งํ ๋ฑ์ฅ์ธ๋ฌผ ๋ช
๋จ
|
53 |
+
</div>
|
54 |
+
<div class="explain_box">
|
55 |
+
<span class="explain">์์
์ด ์งํ ์ค ์
๋๋ค. ์ ์๋ง ๊ธฐ๋ค๋ ค ์ฃผ์ธ์.</span>
|
56 |
+
</div>
|
57 |
+
<div class="container mt-4">
|
58 |
+
<ol id="itemList" class="list-group">
|
59 |
+
<!-- ๋ชฉ๋ก์ด ์ฌ๊ธฐ์ ๋์ ์ผ๋ก ์ถ๊ฐ๋ฉ๋๋ค. -->
|
60 |
+
</ol>
|
61 |
+
|
62 |
+
<!-- ์ถ๊ฐ ๋ฒํผ
|
63 |
+
<button class="btn btn-primary mt-3" data-toggle="modal" data-target="#addItemModal">์ถ๊ฐ</button> -->
|
64 |
+
|
65 |
+
<form>
|
66 |
+
<input type="text" class="form-control" id="newItem">
|
67 |
+
<div class="modal-footer">
|
68 |
+
<button type="button" class="btn btn-secondary" data-dismiss="modal">์ทจ์</button>
|
69 |
+
<!-- ์์ ๋ ๋ถ๋ถ: type ์์ฑ์ "button"์์ "submit"์ผ๋ก ๋ณ๊ฒฝ -->
|
70 |
+
<button type="submit" class="btn btn-primary" onclick="addItem()">์ถ๊ฐ</button>
|
71 |
+
</div>
|
72 |
+
</form>
|
73 |
+
|
74 |
+
</div>
|
75 |
+
|
76 |
+
<!-- ์์ /์ญ์ ๋ชจ๋ฌ -->
|
77 |
+
<div class="modal fade" id="editItemModal" tabindex="-1" role="dialog" aria-labelledby="editItemModalLabel"
|
78 |
+
aria-hidden="true">
|
79 |
+
<!-- ๋ชจ๋ฌ ์ฐฝ ๋ด์ฉ์ด ์ฌ๊ธฐ์ ๋์ ์ผ๋ก ์ถ๊ฐ๋ฉ๋๋ค. -->
|
80 |
+
</div>
|
81 |
+
|
82 |
+
<!-- ์ถ๊ฐ ๋ชจ๋ฌ
|
83 |
+
<div class="modal fade" id="addItemModal" tabindex="-1" role="dialog" aria-labelledby="addItemModalLabel" aria-hidden="true">
|
84 |
+
๋ชจ๋ฌ ์ฐฝ ๋ด์ฉ์ด ์ฌ๊ธฐ์ ๋์ ์ผ๋ก ์ถ๊ฐ๋ฉ๋๋ค.
|
85 |
+
</div>
|
86 |
+
-->
|
87 |
+
|
88 |
+
<div class="buttonbox">
|
89 |
+
<div class="transformbox">
|
90 |
+
<button onclick="handleButtonClick()" class="transformButton" type="submit">
|
91 |
+
<span>์์ํ๊ธฐ</span>
|
92 |
+
</button>
|
93 |
+
</div>
|
94 |
+
</div>
|
95 |
+
<div class="foot">
|
96 |
+
<div class="footer-text">
|
97 |
+
<span>๊ณ ๋ ค๋ํ๊ต ์ง๋ฅ์ ๋ณด SW ์์นด๋ฐ๋ฏธ 5์กฐ</span>
|
98 |
+
</div>
|
99 |
+
</div>
|
100 |
+
</div>
|
101 |
+
</div>
|
102 |
+
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js"></script>
|
103 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js"></script>
|
104 |
+
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js"></script>
|
105 |
+
<script type="text/javascript" src="../static/js/confirm.js"></script>
|
106 |
+
<script>
|
107 |
+
document.addEventListener('DOMContentLoaded', function () {handlePageLoad();});
|
108 |
+
</script>
|
109 |
+
</body>
|
110 |
+
|
111 |
+
</html>
|
web/final.html
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="ko">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<title>Spakers in Text</title>
|
8 |
+
<link rel="stylesheet" href="{{ url_for('static', path='css/finishs.css') }}">
|
9 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
10 |
+
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
|
11 |
+
</head>
|
12 |
+
|
13 |
+
<body>
|
14 |
+
<!-- jQuery ๋ฐ Bootstrap JavaScript ๋ก๋ -->
|
15 |
+
<script src="https://code.jquery.com/jquery-3.6.4.min.js"
|
16 |
+
integrity="sha256-oP6HI/tZ1aS9sz3Jr4+6zqbc9BE/l6fLx+Vz2I+H/GL4ZiI/Z5L3hMv8w3yXdBi" crossorigin="anonymous"></script>
|
17 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
|
18 |
+
integrity="sha384-3ziFjNlAXja/Yb0M7y2BmFvR3s09gRPbrCm0lF+SvL4uIboD5lv3U3BdD7dW7Y3" crossorigin="anonymous"></script>
|
19 |
+
|
20 |
+
<div class="background">
|
21 |
+
<div class="header"><a href="/"><span class="title">Nouvel : Novel for you</span></a></div>
|
22 |
+
<div class="empty"></div>
|
23 |
+
<div class="box">
|
24 |
+
<div class="subtitle">txtํ์ผ ๋ณํ ๊ฒฐ๊ณผ</div>
|
25 |
+
<div class="explain_box">
|
26 |
+
<span class="explain">๋ฐ์ ๊ฒฐ๊ณผ๋ ์์ค ์ ์ธ์ฉ๋ฌธ ๋ณ๋ก ๋ฐํ์๋ฅผ ์ธ์ํ ๊ฒฐ๊ณผ์
๋๋ค.</br>
|
27 |
+
๋ํ ์ฅ๋ฉด ๋ณ๋ก ์ฅ์์ ์๊ฐ์ ์ธ์ํ์์ต๋๋ค.</span>
|
28 |
+
</div>
|
29 |
+
<div class = "TLContainer">
|
30 |
+
<div class = "timeContainer">
|
31 |
+
<span class = "timelocation">์๊ฐ : {{ time }}</span>
|
32 |
+
</div>
|
33 |
+
<div class = "locationContainer">
|
34 |
+
<span class = "timelocation">์ฅ์ : {{ place }}</span>
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
<div class="itemContainer" id="resultContainer">
|
38 |
+
{% for out in output %}
|
39 |
+
<p>{{ out }}</p>
|
40 |
+
{% endfor %}
|
41 |
+
</div>
|
42 |
+
<div class="buttonContainer">
|
43 |
+
<div class="downloadButtonContainer">
|
44 |
+
<button id="downloadButton">Download Text File</button>
|
45 |
+
</div>
|
46 |
+
<div class="homeButtonContainer">
|
47 |
+
<button id="homeButton" onClick="location.href='/'">ํ์ผ๋ก ๋์๊ฐ๊ธฐ</button>
|
48 |
+
</div>
|
49 |
+
</div>
|
50 |
+
</div>
|
51 |
+
<div class="foot">
|
52 |
+
<div class="footer-text">
|
53 |
+
<span>๊ณ ๋ ค๋ํ๊ต ์ง๋ฅ์ ๋ณด SW ์์นด๋ฐ๋ฏธ 5์กฐ</span>
|
54 |
+
</div>
|
55 |
+
</div>
|
56 |
+
</div>
|
57 |
+
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js"></script>
|
58 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js"></script>
|
59 |
+
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js"></script>
|
60 |
+
<script type="text/javascript" src="../js/finish.js"></script>
|
61 |
+
</body>
|
62 |
+
|
63 |
+
</html>
|
web/index.html
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="ko">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<title>Spakers in Text</title>
|
8 |
+
<link rel="stylesheet" href="{{ url_for('static', path='css/indexs.css') }}">
|
9 |
+
</head>
|
10 |
+
|
11 |
+
<body>
|
12 |
+
<div class="background">
|
13 |
+
<div class="header">
|
14 |
+
<a href="/">
|
15 |
+
<span class="title">Nouvel : Novel for you</span>
|
16 |
+
</a>
|
17 |
+
</div>
|
18 |
+
<div class="empty"></div>
|
19 |
+
<div class="box">
|
20 |
+
<p class="subtitle">์๋น์ค ์๊ฐ</p>
|
21 |
+
<div class="explain-box">
|
22 |
+
<span class="explain">BERT ๋ชจ๋ธ์ ํ์ฉํ์ฌ ํ
์คํธ ํ์ผ ๋ด์์ ๋ฐํ์๋ฅผ ์ธ์ํ ์ ์๋ ๋ชจ๋ธ์
๋๋ค.
|
23 |
+
<br>์๋์ ์์ํ๊ธฐ ๋ฒํผ์ ๋๋ฌ ํ
์คํธ ํ์ผ์ ์
๋ก๋ํ๊ณ ๋ชจ๋ธ์
|
24 |
+
์ด์ฉํ ๋ฐํ์ ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.</span>
|
25 |
+
</div>
|
26 |
+
<div class="buttonbox">
|
27 |
+
<button onClick="location.href='put.html'" class="start-button">
|
28 |
+
<span>์์ํ๊ธฐ</span>
|
29 |
+
</button>
|
30 |
+
</div>
|
31 |
+
</div>
|
32 |
+
<div class="foot">
|
33 |
+
<div class="footer-text">
|
34 |
+
<span>๊ณ ๋ ค๋ํ๊ต ์ง๋ฅ์ ๋ณด SW ์์นด๋ฐ๋ฏธ 5์กฐ</span>
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
</div>
|
38 |
+
</body>
|
39 |
+
|
40 |
+
</html>
|
web/put.html
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="ko">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<title>Spakers in Text</title>
|
8 |
+
<link rel="stylesheet" href="{{ url_for('static', path='css/put.css') }}">
|
9 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
10 |
+
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
|
11 |
+
</head>
|
12 |
+
|
13 |
+
<body>
|
14 |
+
<div class="background">
|
15 |
+
<div class="header"><a href="/"><span class="title">Nouvel : Novel for you</span></a></div>
|
16 |
+
<div class="empty"></div>
|
17 |
+
<div class="box">
|
18 |
+
<div class="subtitle">๋ณํํ ํ์ผ ์
๋ก๋ํ๊ธฐ</div>
|
19 |
+
<div class="explain_box">
|
20 |
+
<span class="explain">๋ฐํ์๋ฅผ ์ฐพ๊ณ ์ถ์ txt ํ์ผ์ ์
๋ก๋ํ๊ณ ๋ณํํ๊ธฐ ๋ฒํผ์ ํด๋ฆญํ์ธ์. </span>
|
21 |
+
</div>
|
22 |
+
<form id="myForm" onsubmit="return validateForm()">
|
23 |
+
<div class="button_box">
|
24 |
+
<p>์๋์ ๋ฒํผ์ ํด๋ฆญํ์ฌ ์ฌ์ฉ์์ ์ ํ์ ์ ํํด์ฃผ์ธ์. <br><b><i>์๊ฐ(์ ๋ฌธ๊ฐ) ์ฌ์ฉ์๋ ๋ฑ์ฅ์ธ๋ฌผ ๋ฆฌ์คํธ๋ฅผ ๊ฒํ ํ ์ ์์ต๋๋ค.</i></b></p>
|
25 |
+
<input type="radio" class="btn-check" name="displayOption" id="option1" autocomplete="off" value="pro">
|
26 |
+
<label class="btn btn-secondary" for="option1">์๊ฐ(์ ๋ฌธ๊ฐ)</label>
|
27 |
+
|
28 |
+
<input type="radio" class="btn-check" name="displayOption" id="option2" autocomplete="off" value="reader">
|
29 |
+
<label class="btn btn-secondary" for="option2">๋
์(์ด์ฉ์)</label>
|
30 |
+
</div>
|
31 |
+
<div class="formbox">
|
32 |
+
<div class="mb-3">
|
33 |
+
<label for="formFileSm" class="form-label"><b>์ฒจ๋ถํ์ผ ์
๋ก๋</b></label>
|
34 |
+
<input class="form-control form-control-sm" id="formFileSm" type="file">
|
35 |
+
</div>
|
36 |
+
</div>
|
37 |
+
<div class="transformbox">
|
38 |
+
<button onclick="handleButtonClick()" class="transformButton" type="button">
|
39 |
+
<span>๋ณํํ๊ธฐ</span>
|
40 |
+
</button>
|
41 |
+
</div>
|
42 |
+
</form>
|
43 |
+
</div>
|
44 |
+
<div class="foot">
|
45 |
+
<div class="footer-text">
|
46 |
+
<span>๊ณ ๋ ค๋ํ๊ต ์ง๋ฅ์ ๋ณด SW ์์นด๋ฐ๋ฏธ 5์กฐ</span>
|
47 |
+
</div>
|
48 |
+
</div>
|
49 |
+
</div>
|
50 |
+
<script src="../static/js/put.js"></script>
|
51 |
+
</body>
|
52 |
+
|
53 |
+
</html>
|
web/user.html
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="ko">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
6 |
+
<meta charset="utf-8">
|
7 |
+
<title>Spakers in Text</title>
|
8 |
+
<link rel="stylesheet" href="{{ url_for('static', path='css/index.css') }}">
|
9 |
+
</head>
|
10 |
+
|
11 |
+
<body>
|
12 |
+
<div class="background">
|
13 |
+
<div class="header"><a href="/"><span class="title">Nouvel : Novel for you</span></a></div>
|
14 |
+
<div class="empty"></div>
|
15 |
+
<div class="box">
|
16 |
+
<div class="subtitle">๋ฑ์ฅ์ธ๋ฌผ ๋ช
๋จ</div>
|
17 |
+
<div class="body">
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
</div>
|
22 |
+
<div class="transformbox">
|
23 |
+
<button onclick="handleButtonClick()" class="transformButton"><span>์ฅ๋ฉด์ผ๋ก ๋ณํํ๊ธฐ</span></button>
|
24 |
+
</div>
|
25 |
+
</div>
|
26 |
+
<div class="foot">
|
27 |
+
<div class="footer-text">
|
28 |
+
<span>๊ณ ๋ ค๋ํ๊ต ์ง๋ฅ์ ๋ณด SW ์์นด๋ฐ๋ฏธ 5์กฐ</span>
|
29 |
+
</div>
|
30 |
+
</div>
|
31 |
+
<script src="../static/js/user.js"></script>
|
32 |
+
</body>
|
33 |
+
|
34 |
+
</html>
|