File size: 3,273 Bytes
bcb1848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
ํ™”์ž ์ฐพ๋Š” ๋ชจ๋ธ ์œ ํ‹ธ ํŒŒ์ผ๋“ค
"""
class InputFeatures:
    """
    BERT ๋ชจ๋ธ์˜ ์ž…๋ ฅ๋“ค
    """
    def __init__(self, tokens, input_ids, input_mask, input_type_ids):
        self.tokens = tokens
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.input_type_ids = input_type_ids


def convert_examples_to_features(examples, tokenizer):
    """
    ํ…์ŠคํŠธ segment๋ฅผ ๋‹จ์–ด ID๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
    """
    features = []
    tokens_list = []

    for (ex_index, example) in enumerate(examples):
        tokens = tokenizer.tokenize(example)
        tokens_list.append(tokens)

        new_tokens = []
        input_type_ids = []

        new_tokens.append("[CLS]")
        input_type_ids.append(0)
        new_tokens = new_tokens + tokens
        input_type_ids = input_type_ids + [0] * len(tokens)
        new_tokens.append("[SEP]")
        input_type_ids.append(0)

        input_ids = tokenizer.convert_tokens_to_ids(new_tokens)
        input_mask = [1] * len(input_ids)

        features.append(
            InputFeatures(
                tokens=new_tokens,
                input_ids=input_ids,
                input_mask=input_mask,
                input_type_ids=input_type_ids))

    return features, tokens_list


def get_alias2id(name_list_path) -> dict:
    """
    ์ฃผ์–ด์ง„ ์ด๋ฆ„ ๋ชฉ๋ก ํŒŒ์ผ์—์„œ ๋ณ„์นญ(alias)์„ ID๋กœ ๋งคํ•‘ํ•˜๋Š” ์‚ฌ์ „์„ ์ƒ์„ฑ.
    """
    with open(name_list_path, 'r', encoding='utf-8') as fin:
        name_lines = fin.readlines()
    alias2id = {}

    for i, line in enumerate(name_lines):
        for alias in line.strip().split()[1:]:
            alias2id[alias] = i

    return alias2id


def find_speak(fs_model, input_data, tokenizer, alias2id):
    """
    ์ฃผ์–ด์ง„ ๋ชจ๋ธ๊ณผ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ์ž…๋ ฅ์— ๋Œ€ํ•œ ํ™”์ž๋ฅผ ์ฐพ๋Š” ํ•จ์ˆ˜
    """
    model = fs_model
    check_data_iter = iter(input_data)

    names = []

    for _ in range(len(input_data)):

        seg_sents, css, scl, mp, qi, cut_css, name_list_index = next(check_data_iter)
        features, tokens_list = convert_examples_to_features(examples=css, tokenizer=tokenizer)

        try:
            predictions = model(features, scl, mp, qi, 0, "cuda:0", tokens_list, cut_css)
        except RuntimeError:
            predictions = model(features, scl, mp, qi, 0, "cpu", tokens_list, cut_css)

        scores, _, _ = predictions

        # ํ›„์ฒ˜๋ฆฌ
        try:
            scores_np = scores.detach().cpu().numpy()
            scores_list = scores_np.tolist()
            score_index = scores_list.index(max(scores_list))
            name_index = name_list_index[score_index]

            for key, val in alias2id.items():
                if val == name_index:
                    result_key = key

            names.append(result_key)
        except AttributeError:
            names.append('์•Œ ์ˆ˜ ์—†์Œ')

    return names


def making_script(text, speaker:list, instance_num:list) -> str:
    """
    ์ฃผ์–ด์ง„ ํ…์ŠคํŠธ์™€ ํ™”์ž ๋ชฉ๋ก, ํ•ด๋‹นํ•˜๋Š” ์ค„ ๋ฒˆํ˜ธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋Œ€ํ™” ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
    """
    lines = text.splitlines()
    for num, people in zip(instance_num, speaker):
        lines[num] = f'{people}: {lines[num]}'
    return lines