File size: 3,273 Bytes
bcb1848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
"""
ํ์ ์ฐพ๋ ๋ชจ๋ธ ์ ํธ ํ์ผ๋ค
"""
class InputFeatures:
"""
BERT ๋ชจ๋ธ์ ์
๋ ฅ๋ค
"""
def __init__(self, tokens, input_ids, input_mask, input_type_ids):
self.tokens = tokens
self.input_ids = input_ids
self.input_mask = input_mask
self.input_type_ids = input_type_ids
def convert_examples_to_features(examples, tokenizer):
"""
ํ
์คํธ segment๋ฅผ ๋จ์ด ID๋ก ๋ณํํฉ๋๋ค.
"""
features = []
tokens_list = []
for (ex_index, example) in enumerate(examples):
tokens = tokenizer.tokenize(example)
tokens_list.append(tokens)
new_tokens = []
input_type_ids = []
new_tokens.append("[CLS]")
input_type_ids.append(0)
new_tokens = new_tokens + tokens
input_type_ids = input_type_ids + [0] * len(tokens)
new_tokens.append("[SEP]")
input_type_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(new_tokens)
input_mask = [1] * len(input_ids)
features.append(
InputFeatures(
tokens=new_tokens,
input_ids=input_ids,
input_mask=input_mask,
input_type_ids=input_type_ids))
return features, tokens_list
def get_alias2id(name_list_path) -> dict:
"""
์ฃผ์ด์ง ์ด๋ฆ ๋ชฉ๋ก ํ์ผ์์ ๋ณ์นญ(alias)์ ID๋ก ๋งคํํ๋ ์ฌ์ ์ ์์ฑ.
"""
with open(name_list_path, 'r', encoding='utf-8') as fin:
name_lines = fin.readlines()
alias2id = {}
for i, line in enumerate(name_lines):
for alias in line.strip().split()[1:]:
alias2id[alias] = i
return alias2id
def find_speak(fs_model, input_data, tokenizer, alias2id):
"""
์ฃผ์ด์ง ๋ชจ๋ธ๊ณผ ์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ์
๋ ฅ์ ๋ํ ํ์๋ฅผ ์ฐพ๋ ํจ์
"""
model = fs_model
check_data_iter = iter(input_data)
names = []
for _ in range(len(input_data)):
seg_sents, css, scl, mp, qi, cut_css, name_list_index = next(check_data_iter)
features, tokens_list = convert_examples_to_features(examples=css, tokenizer=tokenizer)
try:
predictions = model(features, scl, mp, qi, 0, "cuda:0", tokens_list, cut_css)
except RuntimeError:
predictions = model(features, scl, mp, qi, 0, "cpu", tokens_list, cut_css)
scores, _, _ = predictions
# ํ์ฒ๋ฆฌ
try:
scores_np = scores.detach().cpu().numpy()
scores_list = scores_np.tolist()
score_index = scores_list.index(max(scores_list))
name_index = name_list_index[score_index]
for key, val in alias2id.items():
if val == name_index:
result_key = key
names.append(result_key)
except AttributeError:
names.append('์ ์ ์์')
return names
def making_script(text, speaker:list, instance_num:list) -> str:
"""
์ฃผ์ด์ง ํ
์คํธ์ ํ์ ๋ชฉ๋ก, ํด๋นํ๋ ์ค ๋ฒํธ๋ฅผ ์ฌ์ฉํ์ฌ ๋ํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๋ ํจ์
"""
lines = text.splitlines()
for num, people in zip(instance_num, speaker):
lines[num] = f'{people}: {lines[num]}'
return lines
|