File size: 12,493 Bytes
6c4ffba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import torch

from typing import Callable, List, Tuple, Union
from functools import partial
import itertools

from seqeval.scheme import Tokens, IOB2, IOBES

from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize
newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm')

from thai2transformers.preprocess import rm_useless_spaces

SPIECE = '▁'

class TokenClassificationPipeline:

    def __init__(self,
                 model: PreTrainedModel,
                 tokenizer: PreTrainedTokenizerBase,
                 pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer,
                 lowercase=False,
                 space_token='<_>',
                 device: int = -1,
                 group_entities: bool = False,
                 strict: bool = False,
                 tag_delimiter: str = '-',
                 scheme: str = 'IOB',
                 use_crf=False,
                 remove_spiece=True):

        super().__init__()

        assert isinstance(tokenizer, PreTrainedTokenizerBase)
        # assert isinstance(model, PreTrainedModel)
        
        self.model = model
        self.tokenizer = tokenizer
        self.pretokenizer = pretokenizer
        self.lowercase = lowercase
        self.space_token = space_token
        self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}'
        self.group_entities = group_entities
        self.strict = strict
        self.tag_delimiter = tag_delimiter
        self.scheme = scheme
        self.id2label = self.model.config.id2label
        self.label2id = self.model.config.label2id
        self.use_crf = use_crf
        self.remove_spiece = remove_spiece
        self.model.to(self.device)

    def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:

        if self.lowercase:
            inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs))
        
        inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs))

        tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs))

        tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \
                 list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens))

        return tokens

    def _inference(self, input: str):

        tokens = [[self.tokenizer.bos_token]] + \
                    [self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \
                    [[self.tokenizer.eos_token]]
        ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens]
        flatten_tokens = list(itertools.chain(*tokens))
        flatten_ids = list(itertools.chain(*ids))

        input_ids = torch.LongTensor([flatten_ids]).to(self.device)

        if self.use_crf:
            out = self.model(input_ids=input_ids)
        else:
            out = self.model(input_ids=input_ids, return_dict=True)
            probs = torch.softmax(out['logits'], dim=-1)
            vals, indices = probs.topk(1)
            indices_np = indices.detach().cpu().numpy().reshape(-1)

        list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] ))
        merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids)
        if self.remove_spiece:
            merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds))
       
        # remove start and end tokens
        merged_preds_removed_bos_eos = merged_preds[1:-1]
        # convert to list of Dict objects
        merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \
            for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ]

        if (not self.group_entities or self.scheme == None) and self.strict == True:
            return merged_preds_return_dict
        elif not self.group_entities and self.strict == False:

            tags = list(map(lambda x: x['entity'], merged_preds_return_dict))
            processed_tags = self._fix_incorrect_tags(tags)
            for i, item in enumerate(merged_preds_return_dict):
                merged_preds_return_dict[i]['entity'] = processed_tags[i]
            return merged_preds_return_dict
        elif self.group_entities:
            return self._group_entities(merged_preds_removed_bos_eos)

    def __call__(self, inputs: Union[str, List[str]]):

        """     
            
        """
        if type(inputs) == str:
            return self._inference(inputs)
        
        if type(inputs) == list:
            results = [ self._inference(text) for text in inputs]
            return results
       

    def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]):
    
        token_mapping = [ ]
        for i in range(0, len(ids)):
            for j in range(0, len(ids[i])):
                token_mapping.append(i)

        grouped_subtokens = []
        _subtoken = []
        prev_idx = 0
    
        for i, (subtoken, label) in enumerate(preds):
            
            current_idx =  token_mapping[i]
            if prev_idx != current_idx:
                grouped_subtokens.append(_subtoken)
                _subtoken = [(subtoken, label)]
                if i == len(preds) -1:
                    _subtoken = [(subtoken, label)]
                    grouped_subtokens.append(_subtoken)
            elif i == len(preds) -1:
                _subtoken += [(subtoken, label)]
                grouped_subtokens.append(_subtoken)
            else:
                _subtoken += [(subtoken, label)]
            prev_idx = current_idx
        
        merged_subtokens = []
        _merged_subtoken = ''
        for subtoken_group in grouped_subtokens:
            
            first_token_pred = subtoken_group[0][1]
            _merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group)))
            merged_subtokens.append((_merged_subtoken, first_token_pred))
        return merged_subtokens

    def _fix_incorrect_tags(self, tags: List[str]) -> List[str]:

        I_PREFIX = f'I{self.tag_delimiter}'
        E_PREFIX = f'E{self.tag_delimiter}'
        B_PREFIX = f'B{self.tag_delimiter}'
        O_PREFIX = 'O'
    
        previous_tag_ne = None
        for i, current_tag in enumerate(tags):
            
            current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX
            
            if i == 0 and (current_tag.startswith(I_PREFIX) or \
                current_tag.startswith(E_PREFIX)):
                # if a NE tag (with I-, or E- prefix) occuring at the begining of sentence
                # e.g. (I-LOC, I-LOC) , (E-LOC, B-PER) (I-LOC, O, O)
                # then, change the prefix of the current tag to B{tag_delimiter}
                tags[i] = B_PREFIX + tags[i][2:]
            elif i >= 1 and tags[i-1] == O_PREFIX and (
                current_tag.startswith(I_PREFIX) or \
                current_tag.startswith(E_PREFIX)):
                # if a NE tag (with I-, or E- prefix) occuring after O tag
                # e.g. (O, I-LOC, I-LOC) , (O, E-LOC, B-PER) (O, I-LOC, O, O)
                # then, change the prefix of the current tag to B{tag_delimiter}
                tags[i] = B_PREFIX + tags[i][2:]
            elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \
                tags[i-1].startswith(E_PREFIX) or \
                tags[i-1].startswith(B_PREFIX)) and \
                ( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) )  and \
                previous_tag_ne != current_tag_ne:
                # if a NE tag (with I-, or E- prefix) occuring after NE tag with different NE
                # e.g. (B-LOC, I-PER) , (B-LOC, E-LOC, E-PER) (B-LOC, I-LOC, I-PER)
                # then, change the prefix of the current tag to B{tag_delimiter}
                tags[i] = B_PREFIX + tags[i][2:]
            elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and (
                current_tag.startswith(I_PREFIX) or \
                current_tag.startswith(E_PREFIX)):
                # if a NE tag (with I-, or E- prefix) occuring at the end of sentence
                # e.g. (O, O, I-LOC)  , (O, O, E-LOC) 
                # then, change the prefix of the current tag to B{tag_delimiter}
                tags[i] = B_PREFIX + tags[i][2:]

            previous_tag_ne = current_tag_ne
        
        return tags

    def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
        
        if self.scheme not in ['IOB', 'IOBES', 'IOBE']:
            raise AttributeError()

        tokens, tags = zip(*ner_tags)
        tokens, tags = list(tokens), list(tags)

        if self.scheme == 'IOBE':
            # Replace E prefix with I prefix
            tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
        if self.scheme == 'IOBES':
            # Replace E prefix with I prefix and replace S prefix with B
            tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
            tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags))

        if not self.strict:
            
            tags = self._fix_incorrect_tags(tags)
            
        ent = Tokens(tokens=tags, scheme=IOB2,
                     suffix=False, delimiter=self.tag_delimiter)

        ne_position_mappings = ent.entities
        token_positions = []
        curr_len = 0
        tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens))
        for i, token in enumerate(tokens):
            token_len = len(token)
            if i == 0:
                token_positions.append((0, curr_len + token_len))
            else:
                token_positions.append((curr_len, curr_len + token_len ))
            curr_len += token_len
        print(f'token_positions: {list(zip(tokens, token_positions))}')
        begin_end_pos = []
        begin_end_char_pos = []
        accum_char_len = 0
        for i, ne_position_mapping in enumerate(ne_position_mappings):
            print(f'ne_position_mapping.start: {ne_position_mapping.start}')
            print(f'ne_position_mapping.end: {ne_position_mapping.end}\n')
            begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end))
            begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1]))
        print(f'begin_end_pos: {begin_end_pos}')  
        print(f'begin_end_char_pos: {begin_end_char_pos}')  

        j = 0
        # print(f'tokens: {tokens}')
        for i, pos_tuple in enumerate(begin_end_pos):   
            # print(f'j = {j}')
            if pos_tuple[0] > 0 and i == 0:
                ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0]))
                j += 1   
            if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 :
                ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0]))
                j += 1 
        
            j += 1
        print('ne_position_mappings', ne_position_mappings) 

        groups = []
        k = 0
        for i, ne_position_mapping in enumerate(ne_position_mappings):
            if type(ne_position_mapping) != tuple:
                ne_position_mapping = ne_position_mapping.to_tuple()
            ne = ne_position_mapping[1]
            
            text = ''
            for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]):
                _token = tokens[ne_position]
                text += _token if _token != self.space_token else ' '
            if ne.lower() != 'o':
                groups.append({
                    'entity_group': ne,
                    'word': text,
                    'begin_char_index': begin_end_char_pos[k][0]
                })
                k+=1
            else:
                groups.append({
                    'entity_group': ne,
                    'word': text,
                })
        return groups