|
|
|
|
|
import pandas as pd |
|
import json |
|
import argparse |
|
import torch |
|
import os |
|
import logging |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from pytorch_lightning.utilities import rank_zero_info |
|
from sacrebleu.metrics import BLEU |
|
from fengshen.utils.utils import chinese_char_tokenize |
|
from fengshen.models.model_utils import add_module_args, add_inverse_square_args |
|
from fengshen.models.deltalm.tokenizer_deltalm import DeltalmTokenizer |
|
from fengshen.models.deltalm.modeling_deltalm import DeltalmForConditionalGeneration |
|
from fengshen.utils import UniversalCheckpoint |
|
from fengshen.data.universal_datamodule import UniversalDataModule |
|
from pytorch_lightning import Trainer, loggers, LightningModule |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from mosestokenizer import MosesDetokenizer |
|
from typing import List |
|
import sys |
|
sys.path.append('../../../') |
|
|
|
|
|
|
|
|
|
|
|
mose_decode = MosesDetokenizer() |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '4' |
|
logger = logging.getLogger(__name__) |
|
|
|
EVAL_BLEU_ORDER = 4 |
|
|
|
|
|
def calc_bleu_from_stats(sentence_stats: pd.DataFrame) -> BLEU: |
|
corpus_stats = sentence_stats.sum(axis=0) |
|
smooth = {"smooth_method": "exp"} |
|
corpus_bleu = BLEU.compute_bleu( |
|
correct=[ |
|
corpus_stats.correct_1_grams, |
|
corpus_stats.correct_2_grams, |
|
corpus_stats.correct_3_grams, |
|
corpus_stats.correct_4_grams, |
|
], |
|
total=[ |
|
corpus_stats.total_1_grams, |
|
corpus_stats.total_2_grams, |
|
corpus_stats.total_3_grams, |
|
corpus_stats.total_4_grams, |
|
], |
|
sys_len=corpus_stats.translation_length, |
|
ref_len=corpus_stats.reference_length, |
|
**smooth |
|
) |
|
return corpus_bleu |
|
|
|
|
|
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): |
|
if target.dim() == lprobs.dim() - 1: |
|
target = target.unsqueeze(-1) |
|
|
|
nll_loss = -lprobs.gather(dim=-1, index=target) |
|
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) |
|
if ignore_index is not None: |
|
pad_mask = target.eq(ignore_index) |
|
nll_loss.masked_fill_(pad_mask, 0.0) |
|
smooth_loss.masked_fill_(pad_mask, 0.0) |
|
else: |
|
nll_loss = nll_loss.squeeze(-1) |
|
smooth_loss = smooth_loss.squeeze(-1) |
|
if reduce: |
|
nll_loss = nll_loss.sum() |
|
smooth_loss = smooth_loss.sum() |
|
eps_i = epsilon / (lprobs.size(-1) - 1) |
|
valid_length = target.ne(ignore_index).sum() |
|
|
|
loss = ((1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss) / valid_length.item() |
|
|
|
return loss, nll_loss |
|
|
|
|
|
class DataCollator: |
|
def __init__(self, model, tokenizer, max_enc_length, max_dec_length, reverse_src_tgt): |
|
self.tokenizer = tokenizer |
|
self.max_enc_length = max_enc_length |
|
self.max_dec_length = max_dec_length |
|
self.model = model |
|
self.reverse_src_tgt = reverse_src_tgt |
|
|
|
def __call__(self, batch_samples): |
|
batch_inputs, batch_targets = [], [] |
|
for sample in batch_samples: |
|
if self.reverse_src_tgt: |
|
if "tgt" in sample and len(sample["tgt"]) != 0: |
|
batch_inputs.append(sample["tgt"]) |
|
batch_targets.append(sample["src"]) |
|
else: |
|
if "src" in sample and len(sample["src"]) != 0: |
|
batch_inputs.append(sample["src"]) |
|
batch_targets.append(sample["tgt"]) |
|
batch_data = self.tokenizer( |
|
batch_inputs, |
|
padding='max_length', |
|
max_length=self.max_enc_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
with self.tokenizer.as_target_tokenizer(): |
|
labels = self.tokenizer( |
|
batch_targets, |
|
padding='max_length', |
|
max_length=self.max_dec_length, |
|
truncation=False, |
|
return_tensors="pt" |
|
)["input_ids"] |
|
batch_data['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels) |
|
batch_data['labels'] = labels |
|
|
|
batch_data['src'] = batch_inputs |
|
batch_data['tgt'] = batch_targets |
|
|
|
|
|
return batch_data |
|
|
|
|
|
class FinetuneTranslation(LightningModule): |
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('deltalm-base finetune') |
|
parser.add_argument('--label_smoothing', default=0.1, type=float) |
|
return parent_args |
|
|
|
def __init__(self, args, tokenizer=None): |
|
super().__init__() |
|
self.args = args |
|
self.save_hyperparameters(args) |
|
if args.other_model: |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path) |
|
else: |
|
self.model = DeltalmForConditionalGeneration.from_pretrained(args.model_path, ignore_mismatched_sizes=True) |
|
self.tokenizer = tokenizer |
|
assert self.tokenizer, "tokenizer is None!" |
|
self.blue_metric = BLEU() |
|
self.sufficient_stats: List[List[int]] = [] |
|
self.label_smoothing = self.args.label_smoothing |
|
self.mose_decode = MosesDetokenizer() |
|
|
|
if self.args.label_smoothing != 0: |
|
self.loss_fn = label_smoothed_nll_loss |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() |
|
|
|
|
|
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) |
|
ab_size = self.trainer.accumulate_grad_batches * float( |
|
self.trainer.max_epochs) |
|
self.total_steps = (len(train_loader.dataset) // |
|
tb_size) // ab_size |
|
|
|
def configure_optimizers(self): |
|
|
|
from fengshen.models.model_utils import configure_optimizers |
|
return configure_optimizers(self) |
|
|
|
def training_step(self, batch, batch_idx): |
|
if self.label_smoothing == 0: |
|
output = self.model(input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
labels=batch['labels']) |
|
|
|
self.log('train_loss', output.loss, sync_dist=True) |
|
return output.loss |
|
|
|
|
|
else: |
|
labels = batch["labels"] |
|
output = self.model(input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
decoder_input_ids=batch['decoder_input_ids']) |
|
|
|
logits = output["logits"] |
|
m = torch.nn.LogSoftmax(dim=-1) |
|
lprobs = m(logits.float()) |
|
loss, _ = self.loss_fn(lprobs.view(-1, lprobs.size(-1)), labels.view(-1), |
|
self.label_smoothing, self.tokenizer.pad_token_id) |
|
self.log('train_loss', loss, sync_dist=True) |
|
return loss |
|
|
|
def comput_metrix(self, logits, labels): |
|
y_pred = torch.argmax(logits, dim=-1) |
|
y_pred = y_pred.view(size=(-1, )) |
|
|
|
y_true = labels.view(size=(-1, )) |
|
pad_mask = y_true.eq(1) |
|
valid_length = y_true.ne(1).sum() |
|
|
|
corr = torch.eq(y_pred, y_true.float()) |
|
corr.masked_fill_(pad_mask, 0.0) |
|
acc = torch.sum(corr.float()) / valid_length |
|
return acc |
|
|
|
def get_sufficient_stats(self, translations: List[str], references: List[str]) -> pd.DataFrame: |
|
assert len(translations) == len(references), ( |
|
f"There are {len(translations)} translated sentences " |
|
f"but {len(references)} reference sentences" |
|
) |
|
|
|
|
|
|
|
sentence_bleu = self.blue_metric.corpus_score(translations, [references]) |
|
self.sufficient_stats.append( |
|
[ |
|
|
|
sentence_bleu.counts[0], |
|
sentence_bleu.counts[1], |
|
sentence_bleu.counts[2], |
|
sentence_bleu.counts[3], |
|
|
|
sentence_bleu.totals[0], |
|
sentence_bleu.totals[1], |
|
sentence_bleu.totals[2], |
|
sentence_bleu.totals[3], |
|
|
|
sentence_bleu.sys_len, |
|
|
|
sentence_bleu.ref_len, |
|
] |
|
) |
|
|
|
def on_validation_start(self) -> None: |
|
|
|
prefix, ext = os.path.splitext(self.hparams.output_save_path) |
|
file_path_rank = '{}_{}{}'.format( |
|
prefix, |
|
self.trainer._accelerator_connector.cluster_environment. |
|
global_rank(), ext) |
|
if os.path.exists(file_path_rank): |
|
|
|
os.remove(file_path_rank) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
def postprocess_text(preds, labels, tgt_zh): |
|
if tgt_zh: |
|
preds = [pred.strip() for pred in preds] |
|
labels = [label.strip() for label in labels] |
|
else: |
|
preds = list(map(lambda x: mose_decode(x.strip().split()), preds)) |
|
labels = list(map(lambda x: mose_decode(x.strip().split()), labels)) |
|
return preds, labels |
|
|
|
tmp_label = batch['labels'] |
|
end_token_index = torch.where(tmp_label == self.tokenizer.eos_token_id)[1] |
|
for idx, end_idx in enumerate(end_token_index): |
|
tmp_label[idx][end_idx+1:] = -100 |
|
output = self.model(input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
labels=tmp_label) |
|
generated_ids = self.model.generate( |
|
input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
max_length=self.hparams.max_dec_length) |
|
|
|
preds = self.tokenizer.batch_decode(generated_ids, |
|
skip_special_tokens=True) |
|
labels = torch.where(batch['labels'] != -100, batch['labels'], |
|
self.tokenizer.pad_token_id) |
|
|
|
labels = self.tokenizer.batch_decode(labels, |
|
skip_special_tokens=True) |
|
|
|
decoded_preds, decoded_labels = postprocess_text(preds, labels, self.args.tgt_zh) |
|
|
|
prefix, ext = os.path.splitext(self.hparams.output_save_path) |
|
file_path_rank = '{}_{}{}'.format( |
|
prefix, |
|
self.trainer._accelerator_connector.cluster_environment. |
|
global_rank(), ext) |
|
self.save_prediction_to_file(preds=decoded_preds, |
|
sources=batch['src'], |
|
targets=decoded_labels, |
|
ori_target=batch['tgt'], |
|
file_path=file_path_rank) |
|
|
|
if self.args.tgt_zh: |
|
new_preds = [chinese_char_tokenize(p) for p in decoded_preds] |
|
new_labels = [chinese_char_tokenize(label) for label in decoded_labels] |
|
self.get_sufficient_stats(new_preds, new_labels) |
|
else: |
|
self.get_sufficient_stats(decoded_preds, decoded_labels) |
|
|
|
acc = self.comput_metrix(output.logits, batch['labels']) |
|
self.log('val_loss', output.loss, sync_dist=True) |
|
self.log('val_acc', acc, sync_dist=True) |
|
|
|
def validation_epoch_end(self, outputs): |
|
rank_zero_info("***** Validation results *****") |
|
sentence_states = pd.DataFrame( |
|
self.sufficient_stats, |
|
columns=[ |
|
"correct_1_grams", |
|
"correct_2_grams", |
|
"correct_3_grams", |
|
"correct_4_grams", |
|
"total_1_grams", |
|
"total_2_grams", |
|
"total_3_grams", |
|
"total_4_grams", |
|
"translation_length", |
|
"reference_length", |
|
] |
|
) |
|
|
|
computed_bleu = calc_bleu_from_stats(sentence_states) |
|
rank_zero_info("valid_sacrebleu= {}\n".format(computed_bleu.score)) |
|
self.log('valid_sacrebleu', computed_bleu.score, sync_dist=True) |
|
self.sufficient_stats = [] |
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
if self.trainer._accelerator_connector.cluster_environment.global_rank( |
|
) == 0: |
|
self.model.save_pretrained( |
|
os.path.join( |
|
self.trainer.checkpoint_callback.dirpath, |
|
'finetuned_epoch{}_step{}'.format( |
|
checkpoint['epoch'], checkpoint['global_step']))) |
|
|
|
def save_prediction_to_file(self, preds, sources, targets, ori_target, file_path): |
|
with open(file_path, 'a', encoding='utf-8') as f: |
|
for idx, pred in enumerate(preds): |
|
source = sources[idx] |
|
target = targets[idx] |
|
tmp_result = dict() |
|
tmp_result['pred'] = pred |
|
tmp_result['source'] = source |
|
tmp_result['label'] = target |
|
tmp_result['ori_label'] = ori_target[idx] |
|
json_data = json.dumps(tmp_result, ensure_ascii=False) |
|
f.write(json_data + '\n') |
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
texts = batch['src'] |
|
|
|
self.model.eval() |
|
generated_ids = self.model.generate( |
|
input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
max_length=self.hparams.max_dec_length |
|
) |
|
preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
|
labels = torch.where(batch['labels'] != -100, batch['labels'], |
|
self.tokenizer.pad_token_id) |
|
labels = self.tokenizer.batch_decode( |
|
labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
|
self.save_prediction_to_file(preds, texts, labels, self.hparams.output_save_path) |
|
|
|
|
|
def configure_logger(logging_lever=logging.INFO): |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
logger.setLevel(logging_lever) |
|
|
|
|
|
def main(): |
|
args_parser = argparse.ArgumentParser("Pegasus Task") |
|
args_parser.add_argument('--do_eval_only', |
|
action='store_true', |
|
default=False) |
|
args_parser.add_argument('--other_model', |
|
action='store_true', |
|
default=False) |
|
args_parser.add_argument('--reverse_src_tgt', |
|
action='store_true', |
|
default=False) |
|
args_parser.add_argument('--tgt_zh', |
|
action='store_true', |
|
default=False) |
|
args_parser.add_argument('--early_stopping_callback', |
|
action='store_true', |
|
default=False) |
|
args_parser.add_argument('--pretrained_model_path', |
|
default='facebook/mbart', |
|
type=str) |
|
args_parser.add_argument('--output_save_path', |
|
default='predict.json', |
|
type=str) |
|
args_parser.add_argument('--max_enc_length', default=512, type=int) |
|
args_parser.add_argument('--max_dec_length', default=512, type=int) |
|
|
|
|
|
args_parser = UniversalDataModule.add_data_specific_args(args_parser) |
|
|
|
|
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) |
|
args_parser = FinetuneTranslation.add_model_specific_args(args_parser) |
|
args_parser = add_module_args(args_parser) |
|
args_parser = add_inverse_square_args(args_parser) |
|
|
|
args = args_parser.parse_args() |
|
|
|
if args.other_model: |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
|
else: |
|
tokenizer = DeltalmTokenizer.from_pretrained(args.model_path) |
|
|
|
print("tokenizer vocab size: ", tokenizer.vocab_size) |
|
model = FinetuneTranslation(args, tokenizer) |
|
collator = DataCollator(model.model, tokenizer, args.max_enc_length, args.max_dec_length, args.reverse_src_tgt) |
|
data_model = UniversalDataModule(tokenizer=tokenizer, |
|
args=args, |
|
|
|
collate_fn=collator) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
|
|
configure_logger(logging_lever=logging.INFO) |
|
|
|
if not args.do_eval_only: |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
tensorboard_logger = loggers.TensorBoardLogger( |
|
save_dir=os.path.join(args.default_root_dir, 'logs/'), |
|
name=os.path.basename(os.path.dirname(args.model_path))) |
|
checkpoint_callback = UniversalCheckpoint(args) |
|
|
|
trainer = Trainer.from_argparse_args( |
|
args, logger=tensorboard_logger, callbacks=[lr_monitor, checkpoint_callback]) |
|
trainer.fit(model, data_model) |
|
|
|
else: |
|
trainer = Trainer.from_argparse_args(args) |
|
trainer.validate(model, data_model) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|