summary / fengshen /examples /translate /finetune_deltalm.py
fclong's picture
Upload 396 files
8ebda9e
raw
history blame
18.3 kB
# !/usr/bin/env python
# -*- coding: utf-8 -*-
import pandas as pd
import json
import argparse
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pytorch_lightning.utilities import rank_zero_info
from sacrebleu.metrics import BLEU
from fengshen.utils.utils import chinese_char_tokenize
from fengshen.models.model_utils import add_module_args, add_inverse_square_args
from fengshen.models.deltalm.tokenizer_deltalm import DeltalmTokenizer
from fengshen.models.deltalm.modeling_deltalm import DeltalmForConditionalGeneration
from fengshen.utils import UniversalCheckpoint
from fengshen.data.universal_datamodule import UniversalDataModule
from pytorch_lightning import Trainer, loggers, LightningModule
from pytorch_lightning.callbacks import LearningRateMonitor
from mosestokenizer import MosesDetokenizer
from typing import List
import sys
sys.path.append('../../../')
# from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
mose_decode = MosesDetokenizer()
os.environ["CUDA_VISIBLE_DEVICES"] = '4'
logger = logging.getLogger(__name__)
EVAL_BLEU_ORDER = 4
def calc_bleu_from_stats(sentence_stats: pd.DataFrame) -> BLEU:
corpus_stats = sentence_stats.sum(axis=0)
smooth = {"smooth_method": "exp"}
corpus_bleu = BLEU.compute_bleu(
correct=[
corpus_stats.correct_1_grams,
corpus_stats.correct_2_grams,
corpus_stats.correct_3_grams,
corpus_stats.correct_4_grams,
],
total=[
corpus_stats.total_1_grams,
corpus_stats.total_2_grams,
corpus_stats.total_3_grams,
corpus_stats.total_4_grams,
],
sys_len=corpus_stats.translation_length,
ref_len=corpus_stats.reference_length,
**smooth
)
return corpus_bleu
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
# logger.debug("Debug: After target.dim() == lprobs.dim(): ", target.dim(), lprobs.dim())
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / (lprobs.size(-1) - 1)
valid_length = target.ne(ignore_index).sum()
# unvalid_length = target.eq(ignore_index).sum()
loss = ((1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss) / valid_length.item()
return loss, nll_loss
class DataCollator:
def __init__(self, model, tokenizer, max_enc_length, max_dec_length, reverse_src_tgt):
self.tokenizer = tokenizer
self.max_enc_length = max_enc_length
self.max_dec_length = max_dec_length
self.model = model
self.reverse_src_tgt = reverse_src_tgt
def __call__(self, batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
if self.reverse_src_tgt:
if "tgt" in sample and len(sample["tgt"]) != 0:
batch_inputs.append(sample["tgt"])
batch_targets.append(sample["src"])
else:
if "src" in sample and len(sample["src"]) != 0:
batch_inputs.append(sample["src"])
batch_targets.append(sample["tgt"])
batch_data = self.tokenizer(
batch_inputs,
padding='max_length',
max_length=self.max_enc_length,
truncation=True,
return_tensors="pt"
)
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(
batch_targets,
padding='max_length',
max_length=self.max_dec_length,
truncation=False,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels)
batch_data['labels'] = labels
batch_data['src'] = batch_inputs
batch_data['tgt'] = batch_targets
# logger.debug(batch_data)
return batch_data
class FinetuneTranslation(LightningModule):
@staticmethod
def add_model_specific_args(parent_args):
parser = parent_args.add_argument_group('deltalm-base finetune')
parser.add_argument('--label_smoothing', default=0.1, type=float)
return parent_args
def __init__(self, args, tokenizer=None):
super().__init__()
self.args = args
self.save_hyperparameters(args)
if args.other_model:
self.model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path)
else:
self.model = DeltalmForConditionalGeneration.from_pretrained(args.model_path, ignore_mismatched_sizes=True)
self.tokenizer = tokenizer
assert self.tokenizer, "tokenizer is None!"
self.blue_metric = BLEU()
self.sufficient_stats: List[List[int]] = []
self.label_smoothing = self.args.label_smoothing
self.mose_decode = MosesDetokenizer()
if self.args.label_smoothing != 0:
self.loss_fn = label_smoothed_nll_loss
def setup(self, stage) -> None:
if stage == 'fit':
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader()
# Calculate total steps
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus)
ab_size = self.trainer.accumulate_grad_batches * float(
self.trainer.max_epochs)
self.total_steps = (len(train_loader.dataset) //
tb_size) // ab_size
def configure_optimizers(self):
# if self.args.use_default_configure:
from fengshen.models.model_utils import configure_optimizers
return configure_optimizers(self)
def training_step(self, batch, batch_idx):
if self.label_smoothing == 0:
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels'])
self.log('train_loss', output.loss, sync_dist=True)
return output.loss
# TODO label_smoothing should be implemented at here
else:
labels = batch["labels"]
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'])
logits = output["logits"]
m = torch.nn.LogSoftmax(dim=-1)
lprobs = m(logits.float())
loss, _ = self.loss_fn(lprobs.view(-1, lprobs.size(-1)), labels.view(-1),
self.label_smoothing, self.tokenizer.pad_token_id)
self.log('train_loss', loss, sync_dist=True)
return loss
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1, ))
y_true = labels.view(size=(-1, ))
pad_mask = y_true.eq(1)
valid_length = y_true.ne(1).sum()
corr = torch.eq(y_pred, y_true.float())
corr.masked_fill_(pad_mask, 0.0)
acc = torch.sum(corr.float()) / valid_length
return acc
def get_sufficient_stats(self, translations: List[str], references: List[str]) -> pd.DataFrame:
assert len(translations) == len(references), (
f"There are {len(translations)} translated sentences "
f"but {len(references)} reference sentences"
)
# for sentence, ref in zip(translations, references):
sentence_bleu = self.blue_metric.corpus_score(translations, [references])
self.sufficient_stats.append(
[
# Number of correct 1-grams, .., 4-grams
sentence_bleu.counts[0],
sentence_bleu.counts[1],
sentence_bleu.counts[2],
sentence_bleu.counts[3],
# Total number of 1-grams, .., 4-grams
sentence_bleu.totals[0],
sentence_bleu.totals[1],
sentence_bleu.totals[2],
sentence_bleu.totals[3],
# Length of translated sentence.
sentence_bleu.sys_len,
# Length of reference sentence.
sentence_bleu.ref_len,
]
)
def on_validation_start(self) -> None:
# rm file at validation start
prefix, ext = os.path.splitext(self.hparams.output_save_path)
file_path_rank = '{}_{}{}'.format(
prefix,
self.trainer._accelerator_connector.cluster_environment.
global_rank(), ext)
if os.path.exists(file_path_rank):
# logger.debug('rm {}'.format(file_path_rank))
os.remove(file_path_rank)
def validation_step(self, batch, batch_idx):
def postprocess_text(preds, labels, tgt_zh):
if tgt_zh:
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
else:
preds = list(map(lambda x: mose_decode(x.strip().split()), preds))
labels = list(map(lambda x: mose_decode(x.strip().split()), labels))
return preds, labels
tmp_label = batch['labels']
end_token_index = torch.where(tmp_label == self.tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
tmp_label[idx][end_idx+1:] = -100
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=tmp_label)
generated_ids = self.model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
max_length=self.hparams.max_dec_length)
preds = self.tokenizer.batch_decode(generated_ids,
skip_special_tokens=True)
labels = torch.where(batch['labels'] != -100, batch['labels'],
self.tokenizer.pad_token_id)
labels = self.tokenizer.batch_decode(labels,
skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_text(preds, labels, self.args.tgt_zh)
# save preds for every rank
prefix, ext = os.path.splitext(self.hparams.output_save_path)
file_path_rank = '{}_{}{}'.format(
prefix,
self.trainer._accelerator_connector.cluster_environment.
global_rank(), ext)
self.save_prediction_to_file(preds=decoded_preds,
sources=batch['src'],
targets=decoded_labels,
ori_target=batch['tgt'],
file_path=file_path_rank)
if self.args.tgt_zh:
new_preds = [chinese_char_tokenize(p) for p in decoded_preds]
new_labels = [chinese_char_tokenize(label) for label in decoded_labels]
self.get_sufficient_stats(new_preds, new_labels)
else:
self.get_sufficient_stats(decoded_preds, decoded_labels)
# batch_bleu = self.blue_metric.corpus_score(decoded_preds, [decoded_labels]).score
acc = self.comput_metrix(output.logits, batch['labels'])
self.log('val_loss', output.loss, sync_dist=True)
self.log('val_acc', acc, sync_dist=True)
def validation_epoch_end(self, outputs):
rank_zero_info("***** Validation results *****")
sentence_states = pd.DataFrame(
self.sufficient_stats,
columns=[
"correct_1_grams",
"correct_2_grams",
"correct_3_grams",
"correct_4_grams",
"total_1_grams",
"total_2_grams",
"total_3_grams",
"total_4_grams",
"translation_length",
"reference_length",
]
)
computed_bleu = calc_bleu_from_stats(sentence_states)
rank_zero_info("valid_sacrebleu= {}\n".format(computed_bleu.score))
self.log('valid_sacrebleu', computed_bleu.score, sync_dist=True)
self.sufficient_stats = []
def on_save_checkpoint(self, checkpoint) -> None:
if self.trainer._accelerator_connector.cluster_environment.global_rank(
) == 0:
self.model.save_pretrained(
os.path.join(
self.trainer.checkpoint_callback.dirpath,
'finetuned_epoch{}_step{}'.format(
checkpoint['epoch'], checkpoint['global_step'])))
def save_prediction_to_file(self, preds, sources, targets, ori_target, file_path):
with open(file_path, 'a', encoding='utf-8') as f:
for idx, pred in enumerate(preds):
source = sources[idx]
target = targets[idx]
tmp_result = dict()
tmp_result['pred'] = pred
tmp_result['source'] = source
tmp_result['label'] = target
tmp_result['ori_label'] = ori_target[idx]
json_data = json.dumps(tmp_result, ensure_ascii=False)
f.write(json_data + '\n')
def test_step(self, batch, batch_idx):
# print(batch)
texts = batch['src']
# output summary and metrics
self.model.eval()
generated_ids = self.model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
max_length=self.hparams.max_dec_length
)
preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
labels = torch.where(batch['labels'] != -100, batch['labels'],
self.tokenizer.pad_token_id)
labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
self.save_prediction_to_file(preds, texts, labels, self.hparams.output_save_path)
def configure_logger(logging_lever=logging.INFO):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging_lever)
def main():
args_parser = argparse.ArgumentParser("Pegasus Task")
args_parser.add_argument('--do_eval_only',
action='store_true',
default=False)
args_parser.add_argument('--other_model',
action='store_true',
default=False)
args_parser.add_argument('--reverse_src_tgt',
action='store_true',
default=False)
args_parser.add_argument('--tgt_zh',
action='store_true',
default=False)
args_parser.add_argument('--early_stopping_callback',
action='store_true',
default=False)
args_parser.add_argument('--pretrained_model_path',
default='facebook/mbart',
type=str)
args_parser.add_argument('--output_save_path',
default='predict.json',
type=str)
args_parser.add_argument('--max_enc_length', default=512, type=int)
args_parser.add_argument('--max_dec_length', default=512, type=int)
# * Args for data preprocessing
args_parser = UniversalDataModule.add_data_specific_args(args_parser)
# * Args for training
args_parser = Trainer.add_argparse_args(args_parser)
args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
args_parser = FinetuneTranslation.add_model_specific_args(args_parser)
args_parser = add_module_args(args_parser)
args_parser = add_inverse_square_args(args_parser)
args = args_parser.parse_args()
if args.other_model:
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
else:
tokenizer = DeltalmTokenizer.from_pretrained(args.model_path)
# tokenizer = AutoTokenizer.from_pretrained(args.model_path)
print("tokenizer vocab size: ", tokenizer.vocab_size)
model = FinetuneTranslation(args, tokenizer)
collator = DataCollator(model.model, tokenizer, args.max_enc_length, args.max_dec_length, args.reverse_src_tgt)
data_model = UniversalDataModule(tokenizer=tokenizer,
args=args,
# datasets=dataset,
collate_fn=collator)
lr_monitor = LearningRateMonitor(logging_interval='step')
configure_logger(logging_lever=logging.INFO)
if not args.do_eval_only:
lr_monitor = LearningRateMonitor(logging_interval='step')
tensorboard_logger = loggers.TensorBoardLogger(
save_dir=os.path.join(args.default_root_dir, 'logs/'),
name=os.path.basename(os.path.dirname(args.model_path)))
checkpoint_callback = UniversalCheckpoint(args)
# early_stop = EarlyStopping(monitor=args.monitor, mode=args.mode)
trainer = Trainer.from_argparse_args(
args, logger=tensorboard_logger, callbacks=[lr_monitor, checkpoint_callback])
trainer.fit(model, data_model)
else:
trainer = Trainer.from_argparse_args(args)
trainer.validate(model, data_model)
# trainer.test(model, data_model)
if __name__ == '__main__':
main()