|
|
|
|
|
import os |
|
import math |
|
import logging |
|
from pprint import pformat |
|
from argparse import ArgumentParser |
|
from collections import defaultdict |
|
from itertools import chain |
|
|
|
import torch |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from ignite.engine import Engine, Events |
|
from ignite.handlers import ModelCheckpoint |
|
from ignite.metrics import Accuracy, Loss, MetricsLambda, RunningAverage |
|
from ignite.contrib.handlers import ProgressBar, PiecewiseLinear |
|
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler |
|
from transformers import (AdamW, OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, |
|
GPT2DoubleHeadsModel, GPT2Tokenizer, WEIGHTS_NAME, CONFIG_NAME) |
|
|
|
from utils import get_dataset, make_logdir |
|
|
|
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"] |
|
ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<bos>', 'eos_token': '<eos>', 'pad_token': '<pad>', |
|
'additional_special_tokens': ['<speaker1>', '<speaker2>']} |
|
MODEL_INPUTS = ["input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids"] |
|
PADDED_INPUTS = ["input_ids", "lm_labels", "token_type_ids"] |
|
|
|
logger = logging.getLogger(__file__) |
|
|
|
def average_distributed_scalar(scalar, args): |
|
""" Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """ |
|
if args.local_rank == -1: |
|
return scalar |
|
scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size() |
|
torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) |
|
return scalar_t.item() |
|
|
|
|
|
def pad_dataset(dataset, padding=0): |
|
""" Pad the dataset. This could be optimized by defining a Dataset class and padding at the batch level, but this is simpler. """ |
|
max_l = max(len(x) for x in dataset["input_ids"]) |
|
for name in PADDED_INPUTS: |
|
dataset[name] = [x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name]] |
|
return dataset |
|
|
|
|
|
def add_special_tokens_(model, tokenizer): |
|
""" Add special tokens to the tokenizer and the model if they have not already been added. """ |
|
orig_num_tokens = len(tokenizer.encoder) |
|
num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) |
|
if num_added_tokens > 0: |
|
model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens) |
|
|
|
def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=False, with_eos=True): |
|
""" Build a sequence of input from 3 segments: persona, history and last reply. """ |
|
bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) |
|
sequence = [[bos] + list(chain(*persona))] + history + [reply + ([eos] if with_eos else [])] |
|
sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])] |
|
instance = {} |
|
instance["input_ids"] = list(chain(*sequence)) |
|
instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s] |
|
instance["mc_token_ids"] = len(instance["input_ids"]) - 1 |
|
instance["lm_labels"] = [-100] * len(instance["input_ids"]) |
|
if lm_labels: |
|
instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:] |
|
return instance |
|
|
|
|
|
def get_data_loaders(args, tokenizer): |
|
""" Prepare the dataset for training and evaluation """ |
|
personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) |
|
|
|
logger.info("Build inputs and labels") |
|
datasets = {"train": defaultdict(list), "valid": defaultdict(list)} |
|
for dataset_name, dataset in personachat.items(): |
|
num_candidates = len(dataset[0]["utterances"][0]["candidates"]) |
|
if args.num_candidates > 0 and dataset_name == 'train': |
|
num_candidates = min(args.num_candidates, num_candidates) |
|
for dialog in dataset: |
|
persona = dialog["personality"].copy() |
|
for _ in range(args.personality_permutations): |
|
for utterance in dialog["utterances"]: |
|
history = utterance["history"][-(2*args.max_history+1):] |
|
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): |
|
lm_labels = bool(j == num_candidates-1) |
|
instance = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels) |
|
for input_name, input_array in instance.items(): |
|
datasets[dataset_name][input_name].append(input_array) |
|
datasets[dataset_name]["mc_labels"].append(num_candidates - 1) |
|
datasets[dataset_name]["n_candidates"] = num_candidates |
|
persona = [persona[-1]] + persona[:-1] |
|
|
|
logger.info("Pad inputs and convert to Tensor") |
|
tensor_datasets = {"train": [], "valid": []} |
|
for dataset_name, dataset in datasets.items(): |
|
dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) |
|
for input_name in MODEL_INPUTS: |
|
tensor = torch.tensor(dataset[input_name]) |
|
if input_name != "mc_labels": |
|
tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:]) |
|
tensor_datasets[dataset_name].append(tensor) |
|
|
|
logger.info("Build train and validation dataloaders") |
|
train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"]) |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None |
|
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None |
|
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed)) |
|
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False) |
|
|
|
logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape)) |
|
logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape)) |
|
return train_loader, valid_loader, train_sampler, valid_sampler |
|
|
|
|
|
def train(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--dataset_path", type=str, default="/Users/tetracycline/repos/datascience/datascience/projects/counsel_chat_all_data_300-tokens.json", help="Path or url of the dataset. If empty download from S3.") |
|
parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") |
|
parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") |
|
parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") |
|
parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") |
|
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") |
|
parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") |
|
parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") |
|
parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") |
|
parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") |
|
parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") |
|
parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") |
|
parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") |
|
parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training") |
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") |
|
parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") |
|
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") |
|
args = parser.parse_args() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) |
|
logger.warning("Running process %d", args.local_rank) |
|
logger.info("Arguments: %s", pformat(args)) |
|
|
|
|
|
args.distributed = (args.local_rank != -1) |
|
if args.distributed: |
|
torch.cuda.set_device(args.local_rank) |
|
args.device = torch.device("cuda", args.local_rank) |
|
torch.distributed.init_process_group(backend='nccl', init_method='env://') |
|
|
|
logger.info("Prepare tokenizer, pretrained model and optimizer.") |
|
tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer |
|
tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) |
|
|
|
|
|
model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel |
|
model = model_class.from_pretrained(args.model_checkpoint) |
|
model.to(args.device) |
|
|
|
add_special_tokens_(model, tokenizer) |
|
optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) |
|
|
|
|
|
if args.fp16: |
|
from apex import amp |
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) |
|
if args.distributed: |
|
model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) |
|
|
|
logger.info("Prepare datasets") |
|
train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer) |
|
|
|
|
|
def update(engine, batch): |
|
model.train() |
|
batch = tuple(input_tensor.to(args.device) for input_tensor in batch) |
|
input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch |
|
(lm_loss), (mc_loss), *_ = model( |
|
input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, |
|
mc_labels=mc_labels, lm_labels=lm_labels |
|
) |
|
loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps |
|
if args.fp16: |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) |
|
else: |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) |
|
if engine.state.iteration % args.gradient_accumulation_steps == 0: |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
return loss.item() |
|
trainer = Engine(update) |
|
|
|
|
|
def inference(engine, batch): |
|
model.eval() |
|
with torch.no_grad(): |
|
batch = tuple(input_tensor.to(args.device) for input_tensor in batch) |
|
input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch |
|
logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) |
|
|
|
lm_logits, mc_logits, *_ = model( |
|
input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, |
|
) |
|
lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) |
|
lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) |
|
return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) |
|
evaluator = Engine(inference) |
|
|
|
|
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) |
|
if args.n_epochs < 1: |
|
trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) |
|
if args.eval_before_start: |
|
trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) |
|
|
|
|
|
if args.distributed: |
|
trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) |
|
evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) |
|
|
|
|
|
scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) |
|
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) |
|
|
|
|
|
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") |
|
metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])), |
|
"accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} |
|
metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), |
|
"average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) |
|
metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) |
|
for name, metric in metrics.items(): |
|
metric.attach(evaluator, name) |
|
|
|
|
|
if args.local_rank in [-1, 0]: |
|
pbar = ProgressBar(persist=True) |
|
pbar.attach(trainer, metric_names=["loss"]) |
|
evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) |
|
|
|
log_dir = make_logdir(args.model_checkpoint) |
|
tb_logger = TensorboardLogger(log_dir) |
|
|
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) |
|
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) |
|
tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) |
|
|
|
checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3) |
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) |
|
|
|
torch.save(args, log_dir + '/model_training_args.bin') |
|
getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) |
|
tokenizer.save_pretrained(log_dir) |
|
|
|
|
|
trainer.run(train_loader, max_epochs=args.n_epochs) |
|
|
|
|
|
if args.local_rank in [-1, 0] and args.n_epochs > 0: |
|
os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME)) |
|
tb_logger.close() |
|
|
|
if __name__ == "__main__": |
|
train() |
|
|