koCSN_SAPR / utils /train_model.py
yuneun92's picture
Upload 13 files
bcb1848 verified
raw
history blame
10.8 kB
"""
Author:
"""
import re
import torch.nn as nn
import torch.nn.functional as functional
import torch
from transformers import AutoModel
import torch.autograd as autograd
def get_nonlinear(nonlinear):
"""
Activation function.
"""
nonlinear_dict = {'relu': nn.ReLU(), 'tanh': nn.Tanh(),
'sigmoid': nn.Sigmoid(), 'softmax': nn.Softmax(dim=-1)}
try:
return nonlinear_dict[nonlinear]
except:
raise ValueError('not a valid nonlinear type!')
class SeqPooling(nn.Module):
"""
Sequence pooling module.
Can do max-pooling, mean-pooling and attentive-pooling on a list of sequences of different lengths.
"""
def __init__(self, pooling_type, hidden_dim):
super(SeqPooling, self).__init__()
self.pooling_type = pooling_type
self.hidden_dim = hidden_dim
if pooling_type == 'attentive_pooling':
self.query_vec = nn.parameter.Parameter(torch.randn(hidden_dim))
def max_pool(self, seq):
return seq.max(0)[0]
def mean_pool(self, seq):
return seq.mean(0)
def attn_pool(self, seq):
attn_score = torch.mm(seq, self.query_vec.view(-1, 1)).view(-1)
attn_w = nn.Softmax(dim=0)(attn_score)
weighted_sum = torch.mm(attn_w.view(1, -1), seq).view(-1)
return weighted_sum
def forward(self, batch_seq):
pooling_fn = {'max_pooling': self.max_pool,
'mean_pooling': self.mean_pool,
'attentive_pooling': self.attn_pool}
pooled_seq = [pooling_fn[self.pooling_type](seq) for seq in batch_seq]
return torch.stack(pooled_seq, dim=0)
class MLP_Scorer(nn.Module):
"""
MLP scorer module.
A perceptron with two layers.
"""
def __init__(self, args, classifier_input_size):
super(MLP_Scorer, self).__init__()
self.scorer = nn.ModuleList()
self.scorer.append(nn.Linear(classifier_input_size, args.classifier_intermediate_dim))
self.scorer.append(nn.Linear(args.classifier_intermediate_dim, 1))
self.nonlinear = get_nonlinear(args.nonlinear_type)
def forward(self, x):
for model in self.scorer:
x = self.nonlinear(model(x))
return x
class KCSN(nn.Module):
"""
Candidate Scoring Network.
It's built on BERT with an MLP and other simple components.
"""
def __init__(self, args):
super(KCSN, self).__init__()
self.args = args
self.bert_model = AutoModel.from_pretrained(args.bert_pretrained_dir)
self.pooling = SeqPooling(args.pooling_type, self.bert_model.config.hidden_size)
self.mlp_scorer = MLP_Scorer(args, self.bert_model.config.hidden_size * 3)
self.dropout = nn.Dropout(args.dropout)
def forward(self, features, sent_char_lens, mention_poses, quote_idxes, true_index, device, tokens_list, cut_css):
# encoding
qs_hid = []
ctx_hid = []
cdd_hid = []
unk_loc_li = []
unk_loc = 0
for i, (cdd_sent_char_lens, cdd_mention_pos, cdd_quote_idx) in enumerate(
zip(sent_char_lens, mention_poses, quote_idxes)):
unk_loc = unk_loc + 1
bert_output = self.bert_model(
torch.tensor([features[i].input_ids], dtype=torch.long).to(device),
token_type_ids=None,
attention_mask=torch.tensor([features[i].input_mask], dtype=torch.long).to(device)
)
modified_list = [s.replace('#', '') for s in tokens_list[i]]
cnt = 1
verify = 0
num_check = 0
num_vid = -999
accum_char_len = [0]
for idx, txt in enumerate(cut_css[i]):
result_string = ''.join(txt)
replace_dict = {']': r'\]', '[': r'\[', '?': r'\?', '-': r'\-', '!': r'\!'}
string_processing = result_string[-7:].translate(str.maketrans(replace_dict))
pattern = re.compile(rf'[{string_processing}]')
cnt = 1
if num_check == 1000:
accum_char_len.append(num_vid)
num_check = 1000
for string in modified_list:
string_nospace = string.replace(' ','')
if len(accum_char_len) > idx + 1:
continue
for letter in string_nospace:
match_result = pattern.match(letter)
if match_result:
verify += 1
if verify == len(result_string[-7:]):
if cnt > accum_char_len[-1]:
accum_char_len.append(cnt)
verify = 0
num_check = len(accum_char_len)
else:
verify = 0
cnt = cnt + 1
if num_check == 1000:
accum_char_len.append(num_vid)
if -999 in accum_char_len:
unk_loc_li.append(unk_loc)
continue
CSS_hid = bert_output['last_hidden_state'][0][1:sum(cdd_sent_char_lens) + 1].to(device)
qs_hid.append(CSS_hid[accum_char_len[cdd_quote_idx]:accum_char_len[cdd_quote_idx + 1]])
## ๋ฐœํ™”์ž ๋ถ€๋ถ„ ์ฐพ์•„์„œ - bert tokenizer ๋œ ๋ถ€๋ถ„์„ ์ธ๋ฑ์‹ฑ ํ•˜๋Š” ๋ถ€๋ถ„
cnt = 1
cdd_mention_pos_bert_li = []
cdd_mention_pos_unk = []
name = cut_css[i][cdd_mention_pos[0]][cdd_mention_pos[3]]
# extract only name
# ์ด๋ฆ„๋งŒ ์ถ”์ถœ
cdd_pattern = re.compile(r'&C[0-5][0-9]&')
name_process = cdd_pattern.search(name)
# find candidate location in bert output
# ๋ฒ„ํŠธ ๊ฒฐ๊ณผ์—์„œ ๋ฐœํ™”์ž ์œ„์น˜๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค
pattern_unk = re.compile(r'[\[UNK\]]')
# ์ด ๋ถ€๋ถ„์€ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ๊ฒŒ ๋˜๋ฉด, ๋” ์ด์ƒ ๋„˜์–ด๊ฐ€์ง€ ์•Š๋„๋ก ํ•˜๋Š” ์ฝ”๋“œ ์ž…๋‹ˆ๋‹ค.
if len(accum_char_len) < cdd_mention_pos[0]+1:
maxx_len = accum_char_len[len(accum_char_len)-1]
elif len(accum_char_len) == cdd_mention_pos[0]+1:
maxx_len = accum_char_len[-1] + 1000
else:
maxx_len = accum_char_len[cdd_mention_pos[0]+1]
# ํฌํ•จ๋˜๋Š” ๋ฐœํ™”์ž๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด.
start_name = None
name_match = '&'
for string in modified_list:
string_nospace = string.replace(' ','')
for letter in string_nospace:
match_result_unk = pattern_unk.match(letter)
if match_result_unk:
cdd_mention_pos_unk.append(cnt)
if start_name is True:
name_match += letter
if (name_match == name_process.group(0) or letter == '&') and len(
cdd_mention_pos_bert_li) < 3 and maxx_len > cnt >= accum_char_len[
cdd_mention_pos[0]]: # ๋งŒ์•ฝ & ๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์„ ๊ฒฝ์šฐ์— ์‚ฌ๋žŒ์œผ๋กœ ์ถ”์ถœ
start_name = True # ๋งค์นญ์ด ๋˜๋ฉด, 1์„ ๋”ํ•ฉ๋‹ˆ๋‹ค.
if len(cdd_mention_pos_bert_li) == 1 and name_match != name_process.group(0): # ๋งŒ์•ฝ &๊ฐ€ ๋‘๋ฒˆ์งธ๋กœ ๋‚˜์˜ค๊ณ , ๋งค์นญ์ด ์•ˆ๋  ๊ฒฝ์šฐ
start_name = None
name_match = '&'
cdd_mention_pos_bert_li = []
elif name_match == name_process.group(0): # ๋‘๋ฒˆ์งธ ์ถ”๊ฐ€
cdd_mention_pos_bert_li.append(cnt)
start_name = None
name_match = '&'
else:
cdd_mention_pos_bert_li.append(cnt-1)
cnt += 1
if len(cdd_mention_pos_bert_li) == 0 & len(cdd_mention_pos_unk) != 0:
cdd_mention_pos_bert_li.extend([cdd_mention_pos_unk[0], cdd_mention_pos_unk[0]+1])
elif len(cdd_mention_pos_bert_li) != 2:
cdd_mention_pos_bert_li = []
cdd_mention_pos_bert_li.extend([int(cdd_mention_pos[1] * accum_char_len[-1]/sum(
cdd_sent_char_lens)), int(cdd_mention_pos[2] * accum_char_len[-1]/sum(
cdd_sent_char_lens))])
if cdd_mention_pos_bert_li[0] == cdd_mention_pos_bert_li[1]:
cdd_mention_pos_bert_li[1] = cdd_mention_pos_bert_li[1]+1
# ctx ๊ฒฐ์ •ํ•˜๋Š” ์ฝ”๋“œ. candidate ์ฃผ๋ณ€ ์ •๋ณด ์ถ”์ถœ
# ํ•˜๋‚˜์ผ ๊ฒฝ์šฐ์—๋Š” ์ „์ฒด ๋ถ€๋ถ„์„ ๊ฐ€์ ธ์˜จ๋‹ค.
if len(cdd_sent_char_lens) == 1:
ctx_hid.append(torch.zeros(1, CSS_hid.size(1)).to(device))
# ๋งŒ์•ฝ ์•ž์— ๋ฐœํ™”์ž๊ฐ€ ์žˆ์„ ๊ฒฝ์šฐ์—” ์•ž ๋ฌธ์žฅ๋ถ€ํ„ฐ, ๋งˆ์ง€๋ง‰(์ธ์šฉ๋ฌธ) ์ „๊นŒ์ง€ ๊ฐ€์ ธ์˜จ๋‹ค.
elif cdd_mention_pos[0] == 0:
ctx_hid.append(CSS_hid[:accum_char_len[-2]])
# ๋งˆ์ง€๋ง‰์œผ๋กœ ๋ฐœํ™”์ž๊ฐ€ ๋’ค์— ์žˆ์„ ๊ฒฝ์šฐ์—๋Š” ๋‘๋ฒˆ์งธ ๋ถ€ํ„ฐ ๋๊นŒ์ง€ ๊ฐ€์ ธ์˜จ๋‹ค.
else:
ctx_hid.append(CSS_hid[accum_char_len[1]:])
cdd_mention_pos_bert = (cdd_mention_pos[0], cdd_mention_pos_bert_li[0],
cdd_mention_pos_bert_li[1])
cdd_hid.append(CSS_hid[cdd_mention_pos_bert[1]:cdd_mention_pos_bert[2]])
# pooling
if not qs_hid:
scores = '1'
scores_false = 1
scores_true = 1
return scores, scores_false, scores_true
qs_rep = self.pooling(qs_hid).to(device)
ctx_rep = self.pooling(ctx_hid).to(device)
cdd_rep = self.pooling(cdd_hid).to(device)
# concatenate
feature_vector = torch.cat([qs_rep, ctx_rep, cdd_rep], dim=-1).to(device)
# dropout
feature_vector = self.dropout(feature_vector).to(device)
# scoring
scores = self.mlp_scorer(feature_vector).view(-1).to(device)
for i in unk_loc_li:
# ์ถ”๊ฐ€ํ•  ์›์†Œ
new_element = torch.tensor([-0.9000], requires_grad=True).to(device)
# ํŠน์ • ์ธ๋ฑ์Šค์— ์›์†Œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด torch.cat()๊ณผ ์Šฌ๋ผ์ด์‹ฑ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
index_to_insert = i - 1
scores = torch.cat((scores[:index_to_insert], new_element, scores[index_to_insert:]),
dim=0).to(device)
scores_false = [scores[i] for i in range(scores.size(0)) if i != true_index]
scores_true = [scores[true_index] for i in range(scores.size(0) - 1)]
return scores, scores_false, scores_true