fed-lora / examples /NLG /src /gpt2_ft.py
FZH1996
update fed-lora
e7d695a
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import time
import math
import os, sys
import numpy as np
import itertools
import torch
import random
from torch.utils.data import DataLoader
torch.set_printoptions(threshold=100000)
from gpu import (
add_gpu_params,
parse_gpu,
distributed_opt,
distributed_gather,
distributed_sync,
cleanup
)
from optimizer import (
create_adam_optimizer,
create_optimizer_scheduler,
add_optimizer_params,
create_adam_optimizer_from_args
)
from data_utils import FT_Dataset
from model import GPT2Config, GPT2LMModel
from exp_utils import create_exp_dir
import loralib as lora
parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')
add_gpu_params(parser)
add_optimizer_params(parser)
parser.add_argument('--train_data', required=True, help='location of training data corpus')
parser.add_argument('--valid_data', required=True, help='location of validation data corpus')
parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')
parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')
parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')
parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')
parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')
parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
help='model names')
parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')
parser.add_argument('--fp16', action='store_true', help='train model with fp16')
parser.add_argument('--log_interval', type=int, default=100, help='log interval')
parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')
parser.add_argument('--save_interval', type=int, default=500, help='save interval')
parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
help='working folder.')
parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'],
help='language model training objective')
parser.add_argument('--lora_dropout', default=0.0, type=float,
help='dropout probability for lora layers')
parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')
parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')
parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')
parser.add_argument('--roll_step', type=int, default=100, help='rolling step')
parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')
# influence model, calculate the influence score between two samples.
def print_args(args):
if args.rank == 0:
print('=' * 100)
for k, v in args.__dict__.items():
print(f' - {k} : {v}')
print('=' * 100)
class AverageMeter(object):
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update=True):
if args.fp16:
with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
_scaled_loss.backward()
else:
_loss.backward()
# for name, param in _model.named_parameters():
# if param.requires_grad and param.grad is not None:
# print(f"Parameter name: {name}")
# print(f"Gradient value: {param.grad}")
if is_update:
if args.clip > 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
else:
torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)
_optimizer.step()
_optimizer.zero_grad()
if _schedule is not None:
_schedule.step()
# print(f"query[0].lora_B = {_model.module.transformer.h[0].attn.c_attn.lora_B}")
def evaluate(model, valid_loader, args):
model.eval()
total_loss = 0.
start_time = time.time()
avg_lm_loss = AverageMeter()
with torch.no_grad():
for idx, data in enumerate(valid_loader):
data = {key: value for key, value in data.items()}
_input = data['input'].to(args.device)
_target = data['target'].to(args.device)
_msk = data['mask'].to(args.device)
_lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk)
loss = _loss.mean()
# print(f"logits={_lm_logits}, _loss={_loss}")
avg_lm_loss.update(loss.item())
if idx % 100 == 0:
print('eval samples:', idx, 'loss:', loss.float())
total_time = time.time() - start_time
print('average loss', avg_lm_loss.avg)
return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)
def train_validate(
model,
optimizer,
scheduler,
train_loader,
valid_loader,
args,
train_step=0,
epoch=0
):
model.train()
avg_lm_loss = AverageMeter()
print('start to train the model................', epoch)
log_start_time = time.time()
best_val_ppl = None
# train_loader.sampler.set_epoch(epoch)
for idx, data in enumerate(train_loader):
data = {key: value for key, value in data.items()}
_input = data['input'].to(args.device)
_target = data['target'].to(args.device)
_msk = data['mask'].to(args.device)
_lm_logits, _lm_loss = model(
_input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth
)
# print(_input[0])
_lm_loss = _lm_loss.mean()
train_step += 1
is_update = True if train_step % args.grad_acc == 0 else False
avg_lm_loss.update(_lm_loss.item())
optimizer_step(
_lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update
)
if train_step % args.log_interval == 0:
print(f"_lm_loss = {_lm_loss}")
print(f"layer[0].lora_A = {model.module.transformer.h[0].attn.c_attn.lora_A[0,:100]}")
elapsed = time.time() - log_start_time
lr = optimizer.param_groups[0]['lr']
log_str = f'| epoch {epoch:3d} step {train_step:>8d} | { idx + 1:>6d} batches | ' \
f'lr {lr:.3g} | ms/batch {elapsed * 1000 / args.log_interval:5.2f} | ' \
f'loss {avg_lm_loss.val:5.2f} | avg loss {avg_lm_loss.avg:5.2f} | ' \
f'ppl {math.exp(avg_lm_loss.avg):5.2f}'
if args.rank == 0:
print(log_str)
log_start_time = time.time()
avg_lm_loss.reset()
if train_step % args.save_interval == 0:
if args.rank == 0:
model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
print('saving checkpoint', model_path)
torch.save({'model_state_dict': lora.lora_state_dict(model)}, model_path)
distributed_sync(args)
# evaluation interval
if train_step % args.eval_interval == 0:
eval_start_time = time.time()
valid_loss, valid_ppl = evaluate(model, valid_loader, args)
if best_val_ppl is None or valid_ppl < best_val_ppl:
best_val_ppl = valid_ppl
log_str = f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' \
f'time: {time.time() - eval_start_time:5.2f}s | valid loss {valid_loss:5.2f} | ' \
f'valid ppl {valid_ppl:5.2f} | best ppl {best_val_ppl:5.2f} '
if args.rank == 0:
print('-' * 100)
print(log_str)
print('-' * 100)
model.train()
distributed_sync(args)
if train_step == args.max_step:
break
if args.rank == 0:
model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
print('saving checkpoint', model_path)
torch.save({'model_state_dict': model.state_dict()}, model_path)
distributed_sync(args)
return train_step
if __name__ == '__main__':
args = parser.parse_args()
parse_gpu(args)
print_args(args)
if args.fp16:
try:
from apex import amp
except Exception as e:
warnings.warn('Could not import amp, apex may not be installed')
torch.manual_seed(args.random_seed)
random.seed(args.random_seed)
if args.rank == 0:
args.logging = create_exp_dir(args.work_dir)
train_data = FT_Dataset(
args.train_data, args.train_batch_size, args.seq_len,
joint_lm=args.obj=='jlm'
)
valid_data = FT_Dataset(
args.valid_data, args.valid_batch_size, args.seq_len,
)
train_loader = DataLoader(
train_data, batch_size=args.train_batch_size, num_workers=0,
shuffle=False, pin_memory=False, drop_last=True,
# sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed)
)
valid_loader = DataLoader(
valid_data, batch_size=args.valid_batch_size, num_workers=0,
shuffle=False, pin_memory=False, drop_last=False,
# sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
)
print(f"train_loader={len(train_loader)}, train_data={len(train_data)}")
print(f"valid_loader={len(valid_loader)}, valid_data={len(valid_data)}")
if args.model_card == 'gpt2.sm':
config = GPT2Config(
n_embd=768, n_layer=12, n_head=12,
lora_attn_dim=args.lora_dim,
lora_attn_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
elif args.model_card == 'gpt2.md':
config = GPT2Config(
n_embd=1024, n_layer=24, n_head=16,
lora_attn_dim=args.lora_dim,
lora_attn_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
elif args.model_card == 'gpt2.lg':
config = GPT2Config(
n_embd=1280, n_layer=36, n_head=20,
lora_attn_dim=args.lora_dim,
lora_attn_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
lm_net = GPT2LMModel(config)
if args.init_checkpoint is not None:
print('loading model pretrained weight.')
lm_net.load_weight(torch.load(args.init_checkpoint))
lm_net = lm_net.cuda()
if args.lora_dim > 0:
lora.mark_only_lora_as_trainable(lm_net)
print(lm_net)
print(lm_net.transformer.h[0].attn.c_attn.weight.shape)
print(lm_net.transformer.h[0].attn.c_attn.lora_A.shape)
print(lm_net.transformer.h[0].attn.c_attn.lora_B.shape)
config_dict = vars(config)
for param, value in config_dict.items():
print(f"{param}: {value}")
print(args)
optimizer = create_adam_optimizer_from_args(lm_net, args)
print("optimizer: " + str(optimizer))
if args.max_step is None:
args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
print('set max_step:', args.max_step)
print('train_data.num_batches:', train_data.num_batches)
scheduler = create_optimizer_scheduler(optimizer, args)
if args.fp16:
lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
lm_net, optimizer = distributed_opt(args, lm_net, optimizer, grad_acc=args.grad_acc)
try:
train_step = 0
for epoch in itertools.count(start=1):
train_step = train_validate(
lm_net, optimizer, scheduler, train_loader, valid_loader, args,
train_step=train_step, epoch=epoch
)
if train_step >= args.max_step or (args.max_epoch is not None and epoch >= args.max_epoch):
if args.rank == 0:
print('-' * 100)
print('End of training')
break
except KeyboardInterrupt:
if args.rank == 0:
print('-' * 100)
print('Exiting from training early')
distributed_sync(args)
print('cleanup dist ...')
cleanup(args)