|
import torch |
|
|
|
from typing import Callable, List, Tuple, Union |
|
from functools import partial |
|
import itertools |
|
|
|
from seqeval.scheme import Tokens, IOB2, IOBES |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.tokenization_utils import PreTrainedTokenizerBase |
|
from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize |
|
newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm') |
|
|
|
from thai2transformers.preprocess import rm_useless_spaces |
|
|
|
SPIECE = '▁' |
|
|
|
class TokenClassificationPipeline: |
|
|
|
def __init__(self, |
|
model: PreTrainedModel, |
|
tokenizer: PreTrainedTokenizerBase, |
|
pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer, |
|
lowercase=False, |
|
space_token='<_>', |
|
device: int = -1, |
|
group_entities: bool = False, |
|
strict: bool = False, |
|
tag_delimiter: str = '-', |
|
scheme: str = 'IOB', |
|
use_crf=False, |
|
remove_spiece=True): |
|
|
|
super().__init__() |
|
|
|
assert isinstance(tokenizer, PreTrainedTokenizerBase) |
|
|
|
|
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.pretokenizer = pretokenizer |
|
self.lowercase = lowercase |
|
self.space_token = space_token |
|
self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}' |
|
self.group_entities = group_entities |
|
self.strict = strict |
|
self.tag_delimiter = tag_delimiter |
|
self.scheme = scheme |
|
self.id2label = self.model.config.id2label |
|
self.label2id = self.model.config.label2id |
|
self.use_crf = use_crf |
|
self.remove_spiece = remove_spiece |
|
self.model.to(self.device) |
|
|
|
def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]: |
|
|
|
if self.lowercase: |
|
inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs)) |
|
|
|
inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs)) |
|
|
|
tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs)) |
|
|
|
tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \ |
|
list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens)) |
|
|
|
return tokens |
|
|
|
def _inference(self, input: str): |
|
|
|
tokens = [[self.tokenizer.bos_token]] + \ |
|
[self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \ |
|
[[self.tokenizer.eos_token]] |
|
ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens] |
|
flatten_tokens = list(itertools.chain(*tokens)) |
|
flatten_ids = list(itertools.chain(*ids)) |
|
|
|
input_ids = torch.LongTensor([flatten_ids]).to(self.device) |
|
|
|
if self.use_crf: |
|
out = self.model(input_ids=input_ids) |
|
else: |
|
out = self.model(input_ids=input_ids, return_dict=True) |
|
probs = torch.softmax(out['logits'], dim=-1) |
|
vals, indices = probs.topk(1) |
|
indices_np = indices.detach().cpu().numpy().reshape(-1) |
|
|
|
list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] )) |
|
merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids) |
|
if self.remove_spiece: |
|
merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds)) |
|
|
|
|
|
merged_preds_removed_bos_eos = merged_preds[1:-1] |
|
|
|
merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \ |
|
for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ] |
|
|
|
if (not self.group_entities or self.scheme == None) and self.strict == True: |
|
return merged_preds_return_dict |
|
elif not self.group_entities and self.strict == False: |
|
|
|
tags = list(map(lambda x: x['entity'], merged_preds_return_dict)) |
|
processed_tags = self._fix_incorrect_tags(tags) |
|
for i, item in enumerate(merged_preds_return_dict): |
|
merged_preds_return_dict[i]['entity'] = processed_tags[i] |
|
return merged_preds_return_dict |
|
elif self.group_entities: |
|
return self._group_entities(merged_preds_removed_bos_eos) |
|
|
|
def __call__(self, inputs: Union[str, List[str]]): |
|
|
|
""" |
|
|
|
""" |
|
if type(inputs) == str: |
|
return self._inference(inputs) |
|
|
|
if type(inputs) == list: |
|
results = [ self._inference(text) for text in inputs] |
|
return results |
|
|
|
|
|
def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]): |
|
|
|
token_mapping = [ ] |
|
for i in range(0, len(ids)): |
|
for j in range(0, len(ids[i])): |
|
token_mapping.append(i) |
|
|
|
grouped_subtokens = [] |
|
_subtoken = [] |
|
prev_idx = 0 |
|
|
|
for i, (subtoken, label) in enumerate(preds): |
|
|
|
current_idx = token_mapping[i] |
|
if prev_idx != current_idx: |
|
grouped_subtokens.append(_subtoken) |
|
_subtoken = [(subtoken, label)] |
|
if i == len(preds) -1: |
|
_subtoken = [(subtoken, label)] |
|
grouped_subtokens.append(_subtoken) |
|
elif i == len(preds) -1: |
|
_subtoken += [(subtoken, label)] |
|
grouped_subtokens.append(_subtoken) |
|
else: |
|
_subtoken += [(subtoken, label)] |
|
prev_idx = current_idx |
|
|
|
merged_subtokens = [] |
|
_merged_subtoken = '' |
|
for subtoken_group in grouped_subtokens: |
|
|
|
first_token_pred = subtoken_group[0][1] |
|
_merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group))) |
|
merged_subtokens.append((_merged_subtoken, first_token_pred)) |
|
return merged_subtokens |
|
|
|
def _fix_incorrect_tags(self, tags: List[str]) -> List[str]: |
|
|
|
I_PREFIX = f'I{self.tag_delimiter}' |
|
E_PREFIX = f'E{self.tag_delimiter}' |
|
B_PREFIX = f'B{self.tag_delimiter}' |
|
O_PREFIX = 'O' |
|
|
|
previous_tag_ne = None |
|
for i, current_tag in enumerate(tags): |
|
|
|
current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX |
|
|
|
if i == 0 and (current_tag.startswith(I_PREFIX) or \ |
|
current_tag.startswith(E_PREFIX)): |
|
|
|
|
|
|
|
tags[i] = B_PREFIX + tags[i][2:] |
|
elif i >= 1 and tags[i-1] == O_PREFIX and ( |
|
current_tag.startswith(I_PREFIX) or \ |
|
current_tag.startswith(E_PREFIX)): |
|
|
|
|
|
|
|
tags[i] = B_PREFIX + tags[i][2:] |
|
elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \ |
|
tags[i-1].startswith(E_PREFIX) or \ |
|
tags[i-1].startswith(B_PREFIX)) and \ |
|
( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) ) and \ |
|
previous_tag_ne != current_tag_ne: |
|
|
|
|
|
|
|
tags[i] = B_PREFIX + tags[i][2:] |
|
elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and ( |
|
current_tag.startswith(I_PREFIX) or \ |
|
current_tag.startswith(E_PREFIX)): |
|
|
|
|
|
|
|
tags[i] = B_PREFIX + tags[i][2:] |
|
|
|
previous_tag_ne = current_tag_ne |
|
|
|
return tags |
|
|
|
def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]: |
|
|
|
if self.scheme not in ['IOB', 'IOBES', 'IOBE']: |
|
raise AttributeError() |
|
|
|
tokens, tags = zip(*ner_tags) |
|
tokens, tags = list(tokens), list(tags) |
|
|
|
if self.scheme == 'IOBE': |
|
|
|
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags)) |
|
if self.scheme == 'IOBES': |
|
|
|
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags)) |
|
tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags)) |
|
|
|
if not self.strict: |
|
|
|
tags = self._fix_incorrect_tags(tags) |
|
|
|
ent = Tokens(tokens=tags, scheme=IOB2, |
|
suffix=False, delimiter=self.tag_delimiter) |
|
|
|
ne_position_mappings = ent.entities |
|
token_positions = [] |
|
curr_len = 0 |
|
tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens)) |
|
for i, token in enumerate(tokens): |
|
token_len = len(token) |
|
if i == 0: |
|
token_positions.append((0, curr_len + token_len)) |
|
else: |
|
token_positions.append((curr_len, curr_len + token_len )) |
|
curr_len += token_len |
|
print(f'token_positions: {list(zip(tokens, token_positions))}') |
|
begin_end_pos = [] |
|
begin_end_char_pos = [] |
|
accum_char_len = 0 |
|
for i, ne_position_mapping in enumerate(ne_position_mappings): |
|
print(f'ne_position_mapping.start: {ne_position_mapping.start}') |
|
print(f'ne_position_mapping.end: {ne_position_mapping.end}\n') |
|
begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end)) |
|
begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1])) |
|
print(f'begin_end_pos: {begin_end_pos}') |
|
print(f'begin_end_char_pos: {begin_end_char_pos}') |
|
|
|
j = 0 |
|
|
|
for i, pos_tuple in enumerate(begin_end_pos): |
|
|
|
if pos_tuple[0] > 0 and i == 0: |
|
ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0])) |
|
j += 1 |
|
if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 : |
|
ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0])) |
|
j += 1 |
|
|
|
j += 1 |
|
print('ne_position_mappings', ne_position_mappings) |
|
|
|
groups = [] |
|
k = 0 |
|
for i, ne_position_mapping in enumerate(ne_position_mappings): |
|
if type(ne_position_mapping) != tuple: |
|
ne_position_mapping = ne_position_mapping.to_tuple() |
|
ne = ne_position_mapping[1] |
|
|
|
text = '' |
|
for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]): |
|
_token = tokens[ne_position] |
|
text += _token if _token != self.space_token else ' ' |
|
if ne.lower() != 'o': |
|
groups.append({ |
|
'entity_group': ne, |
|
'word': text, |
|
'begin_char_index': begin_end_char_pos[k][0] |
|
}) |
|
k+=1 |
|
else: |
|
groups.append({ |
|
'entity_group': ne, |
|
'word': text, |
|
}) |
|
return groups |
|
|