''' |
Copyright 2022 The International Digital Economy Academy (IDEA). CCNL team. All rights reserved. |
Licensed under the Apache License, Version 2.0 (the "License"); |
you may not use this file except in compliance with the License. |
You may obtain a copy of the License at |
http://www.apache.org/licenses/LICENSE-2.0 |
Unless required by applicable law or agreed to in writing, software |
distributed under the License is distributed on an "AS IS" BASIS, |
@File : finetune_bart.py |
@Time : 2022/10/28 18:23 |
@Author : Qi Yang |
@Version : 1.0 |
@Contact : [email protected] |
@License : (C)Copyright 2022-2023, CCNL-IDEA |
''' |
from fengshen.models.model_utils import configure_optimizers |
from fengshen.data.universal_datamodule import UniversalDataModule |
from fengshen.utils.universal_checkpoint import UniversalCheckpoint |
from fengshen.utils import chinese_char_tokenize |
from utils import truncate_sequence, white_space_fix |
from utils import LabelSmoothingCrossEntropy |
import sys |
import os |
import torch |
import argparse |
import pytorch_lightning as pl |
from dataclasses import dataclass |
from pytorch_lightning import Trainer |
from pytorch_lightning.callbacks import LearningRateMonitor |
from transformers import BartForConditionalGeneration |
from transformers import BertTokenizer, AutoTokenizer |
from torchmetrics.text.rouge import ROUGEScore |
sys.path.append('../../../') |
@dataclass |
class QGT5Collator: |
@ staticmethod |
def add_data_specific_args(parent_args): |
parser = parent_args.add_argument_group('BART DIalo Collator') |
parser.add_argument('--max_seq_length', default=512, type=int) |
parser.add_argument('--max_src_length', default=32, type=int) |
parser.add_argument('--max_kno_length', default=416, type=int) |
parser.add_argument('--max_tgt_length', default=64, type=int) |
parser.add_argument('--mask_ans_style', |
default='normal', |
type=str, |
choices=['normal', 'unmask', 'anstoken', 'postag', 'anstoken_multispan', 'postag_multispan', 'normal_multispan']) |
return parent_args |
def __init__(self, tokenizer, args): |
self.args = args |
self.tokenizer = tokenizer |
self.max_seq_length = args.max_seq_length |
self.print_example = True |
self.mask_ans_style = args.mask_ans_style |
self.do_eval_only = args.do_eval_only |
self.tokenizer_type = args.tokenizer_type |
def encode(self, x, y): |
if self.tokenizer_type == "bert": |
x = x |
y = y |
else: |
x = self.tokenizer.bos_token + x + self.tokenizer.eos_token |
y = y + self.tokenizer.eos_token |
encoder_input = self.tokenizer.encode_plus( |
x, |
max_length=self.args.max_kno_length + self.args.max_src_length, |
padding="max_length", |
truncation=True, |
return_tensors='pt' |
) |
decoder_output = self.tokenizer.encode_plus( |
y, |
max_length=self.args.max_tgt_length, |
padding="max_length", |
truncation=True, |
return_tensors='pt' |
) |
return encoder_input, decoder_output |
def mask(self, s): |
def replace_span(source, target, sptoken): |
ans_bos, ans_eos = s["ans_span"][0] |
return source[:ans_bos] + sptoken + source[ans_eos:] |
def replace_all(source, target, sptoken): |
return source.replace(target, sptoken) |
if 'multispan' in self.mask_ans_style: |
fn = replace_all |
else: |
fn = replace_span |
if 'unmask' in self.mask_ans_style: |
return s["context"] |
if 'normal' in self.mask_ans_style: |
self.anstoken = self.tokenizer.mask_token |
masked_context = fn(s["context"], s["answer"][0], self.anstoken) |
return masked_context |
if 'anstoken' in self.mask_ans_style: |
anstoken_dict = { |
"bert": "[ANS]", |
"bart": "<ans>" |
} |
self.anstoken = anstoken_dict[self.tokenizer_type] |
masked_context = fn(s["context"], s["answer"][0], self.anstoken) |
return masked_context |
if 'postag' in self.mask_ans_style: |
begtoken, endtoken = "<beg>", "<eos>" |
self.anstoken = begtoken + s["answer"][0] + endtoken |
masked_context = fn(s["context"], s["answer"][0], self.anstoken) |
return masked_context |
return masked_context |
def prompt(self, context, answer, question): |
pre_prompt, mid_prompt, post_prompt = "知识:", "回答:", "问题:" |
context = truncate_sequence(context, self.args.max_kno_length-len(pre_prompt)-1) |
answer = truncate_sequence(answer, self.args.max_src_length - len(mid_prompt)-1) |
question = truncate_sequence(question, self.args.max_tgt_length-len(post_prompt)-1) |
x_trunc = f'{pre_prompt}{context}{mid_prompt}{answer}' |
y_trunc = f'{post_prompt}{question}' |
return x_trunc, y_trunc |
def __call__(self, samples): |
""" |
ans_num = 1 适用于 Train 数据只有 1 条 answer 取第一条情况 |
ans_num > 1 适用于 Dev 数据有多条 answer 情况 |
Input: |
input_ids: input_ids (text + answer) |
attn_mask: input attn mask |
labels: decoder_ids (question) |
""" |
input_ids, attn_mask, labels = [], [], [] |
ans, qes, ctx, ans_spans, idxs, imp = [], [], [], [], [], [] |
for s in samples: |
if self.do_eval_only: |
ans.append(s["answer"]) |
qes.append(s["question"]) |
ctx.append(s["context"]) |
ans_spans.append(s["ans_span"]) |
idxs.append(s["idx"]) |
if "is_impossible" in s: |
imp.append(s["is_impossible"]) |
else: |
imp.append(False) |
if not s["is_impossible"]: |
context = self.mask(s) |
answer = s["answer"][0] |
question = s["question"] |
else: |
context = s["context"] |
answer = "无答案" |
question = s["question"] |
x_trunc, y_trunc = self.prompt(context, answer, question) |
encoder_input, decoder_output = self.encode(x_trunc, y_trunc) |
input_ids.append(encoder_input["input_ids"]) |
attn_mask.append(encoder_input["attention_mask"]) |
labels.append(decoder_output["input_ids"]) |
labels = torch.cat(labels) |
if self.tokenizer_type == "bart": |
end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1] |
else: |
end_token_index = torch.where(labels == self.tokenizer.sep_token_id)[1] |
for idx, end_idx in enumerate(end_token_index): |
labels[idx][end_idx + 1:] = -100 |
data = { |
'input_ids': torch.cat(input_ids), |
'attention_mask': torch.cat(attn_mask), |
'labels': labels |
} |
if self.do_eval_only: |
data.update({ |
'answer': ans, |
'question': qes, |
'context': ctx, |
'ans_span': ans_spans, |
'idx': idxs, |
'is_impossible': imp |
}) |
if self.print_example: |
print(x_trunc) |
print(y_trunc) |
self.print_example = False |
return data |
class BARTFinetuneModel(pl.LightningModule): |
@staticmethod |
def add_model_specific_args(parent_args): |
parser = parent_args.add_argument_group('BaseModel') |
parser.add_argument('--model_path', type=str, default='') |
parser.add_argument('--learning_rate', default=1e-5, type=float) |
parser.add_argument('--min_learning_rate', default=1e-7, type=float) |
parser.add_argument('--lr_decay_steps', default=0, type=int) |
parser.add_argument('--lr_decay_ratio', default=1.0, type=float) |
parser.add_argument('--weight_decay', default=0.1, type=float) |
parser.add_argument('--warmup_steps', default=1000, type=int) |
parser.add_argument('--warmup_ratio', default=0.01, type=float) |
parser.add_argument('--label_smooth', default=0, type=float) |
parser.add_argument('--new_token_path', default="./", type=str) |
parser.add_argument('--adam_beta1', default=0.9, type=float) |
parser.add_argument('--adam_beta2', default=0.999, type=float) |
parser.add_argument('--adam_epsilon', default=1e-8, type=float) |
parser.add_argument('--scheduler_type', default='polynomial', type=str) |
return parent_args |
def __init__(self, tokenizer, args): |
super().__init__() |
self.save_hyperparameters(args) |
self.model = BartForConditionalGeneration.from_pretrained(args.model_path) |
self.tokenizer = tokenizer |
new_vocab = args.model_path+"/sp_vocab/" |
if not os.path.exists(new_vocab): |
os.makedirs(new_vocab) |
self.tokenizer.save_pretrained(new_vocab) |
self.model.resize_token_embeddings(len(tokenizer)) |
self.vocab_size = len(tokenizer) |
self.rougescore = ROUGEScore(rouge_keys=('rougeL'), normalizer=lambda x: x) |
if self.hparams.label_smooth: |
self.loss_fct = LabelSmoothingCrossEntropy(smoothing=0.1) |
def setup(self, stage) -> None: |
if stage == 'fit': |
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() |
if self.trainer.max_epochs > 0: |
world_size = self.trainer.world_size |
tb_size = self.hparams.train_batchsize * max(1, world_size) |
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) |
self.total_steps = (len(train_loader.dataset) * |
self.trainer.max_epochs // tb_size) // ab_size |
else: |
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches |
print('Total steps: {}' .format(self.total_steps)) |
def configure_optimizers(self): |
return configure_optimizers(self) |
def training_step(self, batch, batch_idx): |
output = self.model( |
input_ids=batch['input_ids'], |
attention_mask=batch['attention_mask'], |
labels=batch['labels']) |
loss = output.loss |
if self.hparams.label_smooth: |
loss = self.loss_fct(output.logits.view(-1, self.vocab_size), batch["labels"].view(-1)) |
self.log('train_loss', loss, sync_dist=True) |
return loss |
def validation_step(self, batch, batch_idx): |
output = self.model( |
input_ids=batch['input_ids'], |
attention_mask=batch['attention_mask'], |
labels=batch['labels']) |
acc = self.compute_acc(output.logits, batch['labels']) |
self.log('val_loss', output.loss, sync_dist=True) |
self.log('val_acc', acc, sync_dist=True) |
self.log('val_ppl', torch.exp(output.loss), sync_dist=True) |
cond_output = self.model.generate( |
input_ids=batch['input_ids'], |
attention_mask=batch['attention_mask'], |
do_sample=True, |
num_beams=5, |
early_stopping=True, |
max_length=64, |
top_p=0.9, |
) |
batch_label = torch.where(batch["labels"] != -100, batch["labels"], self.tokenizer.pad_token_id) |
pred = self.tokenizer.batch_decode(cond_output, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
ques = self.tokenizer.batch_decode(batch_label, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
pred = [chinese_char_tokenize(white_space_fix(p)) for p in pred] |
ques = [chinese_char_tokenize(white_space_fix(q)) for q in ques] |
self.rougescore.update(pred, ques) |
return pred |
def validation_epoch_end(self, validation_step_outputs): |
rouge = self.rougescore.compute() |
self.log('val_rouge', rouge["rougeL_fmeasure"], sync_dist=True) |
def on_predict_start(self): |
self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none') |
def predict_step(self, batch, batch_idx): |
output = self.model( |
input_ids=batch['input_ids'], |
attention_mask=batch['attention_mask'], |
labels=batch['labels']) |
loss_tensor = self.loss_fct(output.logits.transpose(1, 2), batch["labels"]) |
if self.hparams.tokenizer_type == 'bart': |
eos_index = torch.where(batch['labels'] == self.tokenizer.eos_token_id)[1] |
elif self.hparams.tokenizer_type == 'bert': |
eos_index = torch.where(batch['labels'] == self.tokenizer.sep_token_id)[1] |
loss = torch.sum(loss_tensor, dim=1) / eos_index |
with torch.no_grad(): |
cond_output = self.model.generate( |
input_ids=batch['input_ids'], |
attention_mask=batch['attention_mask'], |
do_sample=True, |
num_beams=5, |
max_length=64, |
top_p=0.9, |
output_scores=True, |
return_dict_in_generate=True |
) |
pred = self.tokenizer.batch_decode( |
cond_output.sequences, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
pred = [white_space_fix(p) for p in pred] |
score = cond_output.sequences_scores |
return pred, score, loss |
def compute_acc(self, logits, labels): |
y_pred = torch.argmax(logits, dim=-1) |
y_pred = y_pred.view(size=(-1,)) |
y_true = labels.view(size=(-1,)).float() |
corr = torch.eq(y_pred, y_true) |
acc = torch.sum(corr.float())/y_true.shape[0] |
return acc |
def on_save_checkpoint(self, checkpoint) -> None: |
if self.trainer._accelerator_connector.cluster_environment.global_rank() == 0: |
self.model.save_pretrained(os.path.join( |
self.trainer.checkpoint_callback.dirpath, |
'hf_pretrained_epoch{}_step{}'.format(checkpoint['epoch'], checkpoint['global_step']))) |
def on_load_checkpoint(self, checkpoint) -> None: |
global_step_offset = checkpoint["global_step"] |
if 'global_samples' in checkpoint: |
self.consumed_samples = checkpoint['global_samples'] |
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset |
def get_tokenizer(tokenizer_type, pretrained_model_path): |
if tokenizer_type == 'bart': |
tokenizer = AutoTokenizer.from_pretrained( |
pretrained_model_path, use_fast=False, additional_special_tokens=["<ans>", "<beg>", "<end>"]) |
print(len(tokenizer)) |
elif tokenizer_type == 'bert': |
tokenizer = BertTokenizer.from_pretrained( |
pretrained_model_path, use_fast=False, additional_special_tokens=["[ANS]"]) |
return tokenizer |
def main(): |
total_parser = argparse.ArgumentParser("Finetune BART for QG") |
total_parser.add_argument('--do_eval_only', action='store_true', default=False) |
total_parser.add_argument('--tokenizer_type', type=str, default="bart", choices=['bart', 'bert']) |
total_parser.add_argument('--tensorboard_dir', type=str, default="bart") |
total_parser.add_argument('--deepspeed') |
total_parser = UniversalDataModule.add_data_specific_args(total_parser) |
total_parser = QGT5Collator.add_data_specific_args(total_parser) |
total_parser = Trainer.add_argparse_args(total_parser) |
total_parser = UniversalCheckpoint.add_argparse_args(total_parser) |
total_parser = BARTFinetuneModel.add_model_specific_args(total_parser) |
args = total_parser.parse_args() |
tokenizer = get_tokenizer(args.tokenizer_type, args.model_path) |
collator = QGT5Collator(tokenizer=tokenizer, args=args) |
data_model = UniversalDataModule(collate_fn=collator, tokenizer=tokenizer, args=args) |
print("Data load complete...") |
if args.deepspeed is not None: |
os.environ['PL_DEEPSPEED_CONFIG_PATH'] = args.deepspeed |
model = BARTFinetuneModel(tokenizer, args) |
checkpoint_callback = UniversalCheckpoint(args) |
lr_monitor = LearningRateMonitor(logging_interval='step') |
trainer = Trainer.from_argparse_args(args, |
callbacks=[checkpoint_callback, lr_monitor] |
) |
if not args.do_eval_only: |
trainer.fit(model, data_model) |
if __name__ == '__main__': |
main() |