Sreyan Ghosh
initial commit
bca6b51
""" Main training script """
import argparse
import functools
import glob
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import random
import shutil
import sys
sys.path.append('../')
import yaml
import time
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
CPUOffload,
MixedPrecision,
ShardingStrategy,
BackwardPrefetch,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointWrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups
from torch.distributed.distributed_c10d import _get_default_group
torch.cuda.empty_cache()
from transformers import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from data.data import get_audiotext_dataloader # AudioTextData, DataCollator
from distributed import init_distributed_device, world_info_from_env
from train_utils import (
train_one_epoch,
get_mp_policy_dtype,
save_checkpoint,
Dict2Class,
get_autocast,
get_cast_dtype
)
from valid_utils import validation_losses
from src.factory import create_model_and_transforms
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path')
parsed_args = parser.parse_args()
config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader)
data_config = config['data_config']
model_config = config['model_config']
clap_config = config["clap_config"]
args = Dict2Class(config['train_config'])
if 'sft_config' in config:
sft_config = config['sft_config']
unfreeze_full_lm = sft_config['unfreeze_full_lm']
else:
sft_config = None
unfreeze_full_lm = False
# get paths done
exp_path = os.path.join(args.expdir, args.run_name)
os.makedirs(exp_path, exist_ok=True)
print('exp_path:', exp_path)
shutil.copy(parsed_args.config, os.path.join(exp_path, 'config.yaml'))
data_config["dataset_blending_output"] = os.path.join(exp_path, data_config["dataset_blending_output"])
# Validate args
if args.fsdp and not args.fsdp_use_orig_params:
print(
"Warning: FSDP is running without fsdp_use_orig_params flag. "
+ "This is not recommended because it means we will use uniform weight decay"
+ " and train all embeddings, not just the newly added ones. "
+ "Note: OPT models are not compatible with fsdp_use_orig_params flag."
)
if args.fsdp and args.fsdp_sharding_strategy == "hybrid":
print(
"Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding."
+ "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py"
+ "Copy and paste the code from the _optim_utils.py in this repo into the torch file."
+ "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state."
)
# Set up distributed training
print('initializing distributed environment')
if args.offline:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
args.local_rank, args.rank, args.world_size = world_info_from_env()
device_id = init_distributed_device(args)
random_seed(args.seed)
# Initialize model
print('creating model')
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
model, tokenizer = create_model_and_transforms(
**model_config,
clap_config=clap_config,
use_local_files=args.offline,
gradient_checkpointing=args.gradient_checkpointing,
freeze_lm_embeddings=args.freeze_lm_embeddings,
unfreeze_full_lm=unfreeze_full_lm
)
random_seed(args.seed, args.rank)
# Initialize logging
print(f"Start running training on rank {args.rank}.")
# Load model checkpoint on CPU
checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
print(f"Found no checkpoints for run {args.run_name}.")
resume_from_checkpoint = None
else:
resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
print(
f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}."
)
# load pretrained model
resume_from_epoch = 0
if (resume_from_checkpoint is None) and (sft_config is not None):
# just started SFT
pretrained_path = os.path.join(
sft_config['pretrained_path'],
sft_config['pretrained_ckpt']
)
if args.rank == 0:
print(f"Loading checkpoint from {pretrained_path}")
checkpoint = torch.load(pretrained_path, map_location="cpu")
msd = checkpoint["model_state_dict"]
msd = {k.replace("module.", ""): v for k, v in msd.items()}
# for fsdp, only one rank needs to load the state dict
if not args.fsdp or args.rank == 0:
model.load_state_dict(msd, False)
del checkpoint["model_state_dict"]
del msd
elif resume_from_checkpoint is not None:
# continue training (either pretraining or STF)
if args.rank == 0:
print(f"Loading checkpoint from {resume_from_checkpoint}")
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu")
msd = checkpoint["model_state_dict"]
msd = {k.replace("module.", ""): v for k, v in msd.items()}
resume_from_epoch = checkpoint["epoch"] + 1
# for fsdp, only one rank needs to load the state dict
if not args.fsdp or args.rank == 0:
model.load_state_dict(msd, False)
del checkpoint["model_state_dict"]
del msd
else:
pass
# Initialize FSDP / DDP, and ensure the model is on GPU
print(f"Initializing distributed training with {args.world_size} GPUs.")
if args.fsdp:
print(
f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
)
# init MixedPrecision
if args.precision != "fp32":
cast_dtype = get_mp_policy_dtype(args.precision)
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=cast_dtype, # gradient communication
buffer_dtype=cast_dtype,
)
else:
mp_policy = None
# init process groups
if args.fsdp_sharding_strategy == "hybrid":
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
_get_default_group()
)
args.my_group = intra_node_group # for optimizer saving
process_group = (intra_node_group, inter_node_group) # for FSDP init
else:
args.my_group = None # for optimizer saving
process_group = None # for FSDP init
# init FSDP
wrapper_kwargs = dict(
process_group=process_group,
cpu_offload=CPUOffload(offload_params=False),
device_id=device_id,
sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks
sharding_strategy=ShardingStrategy.FULL_SHARD
if args.fsdp_sharding_strategy == "full"
else ShardingStrategy.HYBRID_SHARD,
use_orig_params=args.fsdp_use_orig_params,
mixed_precision=mp_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
)
model.wrap_fsdp(wrapper_kwargs, device_id)
ddp_model = model
print(
f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
)
print(
f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
)
else:
model = model.to(device_id)
ddp_model = DDP(model, device_ids=[device_id])
# Initialize gradient checkpointing
if args.gradient_checkpointing:
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=True,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
ddp_model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
and not isinstance(m, FSDP)
and not isinstance(m, CheckpointWrapper),
)
# Initialize optimizer
params_to_optimize = ddp_model.named_parameters()
params_to_optimize = list(
filter(
lambda x: x[1].requires_grad
and not getattr(x[1], "exclude_from_optimizer", False),
params_to_optimize,
)
)
if not args.fsdp or args.fsdp_use_orig_params:
# apply weight decay only to params in the xattn layers
def get_grouped_params(model):
params_with_wd, params_without_wd = [], []
for n, p in params_to_optimize:
if "gated_cross_attn" in n:
params_with_wd.append(p)
else:
params_without_wd.append(p)
return [
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(
get_grouped_params(params_to_optimize), lr=args.learning_rate
)
else:
# unclear if we should be using no weight decay or small weight decay for all parameters
optimizer = torch.optim.AdamW(
(p for _, p in params_to_optimize),
lr=args.learning_rate,
weight_decay=args.weight_decay,
)
# load optimizer checkpoint
if resume_from_checkpoint is not None:
osd = checkpoint["optimizer_state_dict"]
if args.fsdp:
osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer)
optimizer.load_state_dict(osd)
del checkpoint["optimizer_state_dict"]
del osd
# Initialize data loaders
AudioTextDataInfo = get_audiotext_dataloader(
data_config, clap_config, tokenizer, args.batch_size, split='train',
epoch=0, force_reblend=True
)
total_training_steps = (
len(AudioTextDataInfo.dataset) // (args.batch_size * args.world_size)
) * args.num_epochs
if args.rank == 0:
print(f"Total training steps: {total_training_steps}")
tb = SummaryWriter(os.path.join(exp_path, 'tensorboard'))
else:
tb = None
# Initialize lr scheduler
if args.lr_scheduler == "linear":
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
elif args.lr_scheduler == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
# load lr scheduler checkpoint
if resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
del checkpoint["lr_scheduler_state_dict"]
# Start training!
ddp_model.train()
print('start training from epoch {}'.format(resume_from_epoch))
for epoch in range(resume_from_epoch, args.num_epochs):
# force reblending dataset for every epoch
if epoch > 0:
AudioTextDataInfo = get_audiotext_dataloader(
data_config, clap_config, tokenizer, args.batch_size, split='train',
epoch=epoch, force_reblend=True
)
AudioTextDataInfo.set_epoch(epoch)
trainloader = AudioTextDataInfo.dataloader
# train one epoch
train_one_epoch(
args=args,
model=ddp_model,
epoch=epoch,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
trainloader=trainloader,
device_id=device_id,
tb=tb
)
# save checkpoint
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
time.sleep(1.0)
# validation
if epoch % 5 == 0:
torch.distributed.barrier()
try:
with torch.no_grad():
valid_losses = validation_losses(
model=ddp_model,
data_config=data_config,
clap_config=clap_config,
tokenizer=tokenizer,
batch_size=args.batch_size,
autocast=get_autocast(args.precision, cache_enabled=(not args.fsdp)),
cast_dtype=get_cast_dtype(args.precision),
device_id=device_id
)
if args.rank == 0:
for key in valid_losses:
tb.add_scalar("Valid/{}".format(key), valid_losses[key], (epoch+1)*len(trainloader))
except Exception as error:
print("An exception occurred:", error)
torch.distributed.barrier()
# save final checkpoint
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
if args.rank == 0:
tb.close()
if __name__ == "__main__":
main()