|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from logging import basicConfig, setLogRecordFactory |
|
import torch |
|
from torch import nn |
|
import json |
|
from tqdm import tqdm |
|
import os |
|
import numpy as np |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
BertTokenizer, |
|
file_utils |
|
) |
|
import pytorch_lightning as pl |
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning import trainer, loggers |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers.optimization import get_linear_schedule_with_warmup |
|
from transformers import BertForPreTraining, BertForMaskedLM, BertModel |
|
from transformers import BertConfig, BertForTokenClassification, BertPreTrainedModel |
|
import transformers |
|
import unicodedata |
|
import re |
|
import argparse |
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
|
|
|
|
|
|
def search(pattern, sequence): |
|
n = len(pattern) |
|
res = [] |
|
for i in range(len(sequence)): |
|
if sequence[i:i + n] == pattern: |
|
res.append([i, i + n-1]) |
|
return res |
|
|
|
|
|
class UbertDataset(Dataset): |
|
def __init__(self, data, tokenizer, args, used_mask=True): |
|
super().__init__() |
|
self.tokenizer = tokenizer |
|
self.max_length = args.max_length |
|
self.num_labels = args.num_labels |
|
self.used_mask = used_mask |
|
self.data = data |
|
self.args = args |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
return self.encode(self.data[index], self.used_mask) |
|
|
|
def encode(self, item, used_mask=False): |
|
input_ids1 = [] |
|
attention_mask1 = [] |
|
token_type_ids1 = [] |
|
span_labels1 = [] |
|
span_labels_masks1 = [] |
|
|
|
input_ids0 = [] |
|
attention_mask0 = [] |
|
token_type_ids0 = [] |
|
span_labels0 = [] |
|
span_labels_masks0 = [] |
|
|
|
subtask_type = item['subtask_type'] |
|
for choice in item['choices']: |
|
try: |
|
texta = item['task_type'] + '[SEP]' + \ |
|
subtask_type + '[SEP]' + choice['entity_type'] |
|
textb = item['text'] |
|
encode_dict = self.tokenizer.encode_plus(texta, textb, |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
|
|
encode_sent = encode_dict['input_ids'] |
|
encode_token_type_ids = encode_dict['token_type_ids'] |
|
encode_attention_mask = encode_dict['attention_mask'] |
|
span_label = np.zeros((self.max_length, self.max_length)) |
|
span_label_mask = np.zeros( |
|
(self.max_length, self.max_length))-10000 |
|
|
|
if item['task_type'] == '分类任务': |
|
span_label_mask[0, 0] = 0 |
|
span_label[0, 0] = choice['label'] |
|
|
|
else: |
|
question_len = len(self.tokenizer.encode(texta)) |
|
span_label_mask[question_len:, question_len:] = np.zeros( |
|
(self.max_length-question_len, self.max_length-question_len)) |
|
for entity in choice['entity_list']: |
|
|
|
|
|
entity_idx_list = entity['entity_idx'] |
|
if entity_idx_list == []: |
|
continue |
|
for entity_idx in entity_idx_list: |
|
if entity_idx == []: |
|
continue |
|
start_idx_text = item['text'][:entity_idx[0]] |
|
start_idx_text_encode = self.tokenizer.encode( |
|
start_idx_text, add_special_tokens=False) |
|
start_idx = question_len + \ |
|
len(start_idx_text_encode) |
|
|
|
end_idx_text = item['text'][:entity_idx[1]+1] |
|
end_idx_text_encode = self.tokenizer.encode( |
|
end_idx_text, add_special_tokens=False) |
|
end_idx = question_len + \ |
|
len(end_idx_text_encode) - 1 |
|
if start_idx < self.max_length and end_idx < self.max_length: |
|
span_label[start_idx, end_idx] = 1 |
|
|
|
if np.sum(span_label) < 1: |
|
input_ids0.append(encode_sent) |
|
attention_mask0.append(encode_attention_mask) |
|
token_type_ids0.append(encode_token_type_ids) |
|
span_labels0.append(span_label) |
|
span_labels_masks0.append(span_label_mask) |
|
else: |
|
input_ids1.append(encode_sent) |
|
attention_mask1.append(encode_attention_mask) |
|
token_type_ids1.append(encode_token_type_ids) |
|
span_labels1.append(span_label) |
|
span_labels_masks1.append(span_label_mask) |
|
except: |
|
print(item) |
|
print(texta) |
|
print(textb) |
|
|
|
randomize = np.arange(len(input_ids0)) |
|
np.random.shuffle(randomize) |
|
cur = 0 |
|
count = len(input_ids1) |
|
while count < self.args.num_labels: |
|
if cur < len(randomize): |
|
input_ids1.append(input_ids0[randomize[cur]]) |
|
attention_mask1.append(attention_mask0[randomize[cur]]) |
|
token_type_ids1.append(token_type_ids0[randomize[cur]]) |
|
span_labels1.append(span_labels0[randomize[cur]]) |
|
span_labels_masks1.append(span_labels_masks0[randomize[cur]]) |
|
cur += 1 |
|
count += 1 |
|
|
|
while len(input_ids1) < self.args.num_labels: |
|
input_ids1.append([0]*self.max_length) |
|
attention_mask1.append([0]*self.max_length) |
|
token_type_ids1.append([0]*self.max_length) |
|
span_labels1.append(np.zeros((self.max_length, self.max_length))) |
|
span_labels_masks1.append( |
|
np.zeros((self.max_length, self.max_length))-10000) |
|
|
|
input_ids = input_ids1[:self.args.num_labels] |
|
attention_mask = attention_mask1[:self.args.num_labels] |
|
token_type_ids = token_type_ids1[:self.args.num_labels] |
|
span_labels = span_labels1[:self.args.num_labels] |
|
span_labels_masks = span_labels_masks1[:self.args.num_labels] |
|
|
|
span_labels = np.array(span_labels) |
|
span_labels_masks = np.array(span_labels_masks) |
|
if np.sum(span_labels) < 1: |
|
span_labels[-1, -1, -1] = 1 |
|
span_labels_masks[-1, -1, -1] = 10000 |
|
|
|
sample = { |
|
"input_ids": torch.tensor(input_ids).long(), |
|
"token_type_ids": torch.tensor(token_type_ids).long(), |
|
"attention_mask": torch.tensor(attention_mask).float(), |
|
"span_labels": torch.tensor(span_labels).float(), |
|
"span_labels_mask": torch.tensor(span_labels_masks).float() |
|
} |
|
|
|
return sample |
|
|
|
|
|
class UbertDataModel(pl.LightningDataModule): |
|
@staticmethod |
|
def add_data_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('TASK NAME DataModel') |
|
parser.add_argument('--num_workers', default=8, type=int) |
|
parser.add_argument('--batchsize', default=8, type=int) |
|
parser.add_argument('--max_length', default=128, type=int) |
|
return parent_args |
|
|
|
def __init__(self, train_data, val_data, tokenizer, args): |
|
super().__init__() |
|
self.batchsize = args.batchsize |
|
|
|
self.train_data = UbertDataset(train_data, tokenizer, args, True) |
|
self.valid_data = UbertDataset(val_data, tokenizer, args, False) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_data, shuffle=True, batch_size=self.batchsize, pin_memory=False) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.valid_data, shuffle=False, batch_size=self.batchsize, pin_memory=False) |
|
|
|
|
|
class biaffine(nn.Module): |
|
def __init__(self, in_size, out_size, bias_x=True, bias_y=True): |
|
super().__init__() |
|
self.bias_x = bias_x |
|
self.bias_y = bias_y |
|
self.out_size = out_size |
|
self.U = torch.nn.Parameter(torch.zeros( |
|
in_size + int(bias_x), out_size, in_size + int(bias_y))) |
|
torch.nn.init.normal_(self.U, mean=0, std=0.1) |
|
|
|
def forward(self, x, y): |
|
if self.bias_x: |
|
x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1) |
|
if self.bias_y: |
|
y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1) |
|
bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y) |
|
return bilinar_mapping |
|
|
|
|
|
class MultilabelCrossEntropy(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, y_pred, y_true): |
|
y_true = y_true.float() |
|
y_pred = torch.mul((1.0 - torch.mul(y_true, 2.0)), y_pred) |
|
y_pred_neg = y_pred - torch.mul(y_true, 1e12) |
|
y_pred_pos = y_pred - torch.mul(1.0 - y_true, 1e12) |
|
zeros = torch.zeros_like(y_pred[..., :1]) |
|
y_pred_neg = torch.cat([y_pred_neg, zeros], axis=-1) |
|
y_pred_pos = torch.cat([y_pred_pos, zeros], axis=-1) |
|
neg_loss = torch.logsumexp(y_pred_neg, axis=-1) |
|
pos_loss = torch.logsumexp(y_pred_pos, axis=-1) |
|
loss = torch.mean(neg_loss + pos_loss) |
|
return loss |
|
|
|
|
|
class UbertModel(BertPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.bert = BertModel(config) |
|
self.query_layer = torch.nn.Sequential(torch.nn.Linear(in_features=self.config.hidden_size, |
|
out_features=self.config.biaffine_size), |
|
torch.nn.GELU()) |
|
self.key_layer = torch.nn.Sequential(torch.nn.Linear(in_features=self.config.hidden_size, out_features=self.config.biaffine_size), |
|
torch.nn.GELU()) |
|
self.biaffine_query_key_cls = biaffine(self.config.biaffine_size, 1) |
|
self.loss_softmax = MultilabelCrossEntropy() |
|
self.loss_sigmoid = torch.nn.BCEWithLogitsLoss(reduction='mean') |
|
|
|
def forward(self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
span_labels=None, |
|
span_labels_mask=None): |
|
|
|
batch_size, num_label, seq_len = input_ids.shape |
|
|
|
input_ids = input_ids.view(-1, seq_len) |
|
attention_mask = attention_mask.view(-1, seq_len) |
|
token_type_ids = token_type_ids.view(-1, seq_len) |
|
|
|
batch_size, seq_len = input_ids.shape |
|
outputs = self.bert(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
output_hidden_states=True) |
|
|
|
hidden_states = outputs[0] |
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
|
query = self.query_layer(hidden_states) |
|
key = self.key_layer(hidden_states) |
|
|
|
span_logits = self.biaffine_query_key_cls( |
|
query, key).reshape(-1, num_label, seq_len, seq_len) |
|
|
|
span_logits = span_logits + span_labels_mask |
|
|
|
if span_labels == None: |
|
return 0, span_logits |
|
else: |
|
soft_loss1 = self.loss_softmax( |
|
span_logits.reshape(-1, num_label, seq_len*seq_len), span_labels.reshape(-1, num_label, seq_len*seq_len)) |
|
soft_loss2 = self.loss_softmax(span_logits.permute( |
|
0, 2, 3, 1), span_labels.permute(0, 2, 3, 1)) |
|
sig_loss = self.loss_sigmoid(span_logits, span_labels) |
|
all_loss = 10*(100*sig_loss+soft_loss1+soft_loss2) |
|
return all_loss, span_logits |
|
|
|
|
|
class UbertLitModel(pl.LightningModule): |
|
@staticmethod |
|
def add_model_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
|
|
parser.add_argument('--learning_rate', default=1e-5, type=float) |
|
parser.add_argument('--weight_decay', default=0.1, type=float) |
|
parser.add_argument('--warmup', default=0.01, type=float) |
|
parser.add_argument('--num_labels', default=10, type=int) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args, num_data=1): |
|
super().__init__() |
|
self.args = args |
|
self.num_data = num_data |
|
self.model = UbertModel.from_pretrained( |
|
self.args.pretrained_model_path) |
|
self.count = 0 |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 |
|
self.total_step = int(self.trainer.max_epochs * self.num_data / |
|
(max(1, num_gpus) * self.trainer.accumulate_grad_batches)) |
|
print('Total training step:', self.total_step) |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss, span_logits = self.model(**batch) |
|
span_acc, recall, precise = self.comput_metrix_span( |
|
span_logits, batch['span_labels']) |
|
self.log('train_loss', loss) |
|
self.log('train_span_acc', span_acc) |
|
self.log('train_span_recall', recall) |
|
self.log('train_span_precise', precise) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss, span_logits = self.model(**batch) |
|
span_acc, recall, precise = self.comput_metrix_span( |
|
span_logits, batch['span_labels']) |
|
|
|
self.log('val_loss', loss) |
|
self.log('val_span_acc', span_acc) |
|
self.log('val_span_recall', recall) |
|
self.log('val_span_precise', precise) |
|
|
|
def predict_step(self, batch, batch_idx): |
|
loss, span_logits = self.model(**batch) |
|
span_acc = self.comput_metrix_span(span_logits, batch['span_labels']) |
|
return span_acc.item() |
|
|
|
def configure_optimizers(self): |
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
paras = list( |
|
filter(lambda p: p[1].requires_grad, self.named_parameters())) |
|
paras = [{ |
|
'params': |
|
[p for n, p in paras if not any(nd in n for nd in no_decay)], |
|
'weight_decay': self.args.weight_decay |
|
}, { |
|
'params': [p for n, p in paras if any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.0 |
|
}] |
|
optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, int(self.total_step * self.args.warmup), |
|
self.total_step) |
|
|
|
return [{ |
|
'optimizer': optimizer, |
|
'lr_scheduler': { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
'frequency': 1 |
|
} |
|
}] |
|
|
|
def comput_metrix_span(self, logits, labels): |
|
ones = torch.ones_like(logits) |
|
zero = torch.zeros_like(logits) |
|
logits = torch.where(logits < 0, zero, ones) |
|
y_pred = logits.view(size=(-1,)) |
|
y_true = labels.view(size=(-1,)) |
|
corr = torch.eq(y_pred, y_true).float() |
|
corr = torch.multiply(y_true, corr) |
|
recall = torch.sum(corr.float())/(torch.sum(y_true.float())+1e-5) |
|
precise = torch.sum(corr.float())/(torch.sum(y_pred.float())+1e-5) |
|
f1 = 2*recall*precise/(recall+precise+1e-5) |
|
return f1, recall, precise |
|
|
|
|
|
class TaskModelCheckpoint: |
|
@staticmethod |
|
def add_argparse_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
|
|
parser.add_argument('--monitor', default='train_loss', type=str) |
|
parser.add_argument('--mode', default='min', type=str) |
|
parser.add_argument('--checkpoint_path', |
|
default='./checkpoint/', type=str) |
|
parser.add_argument( |
|
'--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) |
|
|
|
parser.add_argument('--save_top_k', default=3, type=float) |
|
parser.add_argument('--every_n_epochs', default=1, type=float) |
|
parser.add_argument('--every_n_train_steps', default=100, type=float) |
|
|
|
parser.add_argument('--save_weights_only', default=True, type=bool) |
|
return parent_args |
|
|
|
def __init__(self, args): |
|
self.callbacks = ModelCheckpoint(monitor=args.monitor, |
|
save_top_k=args.save_top_k, |
|
mode=args.mode, |
|
save_last=True, |
|
every_n_train_steps=args.every_n_train_steps, |
|
save_weights_only=args.save_weights_only, |
|
dirpath=args.checkpoint_path, |
|
filename=args.filename) |
|
|
|
|
|
class OffsetMapping: |
|
def __init__(self): |
|
self._do_lower_case = True |
|
|
|
@staticmethod |
|
def stem(token): |
|
if token[:2] == '##': |
|
return token[2:] |
|
else: |
|
return token |
|
|
|
@staticmethod |
|
def _is_control(ch): |
|
return unicodedata.category(ch) in ('Cc', 'Cf') |
|
|
|
@staticmethod |
|
def _is_special(ch): |
|
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') |
|
|
|
def rematch(self, text, tokens): |
|
if self._do_lower_case: |
|
text = text.lower() |
|
|
|
normalized_text, char_mapping = '', [] |
|
for i, ch in enumerate(text): |
|
if self._do_lower_case: |
|
ch = unicodedata.normalize('NFD', ch) |
|
ch = ''.join( |
|
[c for c in ch if unicodedata.category(c) != 'Mn']) |
|
ch = ''.join([ |
|
c for c in ch |
|
if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) |
|
]) |
|
normalized_text += ch |
|
char_mapping.extend([i] * len(ch)) |
|
|
|
text, token_mapping, offset = normalized_text, [], 0 |
|
for token in tokens: |
|
if self._is_special(token): |
|
token_mapping.append([offset]) |
|
offset += 1 |
|
else: |
|
token = self.stem(token) |
|
start = text[offset:].index(token) + offset |
|
end = start + len(token) |
|
token_mapping.append(char_mapping[start:end]) |
|
offset = end |
|
|
|
return token_mapping |
|
|
|
|
|
class extractModel: |
|
''' |
|
# 在我目前提交的这一版程序中,这个方法已经不再需要被调用了。 |
|
def get_actual_id(self, text, query_text, tokenizer, args): |
|
text_encode = tokenizer.encode(text) |
|
one_input_encode = tokenizer.encode(query_text) |
|
text_start_id = search(text_encode[1:-1], one_input_encode)[0][0] |
|
text_end_id = text_start_id+len(text_encode)-1 |
|
if text_end_id > args.max_length: |
|
text_end_id = args.max_length |
|
|
|
text_token = tokenizer.tokenize(text) |
|
text_mapping = OffsetMapping().rematch(text, text_token) |
|
|
|
return text_start_id, text_end_id, text_mapping, one_input_encode |
|
''' |
|
|
|
def extract_index(self, span_logits, sample_length, split_value=0.5): |
|
result = [] |
|
for i in range(sample_length): |
|
for j in range(i, sample_length): |
|
if span_logits[i, j] > split_value: |
|
result.append((i, j, span_logits[i, j])) |
|
return result |
|
|
|
def extract_entity(self, text, entity_idx, text_start_id, text_mapping): |
|
start_split = text_mapping[entity_idx[0]-text_start_id] if entity_idx[0] - \ |
|
text_start_id < len(text_mapping) and entity_idx[0]-text_start_id >= 0 else [] |
|
end_split = text_mapping[entity_idx[1]-text_start_id] if entity_idx[1] - \ |
|
text_start_id < len(text_mapping) and entity_idx[1]-text_start_id >= 0 else [] |
|
entity = '' |
|
if start_split != [] and end_split != []: |
|
entity = text[start_split[0]:end_split[-1]+1] |
|
return entity |
|
|
|
def extract(self, batch_data, model, tokenizer, args): |
|
input_ids = [] |
|
attention_mask = [] |
|
token_type_ids = [] |
|
span_labels_masks = [] |
|
|
|
for item in batch_data: |
|
input_ids0 = [] |
|
attention_mask0 = [] |
|
token_type_ids0 = [] |
|
span_labels_masks0 = [] |
|
for choice in item['choices']: |
|
texta = item['task_type'] + '[SEP]' + \ |
|
item['subtask_type'] + '[SEP]' + choice['entity_type'] |
|
textb = item['text'] |
|
encode_dict = tokenizer.encode_plus(texta, textb, |
|
max_length=args.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
|
|
encode_sent = encode_dict['input_ids'] |
|
encode_token_type_ids = encode_dict['token_type_ids'] |
|
encode_attention_mask = encode_dict['attention_mask'] |
|
span_label_mask = np.zeros( |
|
(args.max_length, args.max_length))-10000 |
|
|
|
if item['task_type'] == '分类任务': |
|
span_label_mask[0, 0] = 0 |
|
else: |
|
question_len = len(tokenizer.encode(texta)) |
|
span_label_mask[question_len:, question_len:] = np.zeros( |
|
(args.max_length-question_len, args.max_length-question_len)) |
|
input_ids0.append(encode_sent) |
|
attention_mask0.append(encode_attention_mask) |
|
token_type_ids0.append(encode_token_type_ids) |
|
span_labels_masks0.append(span_label_mask) |
|
|
|
input_ids.append(input_ids0) |
|
attention_mask.append(attention_mask0) |
|
token_type_ids.append(token_type_ids0) |
|
span_labels_masks.append(span_labels_masks0) |
|
|
|
input_ids = torch.tensor(input_ids).to(model.device) |
|
attention_mask = torch.tensor(attention_mask).to(model.device) |
|
token_type_ids = torch.tensor(token_type_ids).to(model.device) |
|
|
|
|
|
span_labels_mask = torch.tensor(np.array(span_labels_masks)).to(model.device) |
|
|
|
_, span_logits = model.model(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
span_labels=None, |
|
span_labels_mask=span_labels_mask) |
|
|
|
|
|
span_logits = torch.sigmoid(span_logits) |
|
span_logits = span_logits.cpu().detach().numpy() |
|
|
|
for i, item in enumerate(batch_data): |
|
if item['task_type'] == '分类任务': |
|
cls_idx = 0 |
|
max_c = np.argmax(span_logits[i, :, cls_idx, cls_idx]) |
|
batch_data[i]['choices'][max_c]['label'] = 1 |
|
batch_data[i]['choices'][max_c]['score'] = span_logits[i, |
|
max_c, cls_idx, cls_idx] |
|
else: |
|
|
|
''' |
|
优化了代码效率,并修复了一些bug: |
|
1.通过合理的调整程序,去掉了“text_start_id, text_end_id, offset_mapping, input_ids = self.get_actual_id(item['text'], texta+'[SEP]'+textb, tokenizer, args)”。 |
|
2.保证在一个item任务中,item['text']的“encode”、“tokenize”只需要执行一次,而不是像之前一样会因为item['choices']的多寡而重复执行。 |
|
3.修复了"抽取式阅读理解"无法在item['choices']中有多个实体的情况下,正确提取文字内容,以及提取的文字内容出现错位的问题。 |
|
4.在“抽取式阅读理解”任务下,增加了top_k的选项:可在预测数据的"choices"下,增加top_k属性,如:{"entity_type": "***", "top_k": 2},若未设置top_k属性,则默认为1。 |
|
5.为"抽取任务"下除"抽取式阅读理解"之外的子任务,增加了“entity_name”的过滤,保证“entity_name”唯一。 |
|
''' |
|
textb = item['text'] |
|
offset_mapping = OffsetMapping().rematch(textb, tokenizer.tokenize(textb)) |
|
|
|
input_ids = tokenizer.encode('[SEP]' + textb, |
|
max_length=args.max_length, |
|
truncation='longest_first') |
|
|
|
for c in range(len(item['choices'])): |
|
|
|
texta = item['task_type'] + '[SEP]' + item['subtask_type'] + \ |
|
'[SEP]' + item['choices'][c]['entity_type'] |
|
|
|
|
|
text_start_id = len(tokenizer.encode(texta)) |
|
|
|
logits = span_logits[i, c, :, :] |
|
|
|
entity_name_list = [] |
|
entity_list = [] |
|
if item['subtask_type'] == '抽取式阅读理解': |
|
|
|
try: |
|
|
|
top_k = int(item['choices'][c]['top_k']) |
|
except KeyError: |
|
|
|
top_k = 1 |
|
|
|
if( 0 >= top_k ): |
|
|
|
top_k = 1 |
|
|
|
_, top_indices = torch.topk(torch.flatten(torch.tensor(logits)), top_k) |
|
|
|
for top_idx in top_indices: |
|
|
|
max_index = np.unravel_index(top_idx, logits.shape) |
|
|
|
if logits[max_index] > args.threshold: |
|
|
|
entity = self.extract_entity( |
|
item['text'], (max_index[0], max_index[1]), text_start_id, offset_mapping) |
|
|
|
entity = { |
|
'entity_name': entity, |
|
'score': logits[max_index] |
|
} |
|
|
|
entity_list.append(entity) |
|
else: |
|
|
|
sample_length = text_start_id + len(input_ids) |
|
entity_idx_type_list = self.extract_index( |
|
logits, sample_length, split_value=args.threshold) |
|
|
|
for entity_idx in entity_idx_type_list: |
|
|
|
entity = self.extract_entity( |
|
item['text'], (entity_idx[0], entity_idx[1]), text_start_id, offset_mapping) |
|
|
|
if entity not in entity_name_list: |
|
|
|
entity_name_list.append(entity) |
|
|
|
entity = { |
|
'entity_name': entity, |
|
'score': entity_idx[2] |
|
} |
|
entity_list.append(entity) |
|
|
|
batch_data[i]['choices'][c]['entity_list'] = entity_list |
|
return batch_data |
|
|
|
|
|
class UbertPipelines: |
|
@staticmethod |
|
def pipelines_args(parent_args): |
|
total_parser = parent_args.add_argument_group("pipelines args") |
|
total_parser.add_argument( |
|
'--pretrained_model_path', default='IDEA-CCNL/Erlangshen-Ubert-110M-Chinese', type=str) |
|
total_parser.add_argument('--output_save_path', |
|
default='./predict.json', type=str) |
|
|
|
total_parser.add_argument('--load_checkpoints_path', |
|
default='', type=str) |
|
|
|
total_parser.add_argument('--max_extract_entity_number', |
|
default=1, type=float) |
|
|
|
total_parser.add_argument('--train', action='store_true') |
|
|
|
total_parser.add_argument('--threshold', |
|
default=0.5, type=float) |
|
|
|
total_parser = UbertDataModel.add_data_specific_args(total_parser) |
|
total_parser = TaskModelCheckpoint.add_argparse_args(total_parser) |
|
total_parser = UbertLitModel.add_model_specific_args(total_parser) |
|
total_parser = pl.Trainer.add_argparse_args(parent_args) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args): |
|
|
|
if args.load_checkpoints_path != '': |
|
self.model = UbertLitModel.load_from_checkpoint( |
|
args.load_checkpoints_path, args=args) |
|
else: |
|
self.model = UbertLitModel(args) |
|
|
|
self.args = args |
|
self.checkpoint_callback = TaskModelCheckpoint(args).callbacks |
|
self.logger = loggers.TensorBoardLogger(save_dir=args.default_root_dir) |
|
self.trainer = pl.Trainer.from_argparse_args(args, |
|
logger=self.logger, |
|
callbacks=[self.checkpoint_callback]) |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_path, |
|
additional_special_tokens=['[unused'+str(i+1)+']' for i in range(99)]) |
|
|
|
self.em = extractModel() |
|
|
|
def fit(self, train_data, dev_data): |
|
data_model = UbertDataModel( |
|
train_data, dev_data, self.tokenizer, self.args) |
|
self.model.num_data = len(train_data) |
|
self.trainer.fit(self.model, data_model) |
|
|
|
''' |
|
通过增加“桶”的概念实现了,在一批预测数据的“choices”中可以存在不同数量的实体。 |
|
''' |
|
def predict(self, predict_data, cuda=True): |
|
result = [] |
|
start = 0 |
|
if cuda: |
|
self.model = self.model.cuda() |
|
self.model.eval() |
|
while start < len(predict_data): |
|
batch_data = predict_data[start:start+self.args.batchsize] |
|
start += self.args.batchsize |
|
|
|
|
|
batch_data_bucket = {} |
|
for item in batch_data: |
|
|
|
choice_num = len(item['choices']) |
|
|
|
try: |
|
|
|
batch_data_bucket[choice_num].append(item) |
|
except KeyError: |
|
|
|
batch_data_bucket[choice_num] = [] |
|
batch_data_bucket[choice_num].append(item) |
|
|
|
for k, batch_data in batch_data_bucket.items(): |
|
|
|
batch_result = self.em.extract( |
|
batch_data, self.model, self.tokenizer, self.args) |
|
result.extend(batch_result) |
|
|
|
return result |
|
|