|
import argparse |
|
import glob |
|
import logging |
|
import os |
|
import time |
|
from argparse import Namespace |
|
|
|
import numpy as np |
|
import torch |
|
from lightning_base import BaseTransformer, add_generic_args, generic_train |
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
|
from transformers import glue_compute_metrics as compute_metrics |
|
from transformers import glue_convert_examples_to_features as convert_examples_to_features |
|
from transformers import glue_output_modes, glue_tasks_num_labels |
|
from transformers import glue_processors as processors |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GLUETransformer(BaseTransformer): |
|
mode = "sequence-classification" |
|
|
|
def __init__(self, hparams): |
|
if type(hparams) == dict: |
|
hparams = Namespace(**hparams) |
|
hparams.glue_output_mode = glue_output_modes[hparams.task] |
|
num_labels = glue_tasks_num_labels[hparams.task] |
|
|
|
super().__init__(hparams, num_labels, self.mode) |
|
|
|
def forward(self, **inputs): |
|
return self.model(**inputs) |
|
|
|
def training_step(self, batch, batch_idx): |
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} |
|
|
|
if self.config.model_type not in ["distilbert", "bart"]: |
|
inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None |
|
|
|
outputs = self(**inputs) |
|
loss = outputs[0] |
|
|
|
lr_scheduler = self.trainer.lr_schedulers[0]["scheduler"] |
|
tensorboard_logs = {"loss": loss, "rate": lr_scheduler.get_last_lr()[-1]} |
|
return {"loss": loss, "log": tensorboard_logs} |
|
|
|
def prepare_data(self): |
|
"Called to initialize data. Use the call to construct features" |
|
args = self.hparams |
|
processor = processors[args.task]() |
|
self.labels = processor.get_labels() |
|
|
|
for mode in ["train", "dev"]: |
|
cached_features_file = self._feature_file(mode) |
|
if os.path.exists(cached_features_file) and not args.overwrite_cache: |
|
logger.info("Loading features from cached file %s", cached_features_file) |
|
else: |
|
logger.info("Creating features from dataset file at %s", args.data_dir) |
|
examples = ( |
|
processor.get_dev_examples(args.data_dir) |
|
if mode == "dev" |
|
else processor.get_train_examples(args.data_dir) |
|
) |
|
features = convert_examples_to_features( |
|
examples, |
|
self.tokenizer, |
|
max_length=args.max_seq_length, |
|
label_list=self.labels, |
|
output_mode=args.glue_output_mode, |
|
) |
|
logger.info("Saving features into cached file %s", cached_features_file) |
|
torch.save(features, cached_features_file) |
|
|
|
def get_dataloader(self, mode: str, batch_size: int, shuffle: bool = False) -> DataLoader: |
|
"Load datasets. Called after prepare data." |
|
|
|
|
|
mode = "dev" if mode == "test" else mode |
|
|
|
cached_features_file = self._feature_file(mode) |
|
logger.info("Loading features from cached file %s", cached_features_file) |
|
features = torch.load(cached_features_file) |
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) |
|
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) |
|
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) |
|
if self.hparams.glue_output_mode == "classification": |
|
all_labels = torch.tensor([f.label for f in features], dtype=torch.long) |
|
elif self.hparams.glue_output_mode == "regression": |
|
all_labels = torch.tensor([f.label for f in features], dtype=torch.float) |
|
|
|
return DataLoader( |
|
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels), |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} |
|
|
|
if self.config.model_type not in ["distilbert", "bart"]: |
|
inputs["token_type_ids"] = batch[2] if self.config.model_type in ["bert", "xlnet", "albert"] else None |
|
|
|
outputs = self(**inputs) |
|
tmp_eval_loss, logits = outputs[:2] |
|
preds = logits.detach().cpu().numpy() |
|
out_label_ids = inputs["labels"].detach().cpu().numpy() |
|
|
|
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} |
|
|
|
def _eval_end(self, outputs) -> tuple: |
|
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item() |
|
preds = np.concatenate([x["pred"] for x in outputs], axis=0) |
|
|
|
if self.hparams.glue_output_mode == "classification": |
|
preds = np.argmax(preds, axis=1) |
|
elif self.hparams.glue_output_mode == "regression": |
|
preds = np.squeeze(preds) |
|
|
|
out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0) |
|
out_label_list = [[] for _ in range(out_label_ids.shape[0])] |
|
preds_list = [[] for _ in range(out_label_ids.shape[0])] |
|
|
|
results = {**{"val_loss": val_loss_mean}, **compute_metrics(self.hparams.task, preds, out_label_ids)} |
|
|
|
ret = dict(results.items()) |
|
ret["log"] = results |
|
return ret, preds_list, out_label_list |
|
|
|
def validation_epoch_end(self, outputs: list) -> dict: |
|
ret, preds, targets = self._eval_end(outputs) |
|
logs = ret["log"] |
|
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs} |
|
|
|
def test_epoch_end(self, outputs) -> dict: |
|
ret, predictions, targets = self._eval_end(outputs) |
|
logs = ret["log"] |
|
|
|
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs} |
|
|
|
@staticmethod |
|
def add_model_specific_args(parser, root_dir): |
|
BaseTransformer.add_model_specific_args(parser, root_dir) |
|
parser.add_argument( |
|
"--max_seq_length", |
|
default=128, |
|
type=int, |
|
help=( |
|
"The maximum total input sequence length after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded." |
|
), |
|
) |
|
|
|
parser.add_argument( |
|
"--task", |
|
default="", |
|
type=str, |
|
required=True, |
|
help="The GLUE task to run", |
|
) |
|
parser.add_argument( |
|
"--gpus", |
|
default=0, |
|
type=int, |
|
help="The number of GPUs allocated for this, it is by default 0 meaning none", |
|
) |
|
|
|
parser.add_argument( |
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" |
|
) |
|
|
|
return parser |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
add_generic_args(parser, os.getcwd()) |
|
parser = GLUETransformer.add_model_specific_args(parser, os.getcwd()) |
|
args = parser.parse_args() |
|
|
|
|
|
if args.output_dir is None: |
|
args.output_dir = os.path.join( |
|
"./results", |
|
f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}", |
|
) |
|
os.makedirs(args.output_dir) |
|
|
|
model = GLUETransformer(args) |
|
trainer = generic_train(model, args) |
|
|
|
|
|
if args.do_predict: |
|
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)) |
|
model = model.load_from_checkpoint(checkpoints[-1]) |
|
return trainer.test(model) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|