yuneun92 commited on
Commit
bcb1848
ยท
verified ยท
1 Parent(s): f59d2d3

Upload 13 files

Browse files
utils/__init__.py ADDED
File without changes
utils/arguments.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A
3
+ """
4
+ # ๋ฏธ๋ฆฌ ์„ค์ •๋œ ์ธ์ˆ˜๋“ค
5
+ from argparse import ArgumentParser
6
+
7
+ # ์‚ฌ์šฉ์ž ์„ ์ • ๋ณ€์ˆ˜๋“ค
8
+ ROOT_DIR = "" # ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
9
+ BERT_PRETRAINED_DIR = "klue/roberta-large" # BERT ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
10
+ DATA_PREFIX = "data" # ๋ฐ์ดํ„ฐ ํŒŒ์ผ๋“ค์˜ ์ƒ์œ„ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
11
+ CHECKPOINT_DIR = 'model' # ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
12
+ LOG_FATH = 'logs' # ํ›ˆ๋ จ ๋กœ๊ทธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
13
+
14
+ def get_train_args():
15
+ """
16
+ ํ›ˆ๋ จ ์ธ์ˆ˜ ์„ค์ •
17
+ """
18
+ parser = ArgumentParser(description='I_S', allow_abbrev=False)
19
+
20
+ # ์ธ์ˆ˜ ํŒŒ์‹ฑ
21
+ parser.add_argument('--model_name', type=str, default='KCSN')
22
+
23
+ # ๋ชจ๋ธ ์„ค์ •
24
+ parser.add_argument('--pooling_type', type=str, default='max_pooling')
25
+ parser.add_argument('--classifier_intermediate_dim', type=int, default=100)
26
+ parser.add_argument('--nonlinear_type', type=str, default='tanh')
27
+
28
+ # BERT ์„ค์ •
29
+ parser.add_argument('--bert_pretrained_dir', type=str, default=BERT_PRETRAINED_DIR)
30
+
31
+ # ํ›ˆ๋ จ ์„ค์ •
32
+ parser.add_argument('--margin', type=float, default=1.0)
33
+ parser.add_argument('--lr', type=float, default=2e-5)
34
+ parser.add_argument('--optimizer', type=str, default='adam')
35
+ parser.add_argument('--dropout', type=float, default=0.5)
36
+ parser.add_argument('--num_epochs', type=int, default=50)
37
+ parser.add_argument('--batch_size', type=int, default=16)
38
+ parser.add_argument('--lr_decay', type=float, default=0.95)
39
+ parser.add_argument('--patience', type=int, default=10)
40
+
41
+ # ํ›ˆ๋ จ, ๊ฐœ๋ฐœ ๋ฐ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ํŒŒ์ผ ๊ฒฝ๋กœ
42
+ parser.add_argument('--train_file', type=str, default=f'{DATA_PREFIX}/train_unsplit.txt')
43
+ parser.add_argument('--dev_file', type=str, default=f'{DATA_PREFIX}/dev_unsplit.txt')
44
+ parser.add_argument('--test_file', type=str, default=f'{DATA_PREFIX}/test_unsplit.txt')
45
+ parser.add_argument('--name_list_path', type=str, default=f'{DATA_PREFIX}/name_list.txt')
46
+ parser.add_argument('--ws', type=int, default=10) # ์œˆ๋„์šฐ ํฌ๊ธฐ
47
+
48
+ parser.add_argument('--length_limit', type=int, default=510) # ์‹œํ€€์Šค ๊ธธ์ด ์ œํ•œ
49
+
50
+ # ์ฒดํฌํฌ์ธํŠธ ๋ฐ ๋กœ๊ทธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ
51
+ parser.add_argument('--checkpoint_dir', type=str, default=CHECKPOINT_DIR)
52
+ parser.add_argument('--training_logs', type=str, default=LOG_FATH)
53
+
54
+ args, _ = parser.parse_known_args()
55
+
56
+ return args
utils/data_prep.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author:
3
+ """
4
+ import copy
5
+ from typing import Any
6
+ from ckonlpy.tag import Twitter
7
+ from tqdm import tqdm
8
+ import re
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from sklearn.model_selection import train_test_split
13
+
14
+ # ์‚ฌ์šฉ์ž๊ฐ€ ์‚ฌ์ „์— ๋‹จ์–ด ์ถ”๊ฐ€๊ฐ€ ๊ฐ€๋Šฅํ•œ ํ˜•ํƒœ์†Œ ๋ถ„์„๊ธฐ๋ฅผ ์ด์šฉ(์ถ”ํ›„์— name_list์— ๋“ฑ์žฌ๋œ ์ด๋ฆ„์„ ๋“ฑ๋กํ•˜์—ฌ ์ธ์‹ ๋ฐ ๋ถ„๋ฆฌํ•˜๊ธฐ ์œ„ํ•จ)
15
+ twitter = Twitter()
16
+
17
+
18
+ def load_data(filename) -> Any:
19
+ """
20
+ ์ง€์ •๋œ ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
21
+ """
22
+ return torch.load(filename)
23
+
24
+
25
+ def NML(seg_sents, mention_positions, ws):
26
+ """
27
+ Nearest Mention Location (ํŠน์ • ํ›„๋ณด ๋ฐœํ™”์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜์ค‘, ์ธ์šฉ๋ฌธ์œผ๋กœ๋ถ€ํ„ฐ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์–ธ๊ธ‰ ์œ„์น˜๋ฅผ ์ฐพ๋Š” ํ•จ์ˆ˜)
28
+
29
+ Parameters:
30
+ - seg_sents: ๋ฌธ์žฅ์„ ๋ถ„ํ• ํ•œ ๋ฆฌ์ŠคํŠธ
31
+ - mention_positions: ํŠน์ • ํ›„๋ณด ๋ฐœํ™”์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ๋ชจ๋‘ ๋‹ด์€ ๋ฆฌ์ŠคํŠธ [(sentence_index, word_index), ...]
32
+ - ws: ์ธ์šฉ๋ฌธ ์•ž/๋’ค๋กœ ๊ณ ๋ คํ•  ๋ฌธ์žฅ์˜ ์ˆ˜
33
+
34
+ Returns:
35
+ - ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์–ธ๊ธ‰ ์œ„์น˜์˜ (sentence_index, word_index)
36
+ """
37
+ def word_dist(pos):
38
+ """
39
+ ๋ฐœํ™” ํ›„๋ณด์ž ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜์™€ ์ธ์šฉ๋ฌธ ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๋‹จ์–ด ์ˆ˜์ค€(word level)์—์„œ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
40
+
41
+ Parameters:
42
+ - pos: ๋ฐœํ™” ํ›„๋ณด์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜ (sentence_index, word_index)
43
+
44
+ Returns:
45
+ - ๋ฐœํ™” ํ›„๋ณด์ž์™€ ์–ธ๊ธ‰๋œ ์œ„์น˜ ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ (๋‹จ์–ด ์ˆ˜์ค€)
46
+ """
47
+ if pos[0] == ws:
48
+ w_d = ws * 2
49
+ elif pos[0] < ws:
50
+ w_d = sum(len(
51
+ sent) for sent in seg_sents[pos[0] + 1:ws]) + len(seg_sents[pos[0]][pos[1] + 1:])
52
+ else:
53
+ w_d = sum(
54
+ len(sent) for sent in seg_sents[ws + 1:pos[0]]) + len(seg_sents[pos[0]][:pos[1]])
55
+ return w_d
56
+
57
+ # ์–ธ๊ธ‰๋œ ์œ„์น˜๋“ค๊ณผ ์ธ์šฉ๋ฌธ ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ€๊นŒ์šด ์ˆœ์œผ๋กœ ์ •๋ ฌ
58
+ sorted_positions = sorted(mention_positions, key=lambda x: word_dist(x))
59
+
60
+ # ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์–ธ๊ธ‰ ์œ„์น˜(Nearest Mention Location) ๋ฐ˜ํ™˜
61
+ return sorted_positions[0]
62
+
63
+
64
+ def max_len_cut(seg_sents, mention_pos, max_len):
65
+ """
66
+ ์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ ๋ชจ๋ธ์— ์ž…๋ ฅ ๊ฐ€๋Šฅํ•œ ์ตœ๋Œ€ ๊ธธ์ด(max_len)๋กœ ์ž๋ฅด๋Š” ํ•จ์ˆ˜
67
+
68
+ Parameters:
69
+ - seg_sents: ๋ฌธ์žฅ์„ ๋ถ„ํ• ํ•œ ๋ฆฌ์ŠคํŠธ
70
+ - mention_pos: ๋ฐœํ™” ํ›„๋ณด์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜ (sentence_index, word_index)
71
+ - max_len: ์ž…๋ ฅ ๊ฐ€๋Šฅํ•œ ์ตœ๋Œ€ ๊ธธ์ด
72
+
73
+ Returns:
74
+ - seg_sents : ์ž๋ฅด๊ณ  ๋‚จ์€ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ
75
+ - mention_pos : ์กฐ์ •๋œ ์–ธ๊ธ‰๋œ ์œ„์น˜
76
+ """
77
+
78
+ # ๊ฐ ๋ฌธ์žฅ์˜ ๊ธธ์ด๋ฅผ ๋ฌธ์ž ๋‹จ์œ„๋กœ ๊ณ„์‚ฐํ•œ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
79
+ sent_char_lens = [sum(len(word) for word in sent) for sent in seg_sents]
80
+
81
+ # ์ „์ฒด ๋ฌธ์ž์˜ ๊ธธ์ด ํ•ฉ
82
+ sum_char_len = sum(sent_char_lens)
83
+
84
+ # ๊ฐ ๋ฌธ์žฅ์—์„œ, cut์„ ์‹คํ–‰ํ•  ๋ฌธ์ž์˜ ์œ„์น˜(๋งจ ๋งˆ์ง€๋ง‰ ๋ฌธ์ž)
85
+ running_cut_idx = [len(sent) - 1 for sent in seg_sents]
86
+
87
+ while sum_char_len > max_len:
88
+ max_len_sent_idx = max(list(enumerate(sent_char_lens)), key=lambda x: x[1])[0]
89
+
90
+ if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] == mention_pos[1]:
91
+ running_cut_idx[max_len_sent_idx] -= 1
92
+
93
+ if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] < mention_pos[1]:
94
+ mention_pos[1] -= 1
95
+
96
+ reduced_char_len = len(
97
+ seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]])
98
+ sent_char_lens[max_len_sent_idx] -= reduced_char_len
99
+ sum_char_len -= reduced_char_len
100
+
101
+ # ์ž๋ฅผ ์œ„์น˜ ์‚ญ์ œ
102
+ del seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]]
103
+
104
+ # ์ž๋ฅผ ์œ„์น˜ ์—…๋ฐ์ดํŠธ
105
+ running_cut_idx[max_len_sent_idx] -= 1
106
+
107
+ return seg_sents, mention_pos
108
+
109
+
110
+ def seg_and_mention_location(raw_sents_in_list, alias2id):
111
+ """
112
+ ์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ ๋ถ„ํ• ํ•˜๊ณ  ๋ฐœํ™”์ž ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ์ฐพ๋Š” ํ•จ์ˆ˜
113
+
114
+ Parameters:
115
+ - raw_sents_in_list: ๋ถ„ํ• ํ•  ์›๋ณธ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ
116
+ - alias2id: ์บ๋ฆญํ„ฐ ๋ณ„ ์ด๋ฆ„(๋ฐ ๋ณ„์นญ)๊ณผ ID๋ฅผ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ
117
+
118
+ Returns:
119
+ - seg_sents: ๋ฌธ์žฅ์„ ๋‹จ์–ด๋กœ ๋ถ„ํ• ํ•œ ๋ฆฌ์ŠคํŠธ
120
+ - character_mention_poses: ์บ๋ฆญํ„ฐ๋ณ„๋กœ, ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ๋ชจ๋‘ ์ €์žฅํ•œ ๋”•์…”๋„ˆ๋ฆฌ {character1_id: [[sent_idx, word_idx], ...]}
121
+ - name_list_index: ์–ธ๊ธ‰๋œ ์บ๋ฆญํ„ฐ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ
122
+ """
123
+
124
+ character_mention_poses = {}
125
+ seg_sents = []
126
+ id_pattern = ['&C{:02d}&'.format(i) for i in range(51)]
127
+
128
+ for sent_idx, sent in enumerate(raw_sents_in_list):
129
+ raw_sent_with_split = sent.split()
130
+
131
+ for word_idx, word in enumerate(raw_sent_with_split):
132
+ match = re.search(r'&C\d{1,2}&', word)
133
+
134
+ # &C00& ํ˜•์‹์œผ๋กœ ๋œ ์ด๋ฆ„์ด ์žˆ์„ ๊ฒฝ์šฐ, result ๋ณ€์ˆ˜๋กœ ์ง€์ •
135
+ if match:
136
+ result = match.group(0)
137
+
138
+ if alias2id[result] in character_mention_poses:
139
+ character_mention_poses[alias2id[result]].append([sent_idx, word_idx])
140
+ else:
141
+ character_mention_poses[alias2id[result]] = [[sent_idx, word_idx]]
142
+
143
+ seg_sents.append(raw_sent_with_split)
144
+
145
+ name_list_index = list(character_mention_poses.keys())
146
+
147
+ return seg_sents, character_mention_poses, name_list_index
148
+
149
+
150
+ def create_CSS(seg_sents, candidate_mention_poses, args):
151
+ """
152
+ ๊ฐ ์ธ์Šคํ„ด์Šค ๋‚ด ๊ฐ ๋ฐœํ™”์ž ํ›„๋ณด(candidate)์— ๋Œ€ํ•˜์—ฌ candidate-specific segments(CSS)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
153
+
154
+ parameters:
155
+ seg_sents: 2ws + 1 ๊ฐœ์˜ ๋ฌธ์žฅ(๊ฐ ๋ฌธ์žฅ์€ ๋ถ„ํ• ๋จ)๋“ค์„ ๋‹ด์€ ๋ฆฌ์ŠคํŠธ
156
+ candidate_mention_poses: ๋ฐœํ™”์ž๋ณ„๋กœ ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ๋”•์…”๋„ˆ๋ฆฌ์ด๋ฉฐ, ํ˜•ํƒœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Œ.
157
+ {character index: [[sentence index, word index in sentence] of mention 1,...]...}.
158
+ args : ์‹คํ–‰ ์ธ์ˆ˜๋ฅผ ๋‹ด์€ ๊ฐ์ฒด
159
+
160
+ return:
161
+ Returned contents are in lists, in which each element corresponds to a candidate.
162
+ The order of candidate is consistent with that in list(candidate_mention_poses.keys()).
163
+ many_css: ๊ฐ ๋ฐœํ™”์ž ํ›„๋ณด์— ๋Œ€ํ•œ candidate-specific segments(CSS).
164
+ many_sent_char_len: ๊ฐ CSS์˜ ๋ฌธ์ž ๊ธธ์ด ์ •๋ณด
165
+ [[character-level length of sentence 1,...] of the CSS of candidate 1,...].
166
+ many_mention_pos: CSS ๋‚ด์—์„œ, ์ธ์šฉ๋ฌธ๊ณผ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜ ์ •๋ณด
167
+ [(sentence-level index of nearest mention in CSS,
168
+ character-level index of the leftmost character of nearest mention in CSS,
169
+ character-level index of the rightmost character + 1) of candidate 1,...].
170
+ many_quote_idx: CSS ๋‚ด์˜ ์ธ์šฉ๋ฌธ์˜ ๋ฌธ์žฅ ์ธ๋ฑ์Šค
171
+ many_cut_css : ์ตœ๋Œ€ ๊ธธ์ด ์ œํ•œ์ด ์ ์šฉ๋œ CSS
172
+
173
+ """
174
+ ws = args.ws
175
+ max_len = args.length_limit
176
+ model_name = args.model_name
177
+
178
+ # assert len(seg_sents) == ws * 2 + 1
179
+
180
+ many_css = []
181
+ many_sent_char_lens = []
182
+ many_mention_poses = []
183
+ many_quote_idxes = []
184
+ many_cut_css = []
185
+
186
+ for candidate_idx in candidate_mention_poses.keys():
187
+ nearest_pos = NML(seg_sents, candidate_mention_poses[candidate_idx], ws)
188
+
189
+ if nearest_pos[0] <= ws:
190
+ CSS = copy.deepcopy(seg_sents[nearest_pos[0]:ws + 1])
191
+ mention_pos = [0, nearest_pos[1]]
192
+ quote_idx = ws - nearest_pos[0]
193
+ else:
194
+ CSS = copy.deepcopy(seg_sents[ws:nearest_pos[0] + 1])
195
+ mention_pos = [nearest_pos[0] - ws, nearest_pos[1]]
196
+ quote_idx = 0
197
+
198
+ cut_CSS, mention_pos = max_len_cut(CSS, mention_pos, max_len)
199
+ sent_char_lens = [sum(len(word) for word in sent) for sent in cut_CSS]
200
+
201
+ mention_pos_left = sum(sent_char_lens[:mention_pos[0]]) + sum(
202
+ len(x) for x in cut_CSS[mention_pos[0]][:mention_pos[1]])
203
+ mention_pos_right = mention_pos_left + len(cut_CSS[mention_pos[0]][mention_pos[1]])
204
+
205
+ if model_name == 'CSN':
206
+ mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right)
207
+ cat_CSS = ''.join([''.join(sent) for sent in cut_CSS])
208
+ elif model_name == 'KCSN':
209
+ mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right, mention_pos[1])
210
+ cat_CSS = ' '.join([' '.join(sent) for sent in cut_CSS])
211
+
212
+ many_css.append(cat_CSS)
213
+ many_sent_char_lens.append(sent_char_lens)
214
+ many_mention_poses.append(mention_pos)
215
+ many_quote_idxes.append(quote_idx)
216
+ many_cut_css.append(cut_CSS)
217
+
218
+ return many_css, many_sent_char_lens, many_mention_poses, many_quote_idxes, many_cut_css
219
+
220
+
221
+ class ISDataset(Dataset):
222
+ """
223
+ ๋ฐœํ™”์ž ์‹๋ณ„์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ์…‹ ์„œ๋ธŒํด๋ž˜์Šค
224
+ """
225
+ def __init__(self, data_list):
226
+ super(ISDataset, self).__init__()
227
+ self.data = data_list
228
+
229
+ def __len__(self):
230
+ return len(self.data)
231
+
232
+ def __getitem__(self, idx):
233
+ return self.data[idx]
234
+
235
+
236
+ def build_data_loader(data_file, alias2id, args, save_name=None) -> DataLoader:
237
+ """
238
+ ํ•™์Šต์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
239
+ """
240
+ # ์‚ฌ์ „์— ์ด๋ฆ„์„ ์ถ”๊ฐ€
241
+ for alias in alias2id:
242
+ twitter.add_dictionary(alias, 'Noun')
243
+
244
+ # ํŒŒ์ผ์„ ์ค„๋ณ„๋กœ ๋ถˆ๋Ÿฌ๋“ค์ž„
245
+ with open(data_file, 'r', encoding='utf-8') as fin:
246
+ data_lines = fin.readlines()
247
+
248
+ # ์ „์ฒ˜๋ฆฌ
249
+ data_list = []
250
+
251
+ for i, line in enumerate(tqdm(data_lines)):
252
+ offset = i % 31
253
+
254
+ if offset == 0:
255
+ instance_index = line.strip().split()[-1]
256
+ raw_sents_in_list = []
257
+ continue
258
+
259
+ if offset < 22:
260
+ raw_sents_in_list.append(line.strip())
261
+
262
+ if offset == 22:
263
+ speaker_name = line.strip().split()[-1]
264
+
265
+ # ๋นˆ ๋ฆฌ์ŠคํŠธ๋Š” ์ œ๊ฑฐ
266
+ filtered_list = [li for li in raw_sents_in_list if li]
267
+
268
+ # ๏ฟฝ๏ฟฝ์žฅ ๋ถ„ํ•  ๋ฐ ๋“ฑ์žฅ์ธ๋ฌผ ์–ธ๊ธ‰ ์œ„์น˜ ์ถ”์ถœ
269
+ seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
270
+ filtered_list, alias2id)
271
+
272
+ # CSS ์ƒ์„ฑ
273
+ css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS(
274
+ seg_sents, candidate_mention_poses, args)
275
+
276
+ # ํ›„๋ณด์ž ๋ฆฌ์ŠคํŠธ
277
+ candidates_list = list(candidate_mention_poses.keys())
278
+
279
+ # ์›ํ•ซ ๋ ˆ์ด๋ธ” ์ƒ์„ฑ
280
+ one_hot_label = [0 if character_idx != alias2id[speaker_name]
281
+ else 1 for character_idx in candidate_mention_poses.keys()]
282
+
283
+ true_index = one_hot_label.index(1) if 1 in one_hot_label else 0
284
+
285
+ if offset == 24:
286
+ category = line.strip().split()[-1]
287
+
288
+ if offset == 25:
289
+ name = ' '.join(line.strip().split()[1:])
290
+
291
+ if offset == 26:
292
+ scene = line.strip().split()[-1]
293
+
294
+ if offset == 27:
295
+ place = line.strip().split()[-1]
296
+
297
+ if offset == 28:
298
+ time = line.strip().split()[-1]
299
+
300
+ if offset == 29:
301
+ cut_position = line.strip().split()[-1]
302
+ data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
303
+ cut_css, one_hot_label, true_index, category, name_list_index,
304
+ name, scene, place, time, cut_position, candidates_list,
305
+ instance_index))
306
+ # ๋ฐ์ดํ„ฐ๋กœ๋” ์ƒ์„ฑ
307
+ data_loader = DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
308
+
309
+ # ์ €์žฅํ•  ์ด๋ฆ„์ด ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ ์ €์žฅ
310
+ if save_name is not None:
311
+ torch.save(data_list, save_name)
312
+
313
+ return data_loader
314
+
315
+
316
+ def load_data_loader(saved_filename: str) -> DataLoader:
317
+ """
318
+ ์ €์žฅ๋œ ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๊ณ  DataLoader ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
319
+ """
320
+ # ์ €์žฅ๋œ ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ ๋กœ๋“œ
321
+ data_list = load_data(saved_filename)
322
+ return DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
323
+
324
+
325
+ def split_train_val_test(data_file, alias2id, args, save_name=None, test_size=0.2, val_size=0.1, random_state=13):
326
+ """
327
+ ๊ธฐ์กด ๊ฒ€์ฆ ๋ฐฉ์‹์„ ์ ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋นŒ๋“œํ•ฉ๋‹ˆ๋‹ค.
328
+ ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ ํŒŒ์ผ์„ ํ›ˆ๋ จ, ๊ฒ€์ฆ, ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋ถ„ํ• ํ•˜๊ณ  ๊ฐ๊ฐ์˜ DataLoader๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
329
+
330
+ Parameters:
331
+ - data_file: ๋ถ„ํ• ํ•  ๋ฐ์ดํ„ฐ ํŒŒ์ผ ๊ฒฝ๋กœ
332
+ - alias2id: ๋“ฑ์žฅ์ธ๋ฌผ ์ด๋ฆ„๊ณผ ID๋ฅผ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ
333
+ - args: ์‹คํ–‰ ์ธ์ž๋ฅผ ๋‹ด์€ ๊ฐ์ฒด
334
+ - save_name: ๋ถ„ํ• ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•  ํŒŒ์ผ ์ด๋ฆ„
335
+ - test_size: ํ…Œ์ŠคํŠธ ์„ธํŠธ์˜ ๋น„์œจ (๊ธฐ๋ณธ๊ฐ’: 0.2)
336
+ - val_size: ๊ฒ€์ฆ ์„ธํŠธ์˜ ๋น„์œจ (๊ธฐ๋ณธ๊ฐ’: 0.1)
337
+ - random_state: ๋žœ๋ค ์‹œ๋“œ (๊ธฐ๋ณธ๊ฐ’: 13)
338
+
339
+ Returns:
340
+ - train_loader: ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ๋”
341
+ - val_loader: ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ๋กœ๋”
342
+ - test_loader: ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ๋”
343
+ """
344
+
345
+ # ์‚ฌ์ „์— ์ด๋ฆ„ ์ถ”๊ฐ€
346
+ for alias in alias2id:
347
+ twitter.add_dictionary(alias, 'Noun')
348
+
349
+ # ํŒŒ์ผ์—์„œ ์ธ์Šคํ„ด์Šค ๋กœ๋“œ
350
+ with open(data_file, 'r', encoding='utf-8') as fin:
351
+ data_lines = fin.readlines()
352
+
353
+ # ์ „์ฒ˜๋ฆฌ
354
+ data_list = []
355
+
356
+ for i, line in enumerate(tqdm(data_lines)):
357
+ offset = i % 31
358
+
359
+ if offset == 0:
360
+ instance_index = line.strip().split()[-1]
361
+ raw_sents_in_list = []
362
+ continue
363
+
364
+ if offset < 22:
365
+ raw_sents_in_list.append(line.strip())
366
+
367
+ if offset == 22:
368
+ speaker_name = line.strip().split()[-1]
369
+
370
+ # ๋นˆ ๋ฆฌ์ŠคํŠธ๋Š” ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.
371
+ filtered_list = [li for li in raw_sents_in_list if li]
372
+
373
+ # ๋ฌธ์žฅ ๋ถ„ํ•  ๋ฐ ๋“ฑ์žฅ์ธ๋ฌผ ์–ธ๊ธ‰ ์œ„์น˜ ์ถ”์ถœ
374
+ seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
375
+ filtered_list, alias2id)
376
+
377
+ # CSS ์ƒ์„ฑ
378
+ css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS(
379
+ seg_sents, candidate_mention_poses, args)
380
+
381
+ # ํ›„๋ณด์ž ๋ฆฌ์ŠคํŠธ
382
+ candidates_list = list(candidate_mention_poses.keys())
383
+
384
+ # ์›ํ•ซ ๋ ˆ์ด๋ธ” ์ƒ์„ฑ
385
+ one_hot_label = [0 if character_idx != alias2id[speaker_name]
386
+ else 1 for character_idx in candidate_mention_poses.keys()]
387
+
388
+ true_index = one_hot_label.index(1) if 1 in one_hot_label else 0
389
+
390
+ if offset == 24:
391
+ category = line.strip().split()[-1]
392
+
393
+ if offset == 25:
394
+ name = ' '.join(line.strip().split()[1:])
395
+
396
+ if offset == 26:
397
+ scene = line.strip().split()[-1]
398
+
399
+ if offset == 27:
400
+ place = line.strip().split()[-1]
401
+
402
+ if offset == 28:
403
+ time = line.strip().split()[-1]
404
+
405
+ if offset == 29:
406
+ cut_position = line.strip().split()[-1]
407
+ data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
408
+ cut_css, one_hot_label, true_index, category, name_list_index,
409
+ name, scene, place, time, cut_position, candidates_list,
410
+ instance_index))
411
+
412
+ # train-validation-test๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆ„๊ธฐ
413
+ train_data, test_data = train_test_split(
414
+ data_list, test_size=test_size, random_state=random_state)
415
+ train_data, val_data = train_test_split(
416
+ train_data, test_size=val_size, random_state=random_state)
417
+
418
+ # train DataLoader ์ƒ์„ฑ
419
+ train_loader = DataLoader(ISDataset(train_data), batch_size=1, collate_fn=lambda x: x[0])
420
+
421
+ # validation DataLoader ์ƒ์„ฑ
422
+ val_loader = DataLoader(ISDataset(val_data), batch_size=1, collate_fn=lambda x: x[0])
423
+
424
+ # test DataLoader ์ƒ์„ฑ
425
+ test_loader = DataLoader(ISDataset(test_data), batch_size=1, collate_fn=lambda x: x[0])
426
+
427
+ if save_name is not None:
428
+ # ๊ฐ๊ฐ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅ
429
+ torch.save(train_data, save_name.replace(".pt", "_train.pt"))
430
+ torch.save(val_data, save_name.replace(".pt", "_val.pt"))
431
+ torch.save(test_data, save_name.replace(".pt", "_test.pt"))
432
+
433
+ return train_loader, val_loader, test_loader
utils/fs_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ํ™”์ž ์ฐพ๋Š” ๋ชจ๋ธ ์œ ํ‹ธ ํŒŒ์ผ๋“ค
3
+ """
4
+ class InputFeatures:
5
+ """
6
+ BERT ๋ชจ๋ธ์˜ ์ž…๋ ฅ๋“ค
7
+ """
8
+ def __init__(self, tokens, input_ids, input_mask, input_type_ids):
9
+ self.tokens = tokens
10
+ self.input_ids = input_ids
11
+ self.input_mask = input_mask
12
+ self.input_type_ids = input_type_ids
13
+
14
+
15
+ def convert_examples_to_features(examples, tokenizer):
16
+ """
17
+ ํ…์ŠคํŠธ segment๋ฅผ ๋‹จ์–ด ID๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
18
+ """
19
+ features = []
20
+ tokens_list = []
21
+
22
+ for (ex_index, example) in enumerate(examples):
23
+ tokens = tokenizer.tokenize(example)
24
+ tokens_list.append(tokens)
25
+
26
+ new_tokens = []
27
+ input_type_ids = []
28
+
29
+ new_tokens.append("[CLS]")
30
+ input_type_ids.append(0)
31
+ new_tokens = new_tokens + tokens
32
+ input_type_ids = input_type_ids + [0] * len(tokens)
33
+ new_tokens.append("[SEP]")
34
+ input_type_ids.append(0)
35
+
36
+ input_ids = tokenizer.convert_tokens_to_ids(new_tokens)
37
+ input_mask = [1] * len(input_ids)
38
+
39
+ features.append(
40
+ InputFeatures(
41
+ tokens=new_tokens,
42
+ input_ids=input_ids,
43
+ input_mask=input_mask,
44
+ input_type_ids=input_type_ids))
45
+
46
+ return features, tokens_list
47
+
48
+
49
+ def get_alias2id(name_list_path) -> dict:
50
+ """
51
+ ์ฃผ์–ด์ง„ ์ด๋ฆ„ ๋ชฉ๋ก ํŒŒ์ผ์—์„œ ๋ณ„์นญ(alias)์„ ID๋กœ ๋งคํ•‘ํ•˜๋Š” ์‚ฌ์ „์„ ์ƒ์„ฑ.
52
+ """
53
+ with open(name_list_path, 'r', encoding='utf-8') as fin:
54
+ name_lines = fin.readlines()
55
+ alias2id = {}
56
+
57
+ for i, line in enumerate(name_lines):
58
+ for alias in line.strip().split()[1:]:
59
+ alias2id[alias] = i
60
+
61
+ return alias2id
62
+
63
+
64
+ def find_speak(fs_model, input_data, tokenizer, alias2id):
65
+ """
66
+ ์ฃผ์–ด์ง„ ๋ชจ๋ธ๊ณผ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ์ž…๋ ฅ์— ๋Œ€ํ•œ ํ™”์ž๋ฅผ ์ฐพ๋Š” ํ•จ์ˆ˜
67
+ """
68
+ model = fs_model
69
+ check_data_iter = iter(input_data)
70
+
71
+ names = []
72
+
73
+ for _ in range(len(input_data)):
74
+
75
+ seg_sents, css, scl, mp, qi, cut_css, name_list_index = next(check_data_iter)
76
+ features, tokens_list = convert_examples_to_features(examples=css, tokenizer=tokenizer)
77
+
78
+ try:
79
+ predictions = model(features, scl, mp, qi, 0, "cuda:0", tokens_list, cut_css)
80
+ except RuntimeError:
81
+ predictions = model(features, scl, mp, qi, 0, "cpu", tokens_list, cut_css)
82
+
83
+ scores, _, _ = predictions
84
+
85
+ # ํ›„์ฒ˜๋ฆฌ
86
+ try:
87
+ scores_np = scores.detach().cpu().numpy()
88
+ scores_list = scores_np.tolist()
89
+ score_index = scores_list.index(max(scores_list))
90
+ name_index = name_list_index[score_index]
91
+
92
+ for key, val in alias2id.items():
93
+ if val == name_index:
94
+ result_key = key
95
+
96
+ names.append(result_key)
97
+ except AttributeError:
98
+ names.append('์•Œ ์ˆ˜ ์—†์Œ')
99
+
100
+ return names
101
+
102
+
103
+ def making_script(text, speaker:list, instance_num:list) -> str:
104
+ """
105
+ ์ฃผ์–ด์ง„ ํ…์ŠคํŠธ์™€ ํ™”์ž ๋ชฉ๋ก, ํ•ด๋‹นํ•˜๋Š” ์ค„ ๋ฒˆํ˜ธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋Œ€ํ™” ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
106
+ """
107
+ lines = text.splitlines()
108
+ for num, people in zip(instance_num, speaker):
109
+ lines[num] = f'{people}: {lines[num]}'
110
+ return lines
utils/input_process.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ๊ฐ€๊ณตํ•˜๋Š” ๋ชจ๋“ˆ
3
+ """
4
+ import copy
5
+ import re
6
+
7
+ from torch.utils.data import DataLoader, Dataset
8
+
9
+
10
+ class ISDataset(Dataset):
11
+ """
12
+ Dataset subclass for Identifying speaker.
13
+ """
14
+ def __init__(self, data_list):
15
+ super(ISDataset, self).__init__()
16
+ self.data = data_list
17
+
18
+ def __len__(self):
19
+ return len(self.data)
20
+
21
+ def __getitem__(self, idx):
22
+ return self.data[idx]
23
+
24
+
25
+ def make_instance_list(text: str, ws=10) -> list:
26
+ """
27
+ ์ž…๋ ฅ๋ฐ›์€ ๋ฌธ์žฅ์„ ๊ธฐ์ดˆ์ ์ธ ์ธ์Šคํ„ด์Šค ๋ฆฌ์ŠคํŠธ๋กœ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.
28
+ """
29
+ lines = text.splitlines()
30
+ max_line = len(lines)
31
+
32
+ utterance = ['"', 'โ€œ', 'โ€˜']
33
+ instance_num = []
34
+
35
+ for idx, line in enumerate(lines):
36
+ if any(u in line for u in utterance):
37
+ instance_num.append(idx)
38
+
39
+ instance = [[] for _ in range(len(instance_num))]
40
+
41
+ for i, num in enumerate(instance_num):
42
+ if num - ws <= 0 and num + ws + 1 < max_line:
43
+ instance[i] += ([''] * (ws - num))
44
+ instance[i] +=(lines[:num + 1 + ws])
45
+ elif num - ws <= 0 and num + ws + 1 >= max_line:
46
+ instance[i] += ([''] * (ws - num))
47
+ instance[i] +=(lines)
48
+ instance[i] += ([''] * (ws * 2 - len(instance[i]) + 1))
49
+ elif num + ws + 1 >= max_line:
50
+ instance[i] +=(lines[num-ws:max_line+1])
51
+ instance[i] += ([''] * (num + ws + 1 - max_line))
52
+ else:
53
+ instance[i] += (lines[num-ws:num + ws + 1])
54
+
55
+ return instance, instance_num
56
+
57
+
58
+ def NML(seg_sents, mention_positions, ws):
59
+ """
60
+ Nearest Mention Location
61
+ """
62
+ def word_dist(pos):
63
+ """
64
+ The word level distance between quote and the mention position
65
+ """
66
+ if pos[0] == ws:
67
+ w_d = ws * 2
68
+ elif pos[0] < ws:
69
+ w_d = sum(len(
70
+ sent) for sent in seg_sents[pos[0] + 1:ws]) + len(seg_sents[pos[0]][pos[1] + 1:])
71
+ else:
72
+ w_d = sum(
73
+ len(sent) for sent in seg_sents[ws + 1:pos[0]]) + len(seg_sents[pos[0]][:pos[1]])
74
+ return w_d
75
+
76
+ sorted_positions = sorted(mention_positions, key=lambda x: word_dist(x))
77
+
78
+ return sorted_positions[0]
79
+
80
+
81
+ def max_len_cut(seg_sents, mention_pos, max_len):
82
+ sent_char_lens = [sum(len(word) for word in sent) for sent in seg_sents]
83
+ sum_char_len = sum(sent_char_lens)
84
+
85
+ running_cut_idx = [len(sent) - 1 for sent in seg_sents]
86
+
87
+ while sum_char_len > max_len:
88
+ max_len_sent_idx = max(list(enumerate(sent_char_lens)), key=lambda x: x[1])[0]
89
+
90
+ if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] == mention_pos[1]:
91
+ running_cut_idx[max_len_sent_idx] -= 1
92
+
93
+ if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] < mention_pos[1]:
94
+ mention_pos[1] -= 1
95
+
96
+ reduced_char_len = len(
97
+ seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]])
98
+ sent_char_lens[max_len_sent_idx] -= reduced_char_len
99
+ sum_char_len -= reduced_char_len
100
+
101
+ del seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]]
102
+
103
+ running_cut_idx[max_len_sent_idx] -= 1
104
+
105
+ return seg_sents, mention_pos
106
+
107
+
108
+ def seg_and_mention_location(raw_sents_in_list, alias2id):
109
+ character_mention_poses = {}
110
+ seg_sents = []
111
+ id_pattern = ['&C{:02d}&'.format(i) for i in range(51)]
112
+
113
+ for sent_idx, sent in enumerate(raw_sents_in_list):
114
+ raw_sent_with_split = sent.split()
115
+
116
+ for word_idx, word in enumerate(raw_sent_with_split):
117
+ match = re.search(r'&C\d{1,2}&', word)
118
+
119
+ if match:
120
+ result = match.group(0)
121
+
122
+ if alias2id[result] in character_mention_poses:
123
+ character_mention_poses[alias2id[result]].append([sent_idx, word_idx])
124
+ else:
125
+ character_mention_poses[alias2id[result]] = [[sent_idx, word_idx]]
126
+
127
+ seg_sents.append(raw_sent_with_split)
128
+
129
+ name_list_index = list(character_mention_poses.keys())
130
+
131
+ return seg_sents, character_mention_poses, name_list_index
132
+
133
+
134
+ def create_css(seg_sents, candidate_mention_poses, ws=10):
135
+ """
136
+ Create candidate-specific segments for each candidate in an instance.
137
+ """
138
+ # assert len(seg_sents) == ws * 2 + 1
139
+
140
+ many_css = []
141
+ many_sent_char_lens = []
142
+ many_mention_poses = []
143
+ many_quote_idxes = []
144
+ many_cut_css = []
145
+
146
+ for candidate_idx in candidate_mention_poses.keys():
147
+ nearest_pos = NML(seg_sents, candidate_mention_poses[candidate_idx], ws)
148
+
149
+ if nearest_pos[0] <= ws:
150
+ CSS = copy.deepcopy(seg_sents[nearest_pos[0]:ws + 1])
151
+ mention_pos = [0, nearest_pos[1]]
152
+ quote_idx = ws - nearest_pos[0]
153
+ else:
154
+ CSS = copy.deepcopy(seg_sents[ws:nearest_pos[0] + 1])
155
+ mention_pos = [nearest_pos[0] - ws, nearest_pos[1]]
156
+ quote_idx = 0
157
+
158
+ cut_CSS, mention_pos = max_len_cut(CSS, mention_pos, 510)
159
+ sent_char_lens = [sum(len(word) for word in sent) for sent in cut_CSS]
160
+
161
+ mention_pos_left = sum(sent_char_lens[:mention_pos[0]]) + sum(
162
+ len(x) for x in cut_CSS[mention_pos[0]][:mention_pos[1]])
163
+ mention_pos_right = mention_pos_left + len(cut_CSS[mention_pos[0]][mention_pos[1]])
164
+ mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right, mention_pos[1])
165
+ cat_CSS = ' '.join([' '.join(sent) for sent in cut_CSS])
166
+
167
+ many_css.append(cat_CSS)
168
+ many_sent_char_lens.append(sent_char_lens)
169
+ many_mention_poses.append(mention_pos)
170
+ many_quote_idxes.append(quote_idx)
171
+ many_cut_css.append(cut_CSS)
172
+
173
+ return many_css, many_sent_char_lens, many_mention_poses, many_quote_idxes, many_cut_css
174
+
175
+
176
+ def input_data_loader(instances: list, alias2id) -> DataLoader:
177
+ """
178
+ ๋‚˜๋ˆ ์ง„ ๋ฐ์ดํ„ฐ๋ฅผ ๋งž์ถ”๊ธฐ ์œ„ํ•ด ๊ฐ€๊ณต
179
+ """
180
+ data_list = []
181
+
182
+ for instance in instances:
183
+ seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
184
+ instance, alias2id)
185
+ css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_css(
186
+ seg_sents, candidate_mention_poses)
187
+
188
+ data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
189
+ cut_css, name_list_index))
190
+
191
+ data_loader = DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
192
+
193
+ return data_loader
194
+
195
+
196
+ def make_ner_input(text, chunk_size=500) -> list:
197
+ """
198
+ ๋ฌธ์žฅ์„ New Lines ๊ธฐ์ค€์œผ๋กœ ๋‚˜๋ˆ„์–ด ์ค๋‹ˆ๋‹ค.
199
+ chunk size๋ณด๋‹ค ๋ฌธ์žฅ์ด ๊ธธ ๊ฒฝ์šฐ, ๋งˆ์ง€๋ง‰ ๋ฌธ์žฅ์€ ๋’ค์—์„œ chunk size ๋งŒํผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
200
+ """
201
+ count_text = chunk_size
202
+ max_text = len(text)
203
+ newline_position = []
204
+
205
+ while count_text < max_text:
206
+ sentence = text[:count_text]
207
+ last_newline_position = sentence.rfind('\n')
208
+ newline_position.append(last_newline_position)
209
+ count_text = last_newline_position + chunk_size
210
+
211
+ split_sentences = []
212
+ start_num = 0
213
+
214
+ for _, num in enumerate(newline_position):
215
+ split_sentences.append(text[start_num:num])
216
+ start_num = num
217
+
218
+ if max_text % chunk_size != 0:
219
+ f_sentence = text[max_text-500:]
220
+ first_newline_position = max_text-500 + f_sentence.find('\n')
221
+ split_sentences.append(text[first_newline_position:])
222
+
223
+ return split_sentences
224
+
225
+
226
+ def making_script(text, speaker:list, instance_num:list) -> str:
227
+ """
228
+ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜
229
+ """
230
+ lines = text.splitlines()
231
+ for num, people in zip(instance_num, speaker):
232
+ lines[num] = f'{people}: {lines[num]}'
233
+ return lines
utils/load_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ๋ชจ๋ธ๋“ค ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ชจ๋“ˆ
3
+ """
4
+ import torch
5
+ # from .load_model import KCSN
6
+ # from .arguments import get_train_args
7
+
8
+
9
+ # args = get_train_args()
10
+
11
+ def load_ner(path ='model/NER.pth'):
12
+ """
13
+ NER ๋ชจ๋ธ
14
+ """
15
+ checkpoint = torch.load(path)
16
+ model = checkpoint['model']
17
+ model.load_state_dict(checkpoint['model_state_dict'])
18
+
19
+ return model, checkpoint
20
+
21
+
22
+ # def load_fs(path = 'model/FS.pth'):
23
+ # """
24
+ # Find Speaker ๋ชจ๋ธ
25
+ # """
26
+ # model = KCSN(args)
27
+ # checkpoint = torch.load(path)
28
+ # model.load_state_dict(checkpoint['model_state_dict'])
29
+ # return model, checkpoint
utils/ner_utils.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NER ๋ชจ๋ธ์„ ์ด์šฉํ•˜์—ฌ ์ž‘์—…ํ•˜๋Š” ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
3
+ """
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ from collections import Counter
8
+
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+
11
+ def ner_tokenizer(text, max_seq_length, checkpoint):
12
+ """
13
+ NER์„ ์œ„ํ•ด ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
14
+ Args:
15
+ sent: ์ฒ˜๋ฆฌํ•˜๊ณ ์ž ํ•˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅ๋ฐ›์Šต๋‹ˆ๋‹ค.
16
+ max_seq_length: BERT์˜ config์—์„œ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•œ ์ตœ๋Œ€ ๋ฌธ์ž์—ด ๊ธธ์ด๋Š” 512์ž…๋‹ˆ๋‹ค. ์ตœ๋Œ€ ๊ธธ์ด๋ฅผ ๋„˜์–ด์„œ์ง€ ์•Š๋„๋ก, ํ…์ŠคํŠธ ๊ธธ์ด๊ฐ€ 512๋ฅผ ๋„˜์–ด๊ฐˆ ๊ฒฝ์šฐ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ฌธ์ž์—ด๋กœ ๋ถ„๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
17
+ ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ๊ณ ๋ คํ•˜๋ฏ€๋กœ ๊ฐ€๋Šฅํ•œ ๊ธด ๊ธธ์ด๋กœ chunkingํ•˜๋Š” ๊ฒƒ์ด ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
18
+ checkpoint: NER ๋ชจ๋ธ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋ถˆ๋Ÿฌ๋“ค์ž…๋‹ˆ๋‹ค.
19
+ Return:
20
+ ner_tokenizer_dict: ์•„๋ž˜ ์„ธ ์š”์†Œ๋ฅผ ํฌํ•จํ•œ ๋”•์…”๋„ˆ๋ฆฌ์ž…๋‹ˆ๋‹ค.
21
+ input_ids: ๊ฐ ํ† ํฐ์˜ ๋ชจ๋ธ ๋”•์…”๋„ˆ๋ฆฌ์—์„œ์˜ ์•„์ด๋””๊ฐ’์ž…๋‹ˆ๋‹ค.
22
+ attention_mask: ๊ฐ ํ† ํฐ์˜ ์–ดํƒ ์…˜ ๋งˆ์Šคํฌ ํ™œ์„ฑํ™” ์—ฌ๋ถ€์ž…๋‹ˆ๋‹ค.
23
+ token_type_ids: ๊ฐœ์ฒด๋ช… ์ธ์‹ ๋œ ํ† ํฐ์˜ ๊ฒฝ์šฐ ๊ทธ ํƒ€์ž…์˜ ์•„์ด๋””(์ˆซ์ž ์กฐํ•ฉ)๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
24
+ """
25
+ #์ €์žฅ๋œ ๋ชจ๋ธ์˜ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
26
+ tokenizer = checkpoint['tokenizer']
27
+
28
+ #๊ฐ๊ฐ ํŒจ๋”ฉ, ๋ฌธ์žฅ ์‹œ์ž‘, ๋ฌธ์žฅ ๋์„ ๋‚˜ํƒ€๋‚ด๋Š” ํŠน๋ณ„ํ•œ ํ† ํฐ๋“ค์˜ ID ๊ฐ’๋“ค์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
29
+ pad_token_id = tokenizer.pad_token_id
30
+ cls_token_id = tokenizer.cls_token_id
31
+ sep_token_id = tokenizer.sep_token_id
32
+
33
+ #์ด์ „ ์Œ์ ˆ์„ ์ €์žฅํ•˜๋Š” ๋ณ€์ˆ˜๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
34
+ pre_syllable = "_"
35
+
36
+ #ํ† ํฌ๋‚˜์ด์ง•๋œ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ๋ฆฌ์ŠคํŠธ๋“ค์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
37
+ input_ids = [pad_token_id] * (max_seq_length - 1)
38
+ attention_mask = [0] * (max_seq_length - 1)
39
+ token_type_ids = [0] * max_seq_length
40
+
41
+ #์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด์— ๋งž๊ฒŒ ์ž˜๋ผ๋ƒ…๋‹ˆ๋‹ค.
42
+ text = text[:max_seq_length-2]
43
+
44
+ #ํ…์ŠคํŠธ์˜ ๊ฐ ์Œ์ ˆ์— ๋Œ€ํ•ด ๋ฐ˜๋ณต๋ฌธ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
45
+ for i, syllable in enumerate(text):
46
+ if syllable == '_':
47
+ pre_syllable = syllable
48
+ if pre_syllable != "_":
49
+ syllable = '##' + syllable
50
+ pre_syllable = syllable
51
+
52
+ #ํ† ํฐ์„ ๋ชจ๋ธ์˜ ๋‹จ์–ด ์‚ฌ์ „์— ์žˆ๋Š” ID ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ input_ids ๋ฆฌ์ŠคํŠธ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
53
+ input_ids[i] = tokenizer.convert_tokens_to_ids(syllable)
54
+ #ํ•ด๋‹น ์œ„์น˜์˜ ํ† ํฐ์— ๋Œ€ํ•œ ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋ฅผ ํ™œ์„ฑํ™”ํ•ฉ๋‹ˆ๋‹ค.
55
+ attention_mask[i] = 1
56
+
57
+ #์ž…๋ ฅ ์‹œํ€€์Šค์˜ ์‹œ์ž‘์—๋Š” cls_token_id๋ฅผ, ๋์—๋Š” sep_token_id๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
58
+ input_ids = [cls_token_id] + input_ids[:-1] + [sep_token_id]
59
+ #์–ดํ…์…˜ ๋งˆ์Šคํฌ๋„ ์‹œ์ž‘๊ณผ ๋ ํ† ํฐ์„ ๊ณ ๋ คํ•˜์—ฌ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค.
60
+ attention_mask = [1] + attention_mask[:-1] + [1]
61
+
62
+ ner_tokenizer_dict = {"input_ids": input_ids,
63
+ "attention_mask": attention_mask,
64
+ "token_type_ids": token_type_ids}
65
+
66
+ return ner_tokenizer_dict
67
+
68
+ def get_ner_predictions(text, checkpoint):
69
+ """
70
+ ํ† ํฐํ™”ํ•œ ๋ฌธ์žฅ(tokenized_sent)๊ณผ ์˜ˆ์ธกํ•œ ํƒœ๊ทธ(pred_tags) ๊ฐ’์„ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
71
+ Args:
72
+ text: NER ์˜ˆ์ธก์„ ํ•„์š”๋กœ ํ•˜๋Š” ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•ฉ๋‹ˆ๋‹ค.
73
+ checkpoint: ์ €์žฅํ•œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ๋“ค์ž…๋‹ˆ๋‹ค.
74
+ Returns:
75
+ tokenized_sent: ๋ชจ๋ธ ์ž…๋ ฅ์„ ์œ„ํ•œ ํ† ํฐํ™”๋œ ๋ฌธ์žฅ ์ •๋ณด์ž…๋‹ˆ๋‹ค.
76
+ pred_tags: ๊ฐ ํ† ํฐ์— ๋Œ€ํ•œ ์˜ˆ์ธก๋œ ํƒœ๊ทธ๋“ค์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
77
+ """
78
+ #์ €์žฅํ•œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ๋“ค์ž…๋‹ˆ๋‹ค.
79
+ model = checkpoint['model']
80
+ #ํƒœ๊ทธ์™€ ํ•ด๋‹น ํƒœ๊ทธ์˜ ID ๋งคํ•‘ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
81
+ tag2id = checkpoint['tag2id']
82
+ model.to(device)
83
+ #์ž…๋ ฅ๋œ ํ…์ŠคํŠธ์—์„œ ๊ณต๋ฐฑ์„ ์–ธ๋”์Šค์ฝ”์–ด(_)๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.
84
+ text = text.replace(' ', '_')
85
+
86
+ #์˜ˆ์ธก๊ฐ’๊ณผ ์‹ค์ œ ๋ผ๋ฒจ์„ ์ €์žฅํ•  ๋นˆ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
87
+ predictions, true_labels = [], []
88
+
89
+ #ner_tokenizer ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
90
+ tokenized_sent = ner_tokenizer(text, len(text) + 2, checkpoint)
91
+
92
+ #ํ† ํฐํ™”๋œ ๊ฒฐ๊ณผ๋ฅผ ํ† ๋Œ€๋กœ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์— ๋งž๊ฒŒ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
93
+ input_ids = torch.tensor(
94
+ tokenized_sent['input_ids']).unsqueeze(0).to(device)
95
+ attention_mask = torch.tensor(
96
+ tokenized_sent['attention_mask']).unsqueeze(0).to(device)
97
+ token_type_ids = torch.tensor(
98
+ tokenized_sent['token_type_ids']).unsqueeze(0).to(device)
99
+
100
+ #๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๊ธฐ ์œ„ํ•ด torch.no_grad() ์ปจํ…์ŠคํŠธ ๋‚ด์—์„œ ๋‹ค์Œ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. (eval ์˜์—ญ์ด๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต์„ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค)
101
+ with torch.no_grad():
102
+ outputs = model(
103
+ input_ids=input_ids,
104
+ attention_mask=attention_mask,
105
+ token_type_ids=token_type_ids)
106
+
107
+ #๋ชจ๋ธ ์ถœ๋ ฅ์—์„œ ๋กœ์ง“ ๊ฐ’์„ ๊ฐ€์ ธ์™€ Numpy๊ฐ’์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ๋ผ๋ฒจ ID๋“ค์„ CPU ์ƒ์˜ NumPy ๋ฐฐ์—ด๋กœ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
108
+ logits = outputs['logits']
109
+ logits = logits.detach().cpu().numpy()
110
+ label_ids = token_type_ids.cpu().numpy()
111
+
112
+ #์˜ˆ์ธก๋œ ๋ผ๋ฒจ ๊ฐ’์„ ๊ฐ€์ ธ์™€์„œ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
113
+ predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
114
+ #์‹ค์ œ ๋ผ๋ฒจ์„ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
115
+ true_labels.append(label_ids)
116
+
117
+ #์˜ˆ์ธก๋œ ๋ผ๋ฒจ ID๋ฅผ ์‹ค์ œ ํƒœ๊ทธ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
118
+ pred_tags = [list(tag2id.keys())[p_i] for p in predictions for p_i in p]
119
+
120
+ return tokenized_sent, pred_tags
121
+
122
+
123
+ def ner_inference(tokenized_sent, pred_tags, checkpoint, name_len=5) -> list:
124
+ """
125
+ NER์„ ์‹คํ–‰ํ•˜๊ณ , ์ด๋ฆ„๊ณผ ์‹œ๊ฐ„ ๋ฐ ๊ณต๊ฐ„ ์ •๋ณด๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
126
+ Args:
127
+ tokenized_sent: ํ† ํฐํ™”๋œ ๋ฌธ์žฅ์ด ์ €์žฅ๋œ ๋ฆฌ์ŠคํŠธ
128
+ pred_tags: ๊ฐ ํ† ํฐ์— ๋Œ€ํ•œ ์˜ˆ์ธก ํƒœ๊ทธ๊ฐ’ (NER ๊ฒฐ๊ณผ)
129
+ checkpoint: ์ €์žฅํ•ด๋‘” ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ด
130
+ name_len: ๋” ์ •ํ™•ํ•œ ์ด๋ฆ„ ์ธ์‹์„ ์œ„ํ•ด ์•ž๋’ค๋กœ ๋ช‡ ๊ฐœ์˜ ์Œ์ ˆ์„ ๋” ๊ฒ€ํ† ํ• ์ง€ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
131
+ Returns:
132
+ namelist: ์ถ”์ถœํ•œ ์ด๋ฆ„(๋ณ„์นญ ํฌํ•จ) ๋ฆฌ์ŠคํŠธ์ž…๋‹ˆ๋‹ค. ํ›„์ฒ˜๋ฆฌ๋ฅผ ํ†ตํ•ด
133
+ scene: ์ถ”์ถœํ•œ ์žฅ์†Œ ์‹œ๊ฐ„ ์‚ฌ์ „์ž…๋‹ˆ๋‹ค.
134
+ """
135
+ name_list = []
136
+ speaker = ''
137
+ tokenizer = checkpoint['tokenizer']
138
+ scene = {'์žฅ์†Œ': [], '์‹œ๊ฐ„': []}
139
+ target = ''
140
+ c_tag = None
141
+
142
+ for i, tag in enumerate(pred_tags):
143
+ token = tokenizer.convert_ids_to_tokens(tokenized_sent['input_ids'][i]).replace('#', '')
144
+ if 'PER' in tag:
145
+ if 'B' in tag and speaker != '':
146
+ name_list.append(speaker)
147
+ speaker = ''
148
+ speaker += token
149
+
150
+ elif speaker != '' and tag != pred_tags[i-1]:
151
+ if speaker in name_list:
152
+ name_list.append(speaker)
153
+ else:
154
+ tmp = speaker
155
+ found_name = False
156
+ # print(f'{speaker}์— ์˜๋ฌธ์ด ์ƒ๊ฒจ ํ™•์ธํ•ด๋ด…๋‹ˆ๋‹ค.')
157
+ for j in range(name_len):
158
+ if i + j < len(tokenized_sent['input_ids']):
159
+ token = tokenizer.convert_ids_to_tokens(
160
+ tokenized_sent['input_ids'][i+j]).replace('#', '')
161
+ tmp += token
162
+ # print(f'{speaker} ๋’ค๋กœ ๋‚˜์˜จ {j} ๋ฒˆ์งธ ๊นŒ์ง€ ํ™•์ธํ•œ๊ฒฐ๊ณผ, {tmp} ์ž…๋‹ˆ๋‹ค')
163
+ if tmp in name_list:
164
+ name_list.append(tmp)
165
+ found_name = True
166
+ # print(f'๋ช…๋‹จ์— {tmp} ๊ฐ€ ์กด์žฌํ•˜์—ฌ, {speaker} ๋Œ€์‹  ์ถ”๊ฐ€ํ•˜์˜€์Šต๋‹ˆ๋‹ค.')
167
+ break
168
+
169
+ if not found_name:
170
+ name_list.append(speaker)
171
+ # print(f'์ฐพ์ง€ ๋ชปํ•˜์—ฌ {speaker} ๋ฅผ ์ถ”๊ฐ€ํ•˜์˜€์Šต๋‹ˆ๋‹ค.')
172
+ speaker = ''
173
+
174
+ elif tag != 'O':
175
+ if tag.startswith('B'):
176
+ if c_tag in ['TIM', 'DAT']:
177
+ scene['์‹œ๊ฐ„'].append(target)
178
+ elif c_tag =='LOC':
179
+ scene['์žฅ์†Œ'].append(target)
180
+ c_tag = tag[2:]
181
+ target = token
182
+ else:
183
+ target += token.replace('_', ' ')
184
+
185
+ return name_list, scene
186
+
187
+
188
+ def make_name_list(ner_inputs, checkpoint):
189
+ """
190
+ ๋ฌธ์žฅ๋“ค์„ NER ๋Œ๋ ค์„œ Name List ๋งŒ๋“ค๊ธฐ.
191
+ """
192
+ name_list = []
193
+ times = []
194
+ places = []
195
+
196
+ for ner_input in ner_inputs:
197
+ tokenized_sent, pred_tags = get_ner_predictions(ner_input, checkpoint)
198
+ names, scene = ner_inference(tokenized_sent, pred_tags, checkpoint)
199
+ name_list.extend(names)
200
+ times.extend(scene['์‹œ๊ฐ„'])
201
+ places.extend(scene['์žฅ์†Œ'])
202
+
203
+ return name_list, times, places
204
+
205
+
206
+ def show_name_list(name_list):
207
+ """
208
+ ์‚ฌ์šฉ์ž ์นœํ™”์ ์œผ๋กœ ๋„ค์ž„๋ฆฌ์ŠคํŠธ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.
209
+ Arg:
210
+ name_list: ์ถ”์ถœํ•œ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ
211
+ Return:
212
+ name: ๋™์ผํ•œ ์ด๋ฆ„์ด ๋ช‡ ๋ฒˆ ๋“ฑ์žฅํ–ˆ๋Š”์ง€ ํšŸ์ˆ˜๋ฅผ ํ•จ๊ป˜ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
213
+ """
214
+ name = dict(Counter(name_list))
215
+
216
+ return name
217
+
218
+
219
+ def compare_strings(str1, str2):
220
+ """
221
+ ner๋กœ ์ถ”์ถœํ•œ ์ธ๋ช…์„ ํ›„์ฒ˜๋ฆฌํ•˜๋Š” ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
222
+ ๋น„๊ตํ•  ๋‘ ๋ฌธ์ž์—ด์˜ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅผ ๊ฒฝ์šฐ, ๋” ์งง์€ ๋ฌธ์ž์—ด์ด ๋” ๊ธด ๋ฌธ์ž์—ด์— ํฌํ•จ๋˜๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
223
+ ๋น„๊ตํ•  ๋‘ ๋ฌธ์ž์—ด์˜ ๊ธธ์ด๊ฐ€ ๊ฐ™์„ ๊ฒฝ์šฐ, ๊ฒน์น˜๋Š” ๋ถ€๋ถ„์ด 2๊ธ€์ž ์ด์ƒ์ผ ๊ฒฝ์šฐ ๊ฐ™์€ ์ด๋ฆ„์œผ๋กœ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
224
+ ์ด ํ•จ์ˆ˜์™€ ์•„๋ž˜์˜ combine_similar_names๋ฅผ ํ•จ๊ป˜ ์‹คํ–‰ํ•˜๋ฉด, 'ํ•œ๋‹ค์ •'๊ณผ '๋‹ค์ •์ด', '๋‹ค์ •์ด๊ฐ€' ๋“ฑ์€ ๋ชจ๋‘ ํ•˜๋‚˜์˜ ์ธ๋ฌผ๋กœ ๋ฌถ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
225
+
226
+ Args: ๋น„๊ตํ•˜๋ ค๋Š” ๋‘ ๋ฌธ์ž์—ด
227
+ Return: ๋‘ ๋ฌธ์ž์—ด์ด ๊ฐ™์€ ์ด๋ฆ„์œผ๋กœ ํŒ๋‹จ๋  ๊ฒฝ์šฐ True, ์•„๋‹ ๊ฒฝ์šฐ False
228
+ """
229
+ if len(str1) != len(str2):
230
+ # ๋” ์งง์€ ๋ฌธ์ž์—ด์ด ๋” ๊ธด ๋ฌธ์ž์—ด์— ํฌํ•จ๋˜๋Š”์ง€ ํ™•์ธ
231
+ shorter, longer = (str1, str2) if len(str1) < len(str2) else (str2, str1)
232
+ if shorter in longer:
233
+ return True
234
+ else:
235
+ same_part = []
236
+ for i in range(len(str1)):
237
+ if str1[i] in str2:
238
+ same_part += str1[i]
239
+ continue
240
+ else:
241
+ break
242
+ if len(same_part) >= 2:
243
+ return True
244
+
245
+ return False
246
+
247
+ def combine_similar_names(names_dict):
248
+ """
249
+ compare_strings ํ•จ์ˆ˜๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์œ ์‚ฌํ•œ ์ด๋ฆ„์„ ํ•จ๊ป˜ ๋ฌถ์Šต๋‹ˆ๋‹ค.
250
+ 2๊ธ€์ž๋Š” ์ด๋ฆ„์ผ ํ™•๋ฅ ์ด ๋†’์œผ๋‹ˆ ๊ธฐ์ค€์ ์œผ๋กœ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
251
+ """
252
+ names = names_dict.keys()
253
+ similar_groups = [[name] for name in names if len(name) == 2]
254
+ idx = 0
255
+ # print(similar_groups, '\n',idx)
256
+
257
+ for name in names:
258
+ found = False
259
+ for group in similar_groups:
260
+ idx += 1
261
+ for item in group:
262
+ if compare_strings(name, item) and len(name)>1:
263
+ found = True
264
+ cleaned_text = re.sub(r'(์•„|์ด)$', '', item)
265
+ if len(name) == len(item):
266
+ same_part = ''
267
+ # ์™„์ „ํžˆ ์ผ์น˜ํ•˜๋Š” ๋ถ€๋ถ„์ด ์žˆ๋Š”์ง€ ํ™•์ธ
268
+ for i in range(len(name)):
269
+ if name[i] in item:
270
+ same_part += name[i]
271
+ if same_part not in group and cleaned_text not in group:
272
+ group.append(cleaned_text)
273
+ # print(similar_groups, '\n',idx, '๋ฌธ์ž์—ด์˜ ๊ธธ์ด๊ฐ€ ๊ฐ™์„ ๋•Œ')
274
+ else:
275
+ group.append(name)
276
+ # print(similar_groups, '\n',idx, '๋ฌธ์ž์—ด์˜ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅผ ๋•Œ')
277
+ break
278
+ if found:
279
+ break
280
+ if not found:
281
+ similar_groups.append([name])
282
+
283
+ updated_names = {tuple(name for name in group if len(name) > 1): counts for group, counts in (
284
+ (group, sum(names_dict[name] for name in group if name != '')) for group in similar_groups)
285
+ if len([name for name in group if len(name) > 1]) > 0}
286
+
287
+ return updated_names
288
+
289
+ def convert_name2codename(codename2name, text):
290
+ """RE๋ฅผ ์ด์šฉํ•˜์—ฌ ์ด๋ฆ„์„ ์ฝ”๋“œ๋„ค์ž„์œผ๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ ๊ฐ ์ฝ”๋“œ๋„ค์ž„์˜ ๋ฒˆํ˜ธ๋Š” ๋นˆ๋„์ˆ˜ ๊ธฐ์ค€ ๋‚ด๋ฆผ์ฐจ์ˆœํ•œ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค."""
291
+ import re
292
+ for n_list in codename2name.values():
293
+ n_list.sort(key=lambda x:(len(x), x), reverse=True)
294
+
295
+ for codename, n_list in codename2name.items():
296
+ for subname in n_list:
297
+ text = re.sub(subname, codename, text)
298
+
299
+ return text
300
+
301
+
302
+ def convert_codename2name(codename2name, text):
303
+ """์ฝ”๋“œ๋„ค์ž„์„ ์ด๋ฆ„์œผ๋กœ ๋ณ€๊ฒฝํ•ด์ค๋‹ˆ๋‹ค."""
304
+ outputs = []
305
+ for i in text:
306
+ try:
307
+ outputs.append(codename2name[i][0])
308
+ except:
309
+ outputs.append('์•Œ ์ˆ˜ ์—†์Œ')
310
+
311
+ return outputs
utils/train_model.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author:
3
+ """
4
+ import re
5
+ import torch.nn as nn
6
+ import torch.nn.functional as functional
7
+ import torch
8
+ from transformers import AutoModel
9
+ import torch.autograd as autograd
10
+
11
+ def get_nonlinear(nonlinear):
12
+ """
13
+ Activation function.
14
+ """
15
+ nonlinear_dict = {'relu': nn.ReLU(), 'tanh': nn.Tanh(),
16
+ 'sigmoid': nn.Sigmoid(), 'softmax': nn.Softmax(dim=-1)}
17
+ try:
18
+ return nonlinear_dict[nonlinear]
19
+ except:
20
+ raise ValueError('not a valid nonlinear type!')
21
+
22
+
23
+ class SeqPooling(nn.Module):
24
+ """
25
+ Sequence pooling module.
26
+
27
+ Can do max-pooling, mean-pooling and attentive-pooling on a list of sequences of different lengths.
28
+ """
29
+ def __init__(self, pooling_type, hidden_dim):
30
+ super(SeqPooling, self).__init__()
31
+ self.pooling_type = pooling_type
32
+ self.hidden_dim = hidden_dim
33
+ if pooling_type == 'attentive_pooling':
34
+ self.query_vec = nn.parameter.Parameter(torch.randn(hidden_dim))
35
+
36
+ def max_pool(self, seq):
37
+ return seq.max(0)[0]
38
+
39
+ def mean_pool(self, seq):
40
+ return seq.mean(0)
41
+
42
+ def attn_pool(self, seq):
43
+ attn_score = torch.mm(seq, self.query_vec.view(-1, 1)).view(-1)
44
+ attn_w = nn.Softmax(dim=0)(attn_score)
45
+ weighted_sum = torch.mm(attn_w.view(1, -1), seq).view(-1)
46
+ return weighted_sum
47
+
48
+ def forward(self, batch_seq):
49
+ pooling_fn = {'max_pooling': self.max_pool,
50
+ 'mean_pooling': self.mean_pool,
51
+ 'attentive_pooling': self.attn_pool}
52
+ pooled_seq = [pooling_fn[self.pooling_type](seq) for seq in batch_seq]
53
+ return torch.stack(pooled_seq, dim=0)
54
+
55
+
56
+ class MLP_Scorer(nn.Module):
57
+ """
58
+ MLP scorer module.
59
+
60
+ A perceptron with two layers.
61
+ """
62
+ def __init__(self, args, classifier_input_size):
63
+ super(MLP_Scorer, self).__init__()
64
+ self.scorer = nn.ModuleList()
65
+ self.scorer.append(nn.Linear(classifier_input_size, args.classifier_intermediate_dim))
66
+ self.scorer.append(nn.Linear(args.classifier_intermediate_dim, 1))
67
+ self.nonlinear = get_nonlinear(args.nonlinear_type)
68
+
69
+ def forward(self, x):
70
+ for model in self.scorer:
71
+ x = self.nonlinear(model(x))
72
+ return x
73
+
74
+
75
+ class KCSN(nn.Module):
76
+ """
77
+ Candidate Scoring Network.
78
+
79
+ It's built on BERT with an MLP and other simple components.
80
+ """
81
+ def __init__(self, args):
82
+ super(KCSN, self).__init__()
83
+ self.args = args
84
+ self.bert_model = AutoModel.from_pretrained(args.bert_pretrained_dir)
85
+ self.pooling = SeqPooling(args.pooling_type, self.bert_model.config.hidden_size)
86
+ self.mlp_scorer = MLP_Scorer(args, self.bert_model.config.hidden_size * 3)
87
+ self.dropout = nn.Dropout(args.dropout)
88
+
89
+ def forward(self, features, sent_char_lens, mention_poses, quote_idxes, true_index, device, tokens_list, cut_css):
90
+ # encoding
91
+ qs_hid = []
92
+ ctx_hid = []
93
+ cdd_hid = []
94
+
95
+ unk_loc_li = []
96
+ unk_loc = 0
97
+
98
+ for i, (cdd_sent_char_lens, cdd_mention_pos, cdd_quote_idx) in enumerate(
99
+ zip(sent_char_lens, mention_poses, quote_idxes)):
100
+ unk_loc = unk_loc + 1
101
+
102
+ bert_output = self.bert_model(
103
+ torch.tensor([features[i].input_ids], dtype=torch.long).to(device),
104
+ token_type_ids=None,
105
+ attention_mask=torch.tensor([features[i].input_mask], dtype=torch.long).to(device)
106
+ )
107
+
108
+ modified_list = [s.replace('#', '') for s in tokens_list[i]]
109
+ cnt = 1
110
+ verify = 0
111
+ num_check = 0
112
+ num_vid = -999
113
+ accum_char_len = [0]
114
+
115
+ for idx, txt in enumerate(cut_css[i]):
116
+ result_string = ''.join(txt)
117
+ replace_dict = {']': r'\]', '[': r'\[', '?': r'\?', '-': r'\-', '!': r'\!'}
118
+ string_processing = result_string[-7:].translate(str.maketrans(replace_dict))
119
+ pattern = re.compile(rf'[{string_processing}]')
120
+ cnt = 1
121
+
122
+ if num_check == 1000:
123
+ accum_char_len.append(num_vid)
124
+
125
+ num_check = 1000
126
+
127
+ for string in modified_list:
128
+ string_nospace = string.replace(' ','')
129
+ if len(accum_char_len) > idx + 1:
130
+ continue
131
+
132
+ for letter in string_nospace:
133
+ match_result = pattern.match(letter)
134
+ if match_result:
135
+ verify += 1
136
+ if verify == len(result_string[-7:]):
137
+ if cnt > accum_char_len[-1]:
138
+ accum_char_len.append(cnt)
139
+ verify = 0
140
+ num_check = len(accum_char_len)
141
+ else:
142
+ verify = 0
143
+ cnt = cnt + 1
144
+
145
+ if num_check == 1000:
146
+ accum_char_len.append(num_vid)
147
+
148
+ if -999 in accum_char_len:
149
+ unk_loc_li.append(unk_loc)
150
+ continue
151
+
152
+ CSS_hid = bert_output['last_hidden_state'][0][1:sum(cdd_sent_char_lens) + 1].to(device)
153
+
154
+ qs_hid.append(CSS_hid[accum_char_len[cdd_quote_idx]:accum_char_len[cdd_quote_idx + 1]])
155
+
156
+ ## ๋ฐœํ™”์ž ๋ถ€๋ถ„ ์ฐพ์•„์„œ - bert tokenizer ๋œ ๋ถ€๋ถ„์„ ์ธ๋ฑ์‹ฑ ํ•˜๋Š” ๋ถ€๋ถ„
157
+ cnt = 1
158
+ cdd_mention_pos_bert_li = []
159
+ cdd_mention_pos_unk = []
160
+ name = cut_css[i][cdd_mention_pos[0]][cdd_mention_pos[3]]
161
+
162
+ # extract only name
163
+ # ์ด๋ฆ„๋งŒ ์ถ”์ถœ
164
+ cdd_pattern = re.compile(r'&C[0-5][0-9]&')
165
+ name_process = cdd_pattern.search(name)
166
+
167
+ # find candidate location in bert output
168
+ # ๋ฒ„ํŠธ ๊ฒฐ๊ณผ์—์„œ ๋ฐœํ™”์ž ์œ„์น˜๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค
169
+ pattern_unk = re.compile(r'[\[UNK\]]')
170
+
171
+ # ์ด ๋ถ€๋ถ„์€ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ๊ฒŒ ๋˜๋ฉด, ๋” ์ด์ƒ ๋„˜์–ด๊ฐ€์ง€ ์•Š๋„๋ก ํ•˜๋Š” ์ฝ”๋“œ ์ž…๋‹ˆ๋‹ค.
172
+ if len(accum_char_len) < cdd_mention_pos[0]+1:
173
+ maxx_len = accum_char_len[len(accum_char_len)-1]
174
+ elif len(accum_char_len) == cdd_mention_pos[0]+1:
175
+ maxx_len = accum_char_len[-1] + 1000
176
+ else:
177
+ maxx_len = accum_char_len[cdd_mention_pos[0]+1]
178
+
179
+ # ํฌํ•จ๋˜๋Š” ๋ฐœํ™”์ž๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด.
180
+ start_name = None
181
+ name_match = '&'
182
+ for string in modified_list:
183
+ string_nospace = string.replace(' ','')
184
+ for letter in string_nospace:
185
+ match_result_unk = pattern_unk.match(letter)
186
+ if match_result_unk:
187
+ cdd_mention_pos_unk.append(cnt)
188
+ if start_name is True:
189
+ name_match += letter
190
+ if (name_match == name_process.group(0) or letter == '&') and len(
191
+ cdd_mention_pos_bert_li) < 3 and maxx_len > cnt >= accum_char_len[
192
+ cdd_mention_pos[0]]: # ๋งŒ์•ฝ & ๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์„ ๊ฒฝ์šฐ์— ์‚ฌ๋žŒ์œผ๋กœ ์ถ”์ถœ
193
+ start_name = True # ๋งค์นญ์ด ๋˜๋ฉด, 1์„ ๋”ํ•ฉ๋‹ˆ๋‹ค.
194
+ if len(cdd_mention_pos_bert_li) == 1 and name_match != name_process.group(0): # ๋งŒ์•ฝ &๊ฐ€ ๋‘๋ฒˆ์งธ๋กœ ๋‚˜์˜ค๊ณ , ๋งค์นญ์ด ์•ˆ๋  ๊ฒฝ์šฐ
195
+ start_name = None
196
+ name_match = '&'
197
+ cdd_mention_pos_bert_li = []
198
+ elif name_match == name_process.group(0): # ๋‘๋ฒˆ์งธ ์ถ”๊ฐ€
199
+ cdd_mention_pos_bert_li.append(cnt)
200
+ start_name = None
201
+ name_match = '&'
202
+ else:
203
+ cdd_mention_pos_bert_li.append(cnt-1)
204
+ cnt += 1
205
+
206
+ if len(cdd_mention_pos_bert_li) == 0 & len(cdd_mention_pos_unk) != 0:
207
+ cdd_mention_pos_bert_li.extend([cdd_mention_pos_unk[0], cdd_mention_pos_unk[0]+1])
208
+ elif len(cdd_mention_pos_bert_li) != 2:
209
+ cdd_mention_pos_bert_li = []
210
+ cdd_mention_pos_bert_li.extend([int(cdd_mention_pos[1] * accum_char_len[-1]/sum(
211
+ cdd_sent_char_lens)), int(cdd_mention_pos[2] * accum_char_len[-1]/sum(
212
+ cdd_sent_char_lens))])
213
+ if cdd_mention_pos_bert_li[0] == cdd_mention_pos_bert_li[1]:
214
+ cdd_mention_pos_bert_li[1] = cdd_mention_pos_bert_li[1]+1
215
+
216
+ # ctx ๊ฒฐ์ •ํ•˜๋Š” ์ฝ”๋“œ. candidate ์ฃผ๋ณ€ ์ •๋ณด ์ถ”์ถœ
217
+ # ํ•˜๋‚˜์ผ ๊ฒฝ์šฐ์—๋Š” ์ „์ฒด ๋ถ€๋ถ„์„ ๊ฐ€์ ธ์˜จ๋‹ค.
218
+ if len(cdd_sent_char_lens) == 1:
219
+ ctx_hid.append(torch.zeros(1, CSS_hid.size(1)).to(device))
220
+
221
+ # ๋งŒ์•ฝ ์•ž์— ๋ฐœํ™”์ž๊ฐ€ ์žˆ์„ ๊ฒฝ์šฐ์—” ์•ž ๋ฌธ์žฅ๋ถ€ํ„ฐ, ๋งˆ์ง€๋ง‰(์ธ์šฉ๋ฌธ) ์ „๊นŒ์ง€ ๊ฐ€์ ธ์˜จ๋‹ค.
222
+ elif cdd_mention_pos[0] == 0:
223
+ ctx_hid.append(CSS_hid[:accum_char_len[-2]])
224
+
225
+ # ๋งˆ์ง€๋ง‰์œผ๋กœ ๋ฐœํ™”์ž๊ฐ€ ๋’ค์— ์žˆ์„ ๊ฒฝ์šฐ์—๋Š” ๋‘๋ฒˆ์งธ ๋ถ€ํ„ฐ ๋๊นŒ์ง€ ๊ฐ€์ ธ์˜จ๋‹ค.
226
+ else:
227
+ ctx_hid.append(CSS_hid[accum_char_len[1]:])
228
+
229
+ cdd_mention_pos_bert = (cdd_mention_pos[0], cdd_mention_pos_bert_li[0],
230
+ cdd_mention_pos_bert_li[1])
231
+ cdd_hid.append(CSS_hid[cdd_mention_pos_bert[1]:cdd_mention_pos_bert[2]])
232
+
233
+ # pooling
234
+ if not qs_hid:
235
+ scores = '1'
236
+ scores_false = 1
237
+ scores_true = 1
238
+ return scores, scores_false, scores_true
239
+
240
+ qs_rep = self.pooling(qs_hid).to(device)
241
+ ctx_rep = self.pooling(ctx_hid).to(device)
242
+ cdd_rep = self.pooling(cdd_hid).to(device)
243
+
244
+ # concatenate
245
+ feature_vector = torch.cat([qs_rep, ctx_rep, cdd_rep], dim=-1).to(device)
246
+
247
+ # dropout
248
+ feature_vector = self.dropout(feature_vector).to(device)
249
+
250
+ # scoring
251
+ scores = self.mlp_scorer(feature_vector).view(-1).to(device)
252
+
253
+ for i in unk_loc_li:
254
+ # ์ถ”๊ฐ€ํ•  ์›์†Œ
255
+ new_element = torch.tensor([-0.9000], requires_grad=True).to(device)
256
+ # ํŠน์ • ์ธ๋ฑ์Šค์— ์›์†Œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด torch.cat()๊ณผ ์Šฌ๋ผ์ด์‹ฑ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
257
+ index_to_insert = i - 1
258
+ scores = torch.cat((scores[:index_to_insert], new_element, scores[index_to_insert:]),
259
+ dim=0).to(device)
260
+
261
+ scores_false = [scores[i] for i in range(scores.size(0)) if i != true_index]
262
+ scores_true = [scores[true_index] for i in range(scores.size(0) - 1)]
263
+
264
+ return scores, scores_false, scores_true
web/confirm.html ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ko">
3
+
4
+ <head>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <meta charset="utf-8">
7
+ <title>Spakers in Text</title>
8
+ <link rel="stylesheet" href="{{ url_for('static', path='css/put.css') }}">
9
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
10
+ integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
11
+ </head>
12
+
13
+ <body>
14
+ <!-- jQuery ๋ฐ Bootstrap JavaScript ๋กœ๋“œ -->
15
+ <script src="https://code.jquery.com/jquery-3.6.4.min.js"
16
+ integrity="sha256-oP6HI/tZ1aS9sz3Jr4+6zqbc9BE/l6fLx+Vz2I+H/GL4ZiI/Z5L3hMv8w3yXdBi"
17
+ crossorigin="anonymous"></script>
18
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
19
+ integrity="sha384-3ziFjNlAXja/Yb0M7y2BmFvR3s09gRPbrCm0lF+SvL4uIboD5lv3U3BdD7dW7Y3"
20
+ crossorigin="anonymous"></script>
21
+
22
+ <!-- ๋ชจ๋‹ฌ ํ‘œ์‹œ ์Šคํฌ๋ฆฝํŠธ -->
23
+ <script>
24
+ // JavaScript ์ˆ˜์ •: ๋ฒ„ํŠผ์ด ํด๋ฆญ๋˜๋ฉด ๋ชจ๋‹ฌ ํ‘œ์‹œ
25
+ $(document).ready(function () {
26
+ $('#exampleModal').modal("show");
27
+ });
28
+
29
+ //์ถ”๊ฐ€ ๋ชจ๋‹ฌ ์—ด๊ธฐ ํ•จ์ˆ˜
30
+ function openAddItemModal() {
31
+ const addItemModal = createAddItemModal();
32
+ openModal(addItemModal);
33
+ }
34
+
35
+ //ํผ ์ œ์ถœ ์‹œ ํ•ญ๋ชฉ ์ถ”๊ฐ€ ํ•จ์ˆ˜
36
+ function addItem() {
37
+ event.preventDefault();
38
+ const newItem = document.getElementById('newItem').value;
39
+ itemList.push(newItem);
40
+ updateItemList();
41
+ closeModal('addItemModal');
42
+ }
43
+ </script>
44
+
45
+ <div class="background">
46
+ <div class="header"><a href="/">
47
+ <span class="title">Nouvel : Novel for you</span></a></div>
48
+ <div class="empty"></div>
49
+
50
+ <div class="box">
51
+ <div class="subtitle">
52
+ ๊ฐ์ง€ํ•œ ๋“ฑ์žฅ์ธ๋ฌผ ๋ช…๋‹จ
53
+ </div>
54
+ <div class="explain_box">
55
+ <span class="explain">์ž‘์—…์ด ์ง„ํ–‰ ์ค‘ ์ž…๋‹ˆ๋‹ค. ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค ์ฃผ์„ธ์š”.</span>
56
+ </div>
57
+ <div class="container mt-4">
58
+ <ol id="itemList" class="list-group">
59
+ <!-- ๋ชฉ๋ก์ด ์—ฌ๊ธฐ์— ๋™์ ์œผ๋กœ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค. -->
60
+ </ol>
61
+
62
+ <!-- ์ถ”๊ฐ€ ๋ฒ„ํŠผ
63
+ <button class="btn btn-primary mt-3" data-toggle="modal" data-target="#addItemModal">์ถ”๊ฐ€</button> -->
64
+
65
+ <form>
66
+ <input type="text" class="form-control" id="newItem">
67
+ <div class="modal-footer">
68
+ <button type="button" class="btn btn-secondary" data-dismiss="modal">์ทจ์†Œ</button>
69
+ <!-- ์ˆ˜์ •๋œ ๋ถ€๋ถ„: type ์†์„ฑ์„ "button"์—์„œ "submit"์œผ๋กœ ๋ณ€๊ฒฝ -->
70
+ <button type="submit" class="btn btn-primary" onclick="addItem()">์ถ”๊ฐ€</button>
71
+ </div>
72
+ </form>
73
+
74
+ </div>
75
+
76
+ <!-- ์ˆ˜์ •/์‚ญ์ œ ๋ชจ๋‹ฌ -->
77
+ <div class="modal fade" id="editItemModal" tabindex="-1" role="dialog" aria-labelledby="editItemModalLabel"
78
+ aria-hidden="true">
79
+ <!-- ๋ชจ๋‹ฌ ์ฐฝ ๋‚ด์šฉ์ด ์—ฌ๊ธฐ์— ๋™์ ์œผ๋กœ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค. -->
80
+ </div>
81
+
82
+ <!-- ์ถ”๊ฐ€ ๋ชจ๋‹ฌ
83
+ <div class="modal fade" id="addItemModal" tabindex="-1" role="dialog" aria-labelledby="addItemModalLabel" aria-hidden="true">
84
+ ๋ชจ๋‹ฌ ์ฐฝ ๋‚ด์šฉ์ด ์—ฌ๊ธฐ์— ๋™์ ์œผ๋กœ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
85
+ </div>
86
+ -->
87
+
88
+ <div class="buttonbox">
89
+ <div class="transformbox">
90
+ <button onclick="handleButtonClick()" class="transformButton" type="submit">
91
+ <span>์‹œ์ž‘ํ•˜๊ธฐ</span>
92
+ </button>
93
+ </div>
94
+ </div>
95
+ <div class="foot">
96
+ <div class="footer-text">
97
+ <span>๊ณ ๋ ค๋Œ€ํ•™๊ต ์ง€๋Šฅ์ •๋ณด SW ์•„์นด๋ฐ๋ฏธ 5์กฐ</span>
98
+ </div>
99
+ </div>
100
+ </div>
101
+ </div>
102
+ <script src="https://code.jquery.com/jquery-3.2.1.slim.min.js"></script>
103
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js"></script>
104
+ <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js"></script>
105
+ <script type="text/javascript" src="../static/js/confirm.js"></script>
106
+ <script>
107
+ document.addEventListener('DOMContentLoaded', function () {handlePageLoad();});
108
+ </script>
109
+ </body>
110
+
111
+ </html>
web/final.html ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ko">
3
+
4
+ <head>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <meta charset="utf-8">
7
+ <title>Spakers in Text</title>
8
+ <link rel="stylesheet" href="{{ url_for('static', path='css/finishs.css') }}">
9
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
10
+ integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
11
+ </head>
12
+
13
+ <body>
14
+ <!-- jQuery ๋ฐ Bootstrap JavaScript ๋กœ๋“œ -->
15
+ <script src="https://code.jquery.com/jquery-3.6.4.min.js"
16
+ integrity="sha256-oP6HI/tZ1aS9sz3Jr4+6zqbc9BE/l6fLx+Vz2I+H/GL4ZiI/Z5L3hMv8w3yXdBi" crossorigin="anonymous"></script>
17
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
18
+ integrity="sha384-3ziFjNlAXja/Yb0M7y2BmFvR3s09gRPbrCm0lF+SvL4uIboD5lv3U3BdD7dW7Y3" crossorigin="anonymous"></script>
19
+
20
+ <div class="background">
21
+ <div class="header"><a href="/"><span class="title">Nouvel : Novel for you</span></a></div>
22
+ <div class="empty"></div>
23
+ <div class="box">
24
+ <div class="subtitle">txtํŒŒ์ผ ๋ณ€ํ™˜ ๊ฒฐ๊ณผ</div>
25
+ <div class="explain_box">
26
+ <span class="explain">๋ฐ‘์˜ ๊ฒฐ๊ณผ๋Š” ์†Œ์„ค ์† ์ธ์šฉ๋ฌธ ๋ณ„๋กœ ๋ฐœํ™”์ž๋ฅผ ์ธ์‹ํ•œ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค.</br>
27
+ ๋˜ํ•œ ์žฅ๋ฉด ๋ณ„๋กœ ์žฅ์†Œ์™€ ์‹œ๊ฐ„์„ ์ธ์‹ํ•˜์˜€์Šต๋‹ˆ๋‹ค.</span>
28
+ </div>
29
+ <div class = "TLContainer">
30
+ <div class = "timeContainer">
31
+ <span class = "timelocation">์‹œ๊ฐ„ : {{ time }}</span>
32
+ </div>
33
+ <div class = "locationContainer">
34
+ <span class = "timelocation">์žฅ์†Œ : {{ place }}</span>
35
+ </div>
36
+ </div>
37
+ <div class="itemContainer" id="resultContainer">
38
+ {% for out in output %}
39
+ <p>{{ out }}</p>
40
+ {% endfor %}
41
+ </div>
42
+ <div class="buttonContainer">
43
+ <div class="downloadButtonContainer">
44
+ <button id="downloadButton">Download Text File</button>
45
+ </div>
46
+ <div class="homeButtonContainer">
47
+ <button id="homeButton" onClick="location.href='/'">ํ™ˆ์œผ๋กœ ๋Œ์•„๊ฐ€๊ธฐ</button>
48
+ </div>
49
+ </div>
50
+ </div>
51
+ <div class="foot">
52
+ <div class="footer-text">
53
+ <span>๊ณ ๋ ค๋Œ€ํ•™๊ต ์ง€๋Šฅ์ •๋ณด SW ์•„์นด๋ฐ๋ฏธ 5์กฐ</span>
54
+ </div>
55
+ </div>
56
+ </div>
57
+ <script src="https://code.jquery.com/jquery-3.2.1.slim.min.js"></script>
58
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js"></script>
59
+ <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js"></script>
60
+ <script type="text/javascript" src="../js/finish.js"></script>
61
+ </body>
62
+
63
+ </html>
web/index.html ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ko">
3
+
4
+ <head>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <meta charset="utf-8">
7
+ <title>Spakers in Text</title>
8
+ <link rel="stylesheet" href="{{ url_for('static', path='css/indexs.css') }}">
9
+ </head>
10
+
11
+ <body>
12
+ <div class="background">
13
+ <div class="header">
14
+ <a href="/">
15
+ <span class="title">Nouvel : Novel for you</span>
16
+ </a>
17
+ </div>
18
+ <div class="empty"></div>
19
+ <div class="box">
20
+ <p class="subtitle">์„œ๋น„์Šค ์†Œ๊ฐœ</p>
21
+ <div class="explain-box">
22
+ <span class="explain">BERT ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ํŒŒ์ผ ๋‚ด์—์„œ ๋ฐœํ™”์ž๋ฅผ ์ธ์‹ํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
23
+ <br>์•„๋ž˜์˜ ์‹œ์ž‘ํ•˜๊ธฐ ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ํ…์ŠคํŠธ ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜๊ณ  ๋ชจ๋ธ์„
24
+ ์ด์šฉํ•œ ๋ฐœํ™”์ž ์ธ์‹ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.</span>
25
+ </div>
26
+ <div class="buttonbox">
27
+ <button onClick="location.href='put.html'" class="start-button">
28
+ <span>์‹œ์ž‘ํ•˜๊ธฐ</span>
29
+ </button>
30
+ </div>
31
+ </div>
32
+ <div class="foot">
33
+ <div class="footer-text">
34
+ <span>๊ณ ๋ ค๋Œ€ํ•™๊ต ์ง€๋Šฅ์ •๋ณด SW ์•„์นด๋ฐ๋ฏธ 5์กฐ</span>
35
+ </div>
36
+ </div>
37
+ </div>
38
+ </body>
39
+
40
+ </html>
web/put.html ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ko">
3
+
4
+ <head>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <meta charset="utf-8">
7
+ <title>Spakers in Text</title>
8
+ <link rel="stylesheet" href="{{ url_for('static', path='css/put.css') }}">
9
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
10
+ integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
11
+ </head>
12
+
13
+ <body>
14
+ <div class="background">
15
+ <div class="header"><a href="/"><span class="title">Nouvel : Novel for you</span></a></div>
16
+ <div class="empty"></div>
17
+ <div class="box">
18
+ <div class="subtitle">๋ณ€ํ™˜ํ•  ํŒŒ์ผ ์—…๋กœ๋“œํ•˜๊ธฐ</div>
19
+ <div class="explain_box">
20
+ <span class="explain">๋ฐœํ™”์ž๋ฅผ ์ฐพ๊ณ  ์‹ถ์€ txt ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜๊ณ  ๋ณ€ํ™˜ํ•˜๊ธฐ ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์„ธ์š”. </span>
21
+ </div>
22
+ <form id="myForm" onsubmit="return validateForm()">
23
+ <div class="button_box">
24
+ <p>์•„๋ž˜์˜ ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์—ฌ ์‚ฌ์šฉ์ž์˜ ์œ ํ˜•์„ ์„ ํƒํ•ด์ฃผ์„ธ์š”. <br><b><i>์ž‘๊ฐ€(์ „๋ฌธ๊ฐ€) ์‚ฌ์šฉ์ž๋Š” ๋“ฑ์žฅ์ธ๋ฌผ ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ฒ€ํ† ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.</i></b></p>
25
+ <input type="radio" class="btn-check" name="displayOption" id="option1" autocomplete="off" value="pro">
26
+ <label class="btn btn-secondary" for="option1">์ž‘๊ฐ€(์ „๋ฌธ๊ฐ€)</label>
27
+
28
+ <input type="radio" class="btn-check" name="displayOption" id="option2" autocomplete="off" value="reader">
29
+ <label class="btn btn-secondary" for="option2">๋…์ž(์ด์šฉ์ž)</label>
30
+ </div>
31
+ <div class="formbox">
32
+ <div class="mb-3">
33
+ <label for="formFileSm" class="form-label"><b>์ฒจ๋ถ€ํŒŒ์ผ ์—…๋กœ๋“œ</b></label>
34
+ <input class="form-control form-control-sm" id="formFileSm" type="file">
35
+ </div>
36
+ </div>
37
+ <div class="transformbox">
38
+ <button onclick="handleButtonClick()" class="transformButton" type="button">
39
+ <span>๋ณ€ํ™˜ํ•˜๊ธฐ</span>
40
+ </button>
41
+ </div>
42
+ </form>
43
+ </div>
44
+ <div class="foot">
45
+ <div class="footer-text">
46
+ <span>๊ณ ๋ ค๋Œ€ํ•™๊ต ์ง€๋Šฅ์ •๋ณด SW ์•„์นด๋ฐ๋ฏธ 5์กฐ</span>
47
+ </div>
48
+ </div>
49
+ </div>
50
+ <script src="../static/js/put.js"></script>
51
+ </body>
52
+
53
+ </html>
web/user.html ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ko">
3
+
4
+ <head>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <meta charset="utf-8">
7
+ <title>Spakers in Text</title>
8
+ <link rel="stylesheet" href="{{ url_for('static', path='css/index.css') }}">
9
+ </head>
10
+
11
+ <body>
12
+ <div class="background">
13
+ <div class="header"><a href="/"><span class="title">Nouvel : Novel for you</span></a></div>
14
+ <div class="empty"></div>
15
+ <div class="box">
16
+ <div class="subtitle">๋“ฑ์žฅ์ธ๋ฌผ ๋ช…๋‹จ</div>
17
+ <div class="body">
18
+
19
+
20
+
21
+ </div>
22
+ <div class="transformbox">
23
+ <button onclick="handleButtonClick()" class="transformButton"><span>์žฅ๋ฉด์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ</span></button>
24
+ </div>
25
+ </div>
26
+ <div class="foot">
27
+ <div class="footer-text">
28
+ <span>๊ณ ๋ ค๋Œ€ํ•™๊ต ์ง€๋Šฅ์ •๋ณด SW ์•„์นด๋ฐ๋ฏธ 5์กฐ</span>
29
+ </div>
30
+ </div>
31
+ <script src="../static/js/user.js"></script>
32
+ </body>
33
+
34
+ </html>