File size: 18,263 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 |
# !/usr/bin/env python
# -*- coding: utf-8 -*-
import pandas as pd
import json
import argparse
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pytorch_lightning.utilities import rank_zero_info
from sacrebleu.metrics import BLEU
from fengshen.utils.utils import chinese_char_tokenize
from fengshen.models.model_utils import add_module_args, add_inverse_square_args
from fengshen.models.deltalm.tokenizer_deltalm import DeltalmTokenizer
from fengshen.models.deltalm.modeling_deltalm import DeltalmForConditionalGeneration
from fengshen.utils import UniversalCheckpoint
from fengshen.data.universal_datamodule import UniversalDataModule
from pytorch_lightning import Trainer, loggers, LightningModule
from pytorch_lightning.callbacks import LearningRateMonitor
from mosestokenizer import MosesDetokenizer
from typing import List
import sys
sys.path.append('../../../')
# from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
mose_decode = MosesDetokenizer()
os.environ["CUDA_VISIBLE_DEVICES"] = '4'
logger = logging.getLogger(__name__)
EVAL_BLEU_ORDER = 4
def calc_bleu_from_stats(sentence_stats: pd.DataFrame) -> BLEU:
corpus_stats = sentence_stats.sum(axis=0)
smooth = {"smooth_method": "exp"}
corpus_bleu = BLEU.compute_bleu(
correct=[
corpus_stats.correct_1_grams,
corpus_stats.correct_2_grams,
corpus_stats.correct_3_grams,
corpus_stats.correct_4_grams,
],
total=[
corpus_stats.total_1_grams,
corpus_stats.total_2_grams,
corpus_stats.total_3_grams,
corpus_stats.total_4_grams,
],
sys_len=corpus_stats.translation_length,
ref_len=corpus_stats.reference_length,
**smooth
)
return corpus_bleu
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
# logger.debug("Debug: After target.dim() == lprobs.dim(): ", target.dim(), lprobs.dim())
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / (lprobs.size(-1) - 1)
valid_length = target.ne(ignore_index).sum()
# unvalid_length = target.eq(ignore_index).sum()
loss = ((1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss) / valid_length.item()
return loss, nll_loss
class DataCollator:
def __init__(self, model, tokenizer, max_enc_length, max_dec_length, reverse_src_tgt):
self.tokenizer = tokenizer
self.max_enc_length = max_enc_length
self.max_dec_length = max_dec_length
self.model = model
self.reverse_src_tgt = reverse_src_tgt
def __call__(self, batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
if self.reverse_src_tgt:
if "tgt" in sample and len(sample["tgt"]) != 0:
batch_inputs.append(sample["tgt"])
batch_targets.append(sample["src"])
else:
if "src" in sample and len(sample["src"]) != 0:
batch_inputs.append(sample["src"])
batch_targets.append(sample["tgt"])
batch_data = self.tokenizer(
batch_inputs,
padding='max_length',
max_length=self.max_enc_length,
truncation=True,
return_tensors="pt"
)
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(
batch_targets,
padding='max_length',
max_length=self.max_dec_length,
truncation=False,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels)
batch_data['labels'] = labels
batch_data['src'] = batch_inputs
batch_data['tgt'] = batch_targets
# logger.debug(batch_data)
return batch_data
class FinetuneTranslation(LightningModule):
@staticmethod
def add_model_specific_args(parent_args):
parser = parent_args.add_argument_group('deltalm-base finetune')
parser.add_argument('--label_smoothing', default=0.1, type=float)
return parent_args
def __init__(self, args, tokenizer=None):
super().__init__()
self.args = args
self.save_hyperparameters(args)
if args.other_model:
self.model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path)
else:
self.model = DeltalmForConditionalGeneration.from_pretrained(args.model_path, ignore_mismatched_sizes=True)
self.tokenizer = tokenizer
assert self.tokenizer, "tokenizer is None!"
self.blue_metric = BLEU()
self.sufficient_stats: List[List[int]] = []
self.label_smoothing = self.args.label_smoothing
self.mose_decode = MosesDetokenizer()
if self.args.label_smoothing != 0:
self.loss_fn = label_smoothed_nll_loss
def setup(self, stage) -> None:
if stage == 'fit':
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader()
# Calculate total steps
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus)
ab_size = self.trainer.accumulate_grad_batches * float(
self.trainer.max_epochs)
self.total_steps = (len(train_loader.dataset) //
tb_size) // ab_size
def configure_optimizers(self):
# if self.args.use_default_configure:
from fengshen.models.model_utils import configure_optimizers
return configure_optimizers(self)
def training_step(self, batch, batch_idx):
if self.label_smoothing == 0:
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels'])
self.log('train_loss', output.loss, sync_dist=True)
return output.loss
# TODO label_smoothing should be implemented at here
else:
labels = batch["labels"]
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'])
logits = output["logits"]
m = torch.nn.LogSoftmax(dim=-1)
lprobs = m(logits.float())
loss, _ = self.loss_fn(lprobs.view(-1, lprobs.size(-1)), labels.view(-1),
self.label_smoothing, self.tokenizer.pad_token_id)
self.log('train_loss', loss, sync_dist=True)
return loss
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1, ))
y_true = labels.view(size=(-1, ))
pad_mask = y_true.eq(1)
valid_length = y_true.ne(1).sum()
corr = torch.eq(y_pred, y_true.float())
corr.masked_fill_(pad_mask, 0.0)
acc = torch.sum(corr.float()) / valid_length
return acc
def get_sufficient_stats(self, translations: List[str], references: List[str]) -> pd.DataFrame:
assert len(translations) == len(references), (
f"There are {len(translations)} translated sentences "
f"but {len(references)} reference sentences"
)
# for sentence, ref in zip(translations, references):
sentence_bleu = self.blue_metric.corpus_score(translations, [references])
self.sufficient_stats.append(
[
# Number of correct 1-grams, .., 4-grams
sentence_bleu.counts[0],
sentence_bleu.counts[1],
sentence_bleu.counts[2],
sentence_bleu.counts[3],
# Total number of 1-grams, .., 4-grams
sentence_bleu.totals[0],
sentence_bleu.totals[1],
sentence_bleu.totals[2],
sentence_bleu.totals[3],
# Length of translated sentence.
sentence_bleu.sys_len,
# Length of reference sentence.
sentence_bleu.ref_len,
]
)
def on_validation_start(self) -> None:
# rm file at validation start
prefix, ext = os.path.splitext(self.hparams.output_save_path)
file_path_rank = '{}_{}{}'.format(
prefix,
self.trainer._accelerator_connector.cluster_environment.
global_rank(), ext)
if os.path.exists(file_path_rank):
# logger.debug('rm {}'.format(file_path_rank))
os.remove(file_path_rank)
def validation_step(self, batch, batch_idx):
def postprocess_text(preds, labels, tgt_zh):
if tgt_zh:
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
else:
preds = list(map(lambda x: mose_decode(x.strip().split()), preds))
labels = list(map(lambda x: mose_decode(x.strip().split()), labels))
return preds, labels
tmp_label = batch['labels']
end_token_index = torch.where(tmp_label == self.tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
tmp_label[idx][end_idx+1:] = -100
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=tmp_label)
generated_ids = self.model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
max_length=self.hparams.max_dec_length)
preds = self.tokenizer.batch_decode(generated_ids,
skip_special_tokens=True)
labels = torch.where(batch['labels'] != -100, batch['labels'],
self.tokenizer.pad_token_id)
labels = self.tokenizer.batch_decode(labels,
skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_text(preds, labels, self.args.tgt_zh)
# save preds for every rank
prefix, ext = os.path.splitext(self.hparams.output_save_path)
file_path_rank = '{}_{}{}'.format(
prefix,
self.trainer._accelerator_connector.cluster_environment.
global_rank(), ext)
self.save_prediction_to_file(preds=decoded_preds,
sources=batch['src'],
targets=decoded_labels,
ori_target=batch['tgt'],
file_path=file_path_rank)
if self.args.tgt_zh:
new_preds = [chinese_char_tokenize(p) for p in decoded_preds]
new_labels = [chinese_char_tokenize(label) for label in decoded_labels]
self.get_sufficient_stats(new_preds, new_labels)
else:
self.get_sufficient_stats(decoded_preds, decoded_labels)
# batch_bleu = self.blue_metric.corpus_score(decoded_preds, [decoded_labels]).score
acc = self.comput_metrix(output.logits, batch['labels'])
self.log('val_loss', output.loss, sync_dist=True)
self.log('val_acc', acc, sync_dist=True)
def validation_epoch_end(self, outputs):
rank_zero_info("***** Validation results *****")
sentence_states = pd.DataFrame(
self.sufficient_stats,
columns=[
"correct_1_grams",
"correct_2_grams",
"correct_3_grams",
"correct_4_grams",
"total_1_grams",
"total_2_grams",
"total_3_grams",
"total_4_grams",
"translation_length",
"reference_length",
]
)
computed_bleu = calc_bleu_from_stats(sentence_states)
rank_zero_info("valid_sacrebleu= {}\n".format(computed_bleu.score))
self.log('valid_sacrebleu', computed_bleu.score, sync_dist=True)
self.sufficient_stats = []
def on_save_checkpoint(self, checkpoint) -> None:
if self.trainer._accelerator_connector.cluster_environment.global_rank(
) == 0:
self.model.save_pretrained(
os.path.join(
self.trainer.checkpoint_callback.dirpath,
'finetuned_epoch{}_step{}'.format(
checkpoint['epoch'], checkpoint['global_step'])))
def save_prediction_to_file(self, preds, sources, targets, ori_target, file_path):
with open(file_path, 'a', encoding='utf-8') as f:
for idx, pred in enumerate(preds):
source = sources[idx]
target = targets[idx]
tmp_result = dict()
tmp_result['pred'] = pred
tmp_result['source'] = source
tmp_result['label'] = target
tmp_result['ori_label'] = ori_target[idx]
json_data = json.dumps(tmp_result, ensure_ascii=False)
f.write(json_data + '\n')
def test_step(self, batch, batch_idx):
# print(batch)
texts = batch['src']
# output summary and metrics
self.model.eval()
generated_ids = self.model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
max_length=self.hparams.max_dec_length
)
preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
labels = torch.where(batch['labels'] != -100, batch['labels'],
self.tokenizer.pad_token_id)
labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
self.save_prediction_to_file(preds, texts, labels, self.hparams.output_save_path)
def configure_logger(logging_lever=logging.INFO):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging_lever)
def main():
args_parser = argparse.ArgumentParser("Pegasus Task")
args_parser.add_argument('--do_eval_only',
action='store_true',
default=False)
args_parser.add_argument('--other_model',
action='store_true',
default=False)
args_parser.add_argument('--reverse_src_tgt',
action='store_true',
default=False)
args_parser.add_argument('--tgt_zh',
action='store_true',
default=False)
args_parser.add_argument('--early_stopping_callback',
action='store_true',
default=False)
args_parser.add_argument('--pretrained_model_path',
default='facebook/mbart',
type=str)
args_parser.add_argument('--output_save_path',
default='predict.json',
type=str)
args_parser.add_argument('--max_enc_length', default=512, type=int)
args_parser.add_argument('--max_dec_length', default=512, type=int)
# * Args for data preprocessing
args_parser = UniversalDataModule.add_data_specific_args(args_parser)
# * Args for training
args_parser = Trainer.add_argparse_args(args_parser)
args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
args_parser = FinetuneTranslation.add_model_specific_args(args_parser)
args_parser = add_module_args(args_parser)
args_parser = add_inverse_square_args(args_parser)
args = args_parser.parse_args()
if args.other_model:
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
else:
tokenizer = DeltalmTokenizer.from_pretrained(args.model_path)
# tokenizer = AutoTokenizer.from_pretrained(args.model_path)
print("tokenizer vocab size: ", tokenizer.vocab_size)
model = FinetuneTranslation(args, tokenizer)
collator = DataCollator(model.model, tokenizer, args.max_enc_length, args.max_dec_length, args.reverse_src_tgt)
data_model = UniversalDataModule(tokenizer=tokenizer,
args=args,
# datasets=dataset,
collate_fn=collator)
lr_monitor = LearningRateMonitor(logging_interval='step')
configure_logger(logging_lever=logging.INFO)
if not args.do_eval_only:
lr_monitor = LearningRateMonitor(logging_interval='step')
tensorboard_logger = loggers.TensorBoardLogger(
save_dir=os.path.join(args.default_root_dir, 'logs/'),
name=os.path.basename(os.path.dirname(args.model_path)))
checkpoint_callback = UniversalCheckpoint(args)
# early_stop = EarlyStopping(monitor=args.monitor, mode=args.mode)
trainer = Trainer.from_argparse_args(
args, logger=tensorboard_logger, callbacks=[lr_monitor, checkpoint_callback])
trainer.fit(model, data_model)
else:
trainer = Trainer.from_argparse_args(args)
trainer.validate(model, data_model)
# trainer.test(model, data_model)
if __name__ == '__main__':
main()
|