# -*- encoding: utf-8 -*- ''' Copyright 2022 The International Digital Economy Academy (IDEA). CCNL team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @File : finetune_bart.py @Time : 2022/10/28 18:23 @Author : Qi Yang @Version : 1.0 @Contact : yangqi@idea.edu.cn @License : (C)Copyright 2022-2023, CCNL-IDEA ''' from fengshen.models.model_utils import configure_optimizers from fengshen.data.universal_datamodule import UniversalDataModule from fengshen.utils.universal_checkpoint import UniversalCheckpoint from fengshen.utils import chinese_char_tokenize from utils import truncate_sequence, white_space_fix from utils import LabelSmoothingCrossEntropy import sys import os import torch import argparse import pytorch_lightning as pl from dataclasses import dataclass from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor from transformers import BartForConditionalGeneration from transformers import BertTokenizer, AutoTokenizer from torchmetrics.text.rouge import ROUGEScore sys.path.append('../../../') @dataclass class QGT5Collator: @ staticmethod def add_data_specific_args(parent_args): # the hyperparameters should be determined according to the max length of context in dataset parser = parent_args.add_argument_group('BART DIalo Collator') parser.add_argument('--max_seq_length', default=512, type=int) parser.add_argument('--max_src_length', default=32, type=int) parser.add_argument('--max_kno_length', default=416, type=int) parser.add_argument('--max_tgt_length', default=64, type=int) parser.add_argument('--mask_ans_style', default='normal', type=str, choices=['normal', 'unmask', 'anstoken', 'postag', 'anstoken_multispan', 'postag_multispan', 'normal_multispan']) return parent_args def __init__(self, tokenizer, args): self.args = args self.tokenizer = tokenizer self.max_seq_length = args.max_seq_length self.print_example = True self.mask_ans_style = args.mask_ans_style self.do_eval_only = args.do_eval_only self.tokenizer_type = args.tokenizer_type def encode(self, x, y): if self.tokenizer_type == "bert": x = x y = y else: # t5 sentence piece x = self.tokenizer.bos_token + x + self.tokenizer.eos_token y = y + self.tokenizer.eos_token encoder_input = self.tokenizer.encode_plus( x, max_length=self.args.max_kno_length + self.args.max_src_length, padding="max_length", truncation=True, return_tensors='pt' ) decoder_output = self.tokenizer.encode_plus( y, max_length=self.args.max_tgt_length, padding="max_length", truncation=True, return_tensors='pt' ) return encoder_input, decoder_output def mask(self, s): def replace_span(source, target, sptoken): ans_bos, ans_eos = s["ans_span"][0] return source[:ans_bos] + sptoken + source[ans_eos:] def replace_all(source, target, sptoken): return source.replace(target, sptoken) if 'multispan' in self.mask_ans_style: fn = replace_all else: fn = replace_span # unmask: 北京是中国的首都 if 'unmask' in self.mask_ans_style: return s["context"] # normal: 北京是 的首都 if 'normal' in self.mask_ans_style: self.anstoken = self.tokenizer.mask_token masked_context = fn(s["context"], s["answer"][0], self.anstoken) return masked_context # anstoken: 北京是 [ANS] 的首都 if 'anstoken' in self.mask_ans_style: anstoken_dict = { "bert": "[ANS]", "bart": "" } self.anstoken = anstoken_dict[self.tokenizer_type] masked_context = fn(s["context"], s["answer"][0], self.anstoken) return masked_context # postag: 北京是 中国 的首都 if 'postag' in self.mask_ans_style: begtoken, endtoken = "", "" self.anstoken = begtoken + s["answer"][0] + endtoken masked_context = fn(s["context"], s["answer"][0], self.anstoken) return masked_context return masked_context def prompt(self, context, answer, question): pre_prompt, mid_prompt, post_prompt = "知识:", "回答:", "问题:" # prompt context = truncate_sequence(context, self.args.max_kno_length-len(pre_prompt)-1) # used in squad-2.0 # noted that src and tgt is reversed in qg answer = truncate_sequence(answer, self.args.max_src_length - len(mid_prompt)-1) question = truncate_sequence(question, self.args.max_tgt_length-len(post_prompt)-1) x_trunc = f'{pre_prompt}{context}{mid_prompt}{answer}' y_trunc = f'{post_prompt}{question}' return x_trunc, y_trunc def __call__(self, samples): """ ans_num = 1 适用于 Train 数据只有 1 条 answer 取第一条情况 ans_num > 1 适用于 Dev 数据有多条 answer 情况 Input: input_ids: input_ids (text + answer) attn_mask: input attn mask labels: decoder_ids (question) """ input_ids, attn_mask, labels = [], [], [] ans, qes, ctx, ans_spans, idxs, imp = [], [], [], [], [], [] for s in samples: if self.do_eval_only: # log origin answer to compare ans.append(s["answer"]) qes.append(s["question"]) ctx.append(s["context"]) ans_spans.append(s["ans_span"]) idxs.append(s["idx"]) if "is_impossible" in s: imp.append(s["is_impossible"]) else: imp.append(False) # SQUAD 1.0 don't have is_impossible if not s["is_impossible"]: # have ans and ans_span context = self.mask(s) answer = s["answer"][0] question = s["question"] else: # no ans and ans_span context = s["context"] answer = "无答案" question = s["question"] x_trunc, y_trunc = self.prompt(context, answer, question) encoder_input, decoder_output = self.encode(x_trunc, y_trunc) input_ids.append(encoder_input["input_ids"]) attn_mask.append(encoder_input["attention_mask"]) labels.append(decoder_output["input_ids"]) labels = torch.cat(labels) if self.tokenizer_type == "bart": end_token_index = torch.where(labels == self.tokenizer.eos_token_id)[1] else: end_token_index = torch.where(labels == self.tokenizer.sep_token_id)[1] for idx, end_idx in enumerate(end_token_index): labels[idx][end_idx + 1:] = -100 # cross entropy cal data = { 'input_ids': torch.cat(input_ids), 'attention_mask': torch.cat(attn_mask), 'labels': labels } if self.do_eval_only: data.update({ 'answer': ans, 'question': qes, 'context': ctx, 'ans_span': ans_spans, 'idx': idxs, 'is_impossible': imp }) if self.print_example: print(x_trunc) print(y_trunc) self.print_example = False return data class BARTFinetuneModel(pl.LightningModule): @staticmethod def add_model_specific_args(parent_args): parser = parent_args.add_argument_group('BaseModel') parser.add_argument('--model_path', type=str, default='') parser.add_argument('--learning_rate', default=1e-5, type=float) parser.add_argument('--min_learning_rate', default=1e-7, type=float) parser.add_argument('--lr_decay_steps', default=0, type=int) parser.add_argument('--lr_decay_ratio', default=1.0, type=float) parser.add_argument('--weight_decay', default=0.1, type=float) parser.add_argument('--warmup_steps', default=1000, type=int) parser.add_argument('--warmup_ratio', default=0.01, type=float) parser.add_argument('--label_smooth', default=0, type=float) parser.add_argument('--new_token_path', default="./", type=str) # save new token after add special token parser.add_argument('--adam_beta1', default=0.9, type=float) parser.add_argument('--adam_beta2', default=0.999, type=float) parser.add_argument('--adam_epsilon', default=1e-8, type=float) parser.add_argument('--scheduler_type', default='polynomial', type=str) return parent_args def __init__(self, tokenizer, args): super().__init__() self.save_hyperparameters(args) self.model = BartForConditionalGeneration.from_pretrained(args.model_path) self.tokenizer = tokenizer # add special token ans # self.tokenizer.save_vocabulary(self.args.model_path) new_vocab = args.model_path+"/sp_vocab/" if not os.path.exists(new_vocab): os.makedirs(new_vocab) self.tokenizer.save_pretrained(new_vocab) self.model.resize_token_embeddings(len(tokenizer)) self.vocab_size = len(tokenizer) self.rougescore = ROUGEScore(rouge_keys=('rougeL'), normalizer=lambda x: x) if self.hparams.label_smooth: self.loss_fct = LabelSmoothingCrossEntropy(smoothing=0.1) def setup(self, stage) -> None: if stage == 'fit': train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() # Calculate total steps if self.trainer.max_epochs > 0: world_size = self.trainer.world_size tb_size = self.hparams.train_batchsize * max(1, world_size) ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) self.total_steps = (len(train_loader.dataset) * self.trainer.max_epochs // tb_size) // ab_size else: self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches print('Total steps: {}' .format(self.total_steps)) def configure_optimizers(self): return configure_optimizers(self) def training_step(self, batch, batch_idx): output = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) loss = output.loss if self.hparams.label_smooth: loss = self.loss_fct(output.logits.view(-1, self.vocab_size), batch["labels"].view(-1)) self.log('train_loss', loss, sync_dist=True) return loss def validation_step(self, batch, batch_idx): output = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) acc = self.compute_acc(output.logits, batch['labels']) self.log('val_loss', output.loss, sync_dist=True) self.log('val_acc', acc, sync_dist=True) self.log('val_ppl', torch.exp(output.loss), sync_dist=True) cond_output = self.model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], do_sample=True, num_beams=5, early_stopping=True, max_length=64, top_p=0.9, ) batch_label = torch.where(batch["labels"] != -100, batch["labels"], self.tokenizer.pad_token_id) pred = self.tokenizer.batch_decode(cond_output, clean_up_tokenization_spaces=True, skip_special_tokens=True) ques = self.tokenizer.batch_decode(batch_label, clean_up_tokenization_spaces=True, skip_special_tokens=True) pred = [chinese_char_tokenize(white_space_fix(p)) for p in pred] ques = [chinese_char_tokenize(white_space_fix(q)) for q in ques] self.rougescore.update(pred, ques) return pred def validation_epoch_end(self, validation_step_outputs): rouge = self.rougescore.compute() self.log('val_rouge', rouge["rougeL_fmeasure"], sync_dist=True) def on_predict_start(self): self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none') def predict_step(self, batch, batch_idx): output = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) loss_tensor = self.loss_fct(output.logits.transpose(1, 2), batch["labels"]) if self.hparams.tokenizer_type == 'bart': eos_index = torch.where(batch['labels'] == self.tokenizer.eos_token_id)[1] elif self.hparams.tokenizer_type == 'bert': eos_index = torch.where(batch['labels'] == self.tokenizer.sep_token_id)[1] loss = torch.sum(loss_tensor, dim=1) / eos_index with torch.no_grad(): cond_output = self.model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], do_sample=True, num_beams=5, max_length=64, top_p=0.9, output_scores=True, return_dict_in_generate=True ) pred = self.tokenizer.batch_decode( cond_output.sequences, clean_up_tokenization_spaces=True, skip_special_tokens=True) # ['sequences'] pred = [white_space_fix(p) for p in pred] # remove prompt and white space score = cond_output.sequences_scores return pred, score, loss def compute_acc(self, logits, labels): y_pred = torch.argmax(logits, dim=-1) y_pred = y_pred.view(size=(-1,)) y_true = labels.view(size=(-1,)).float() corr = torch.eq(y_pred, y_true) acc = torch.sum(corr.float())/y_true.shape[0] return acc def on_save_checkpoint(self, checkpoint) -> None: if self.trainer._accelerator_connector.cluster_environment.global_rank() == 0: self.model.save_pretrained(os.path.join( self.trainer.checkpoint_callback.dirpath, 'hf_pretrained_epoch{}_step{}'.format(checkpoint['epoch'], checkpoint['global_step']))) def on_load_checkpoint(self, checkpoint) -> None: global_step_offset = checkpoint["global_step"] if 'global_samples' in checkpoint: self.consumed_samples = checkpoint['global_samples'] self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset def get_tokenizer(tokenizer_type, pretrained_model_path): if tokenizer_type == 'bart': tokenizer = AutoTokenizer.from_pretrained( pretrained_model_path, use_fast=False, additional_special_tokens=["", "", ""]) print(len(tokenizer)) elif tokenizer_type == 'bert': tokenizer = BertTokenizer.from_pretrained( pretrained_model_path, use_fast=False, additional_special_tokens=["[ANS]"]) return tokenizer def main(): total_parser = argparse.ArgumentParser("Finetune BART for QG") total_parser.add_argument('--do_eval_only', action='store_true', default=False) total_parser.add_argument('--tokenizer_type', type=str, default="bart", choices=['bart', 'bert']) total_parser.add_argument('--tensorboard_dir', type=str, default="bart") total_parser.add_argument('--deepspeed') total_parser = UniversalDataModule.add_data_specific_args(total_parser) total_parser = QGT5Collator.add_data_specific_args(total_parser) total_parser = Trainer.add_argparse_args(total_parser) total_parser = UniversalCheckpoint.add_argparse_args(total_parser) total_parser = BARTFinetuneModel.add_model_specific_args(total_parser) args = total_parser.parse_args() tokenizer = get_tokenizer(args.tokenizer_type, args.model_path) collator = QGT5Collator(tokenizer=tokenizer, args=args) data_model = UniversalDataModule(collate_fn=collator, tokenizer=tokenizer, args=args) print("Data load complete...") if args.deepspeed is not None: os.environ['PL_DEEPSPEED_CONFIG_PATH'] = args.deepspeed model = BARTFinetuneModel(tokenizer, args) checkpoint_callback = UniversalCheckpoint(args) lr_monitor = LearningRateMonitor(logging_interval='step') trainer = Trainer.from_argparse_args(args, callbacks=[checkpoint_callback, lr_monitor] ) if not args.do_eval_only: trainer.fit(model, data_model) if __name__ == '__main__': main()