Spaces:
Runtime error
Runtime error
""" Main training script """ | |
import argparse | |
import functools | |
import glob | |
import os | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" | |
import random | |
import shutil | |
import sys | |
sys.path.append('../') | |
import yaml | |
import time | |
import numpy as np | |
import torch | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp import ( | |
CPUOffload, | |
MixedPrecision, | |
ShardingStrategy, | |
BackwardPrefetch, | |
) | |
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
checkpoint_wrapper, | |
CheckpointWrapper, | |
CheckpointImpl, | |
apply_activation_checkpointing, | |
) | |
from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups | |
from torch.distributed.distributed_c10d import _get_default_group | |
torch.cuda.empty_cache() | |
from transformers import ( | |
get_constant_schedule_with_warmup, | |
get_cosine_schedule_with_warmup, | |
get_linear_schedule_with_warmup, | |
) | |
from data.data import get_audiotext_dataloader # AudioTextData, DataCollator | |
from distributed import init_distributed_device, world_info_from_env | |
from train_utils import ( | |
train_one_epoch, | |
get_mp_policy_dtype, | |
save_checkpoint, | |
Dict2Class, | |
get_autocast, | |
get_cast_dtype | |
) | |
from valid_utils import validation_losses | |
from src.factory import create_model_and_transforms | |
def random_seed(seed=42, rank=0): | |
torch.manual_seed(seed + rank) | |
np.random.seed(seed + rank) | |
random.seed(seed + rank) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path') | |
parsed_args = parser.parse_args() | |
config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader) | |
data_config = config['data_config'] | |
model_config = config['model_config'] | |
clap_config = config["clap_config"] | |
args = Dict2Class(config['train_config']) | |
if 'sft_config' in config: | |
sft_config = config['sft_config'] | |
unfreeze_full_lm = sft_config['unfreeze_full_lm'] | |
else: | |
sft_config = None | |
unfreeze_full_lm = False | |
# get paths done | |
exp_path = os.path.join(args.expdir, args.run_name) | |
os.makedirs(exp_path, exist_ok=True) | |
print('exp_path:', exp_path) | |
shutil.copy(parsed_args.config, os.path.join(exp_path, 'config.yaml')) | |
data_config["dataset_blending_output"] = os.path.join(exp_path, data_config["dataset_blending_output"]) | |
# Validate args | |
if args.fsdp and not args.fsdp_use_orig_params: | |
print( | |
"Warning: FSDP is running without fsdp_use_orig_params flag. " | |
+ "This is not recommended because it means we will use uniform weight decay" | |
+ " and train all embeddings, not just the newly added ones. " | |
+ "Note: OPT models are not compatible with fsdp_use_orig_params flag." | |
) | |
if args.fsdp and args.fsdp_sharding_strategy == "hybrid": | |
print( | |
"Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding." | |
+ "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py" | |
+ "Copy and paste the code from the _optim_utils.py in this repo into the torch file." | |
+ "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state." | |
) | |
# Set up distributed training | |
print('initializing distributed environment') | |
if args.offline: | |
os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
args.local_rank, args.rank, args.world_size = world_info_from_env() | |
device_id = init_distributed_device(args) | |
random_seed(args.seed) | |
# Initialize model | |
print('creating model') | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
model, tokenizer = create_model_and_transforms( | |
**model_config, | |
clap_config=clap_config, | |
use_local_files=args.offline, | |
gradient_checkpointing=args.gradient_checkpointing, | |
freeze_lm_embeddings=args.freeze_lm_embeddings, | |
unfreeze_full_lm=unfreeze_full_lm | |
) | |
random_seed(args.seed, args.rank) | |
# Initialize logging | |
print(f"Start running training on rank {args.rank}.") | |
# Load model checkpoint on CPU | |
checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt") | |
if len(checkpoint_list) == 0: | |
print(f"Found no checkpoints for run {args.run_name}.") | |
resume_from_checkpoint = None | |
else: | |
resume_from_checkpoint = sorted( | |
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
)[-1] | |
print( | |
f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}." | |
) | |
# load pretrained model | |
resume_from_epoch = 0 | |
if (resume_from_checkpoint is None) and (sft_config is not None): | |
# just started SFT | |
pretrained_path = os.path.join( | |
sft_config['pretrained_path'], | |
sft_config['pretrained_ckpt'] | |
) | |
if args.rank == 0: | |
print(f"Loading checkpoint from {pretrained_path}") | |
checkpoint = torch.load(pretrained_path, map_location="cpu") | |
msd = checkpoint["model_state_dict"] | |
msd = {k.replace("module.", ""): v for k, v in msd.items()} | |
# for fsdp, only one rank needs to load the state dict | |
if not args.fsdp or args.rank == 0: | |
model.load_state_dict(msd, False) | |
del checkpoint["model_state_dict"] | |
del msd | |
elif resume_from_checkpoint is not None: | |
# continue training (either pretraining or STF) | |
if args.rank == 0: | |
print(f"Loading checkpoint from {resume_from_checkpoint}") | |
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu") | |
msd = checkpoint["model_state_dict"] | |
msd = {k.replace("module.", ""): v for k, v in msd.items()} | |
resume_from_epoch = checkpoint["epoch"] + 1 | |
# for fsdp, only one rank needs to load the state dict | |
if not args.fsdp or args.rank == 0: | |
model.load_state_dict(msd, False) | |
del checkpoint["model_state_dict"] | |
del msd | |
else: | |
pass | |
# Initialize FSDP / DDP, and ensure the model is on GPU | |
print(f"Initializing distributed training with {args.world_size} GPUs.") | |
if args.fsdp: | |
print( | |
f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" | |
) | |
# init MixedPrecision | |
if args.precision != "fp32": | |
cast_dtype = get_mp_policy_dtype(args.precision) | |
mp_policy = MixedPrecision( | |
param_dtype=torch.float32, | |
reduce_dtype=cast_dtype, # gradient communication | |
buffer_dtype=cast_dtype, | |
) | |
else: | |
mp_policy = None | |
# init process groups | |
if args.fsdp_sharding_strategy == "hybrid": | |
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( | |
_get_default_group() | |
) | |
args.my_group = intra_node_group # for optimizer saving | |
process_group = (intra_node_group, inter_node_group) # for FSDP init | |
else: | |
args.my_group = None # for optimizer saving | |
process_group = None # for FSDP init | |
# init FSDP | |
wrapper_kwargs = dict( | |
process_group=process_group, | |
cpu_offload=CPUOffload(offload_params=False), | |
device_id=device_id, | |
sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks | |
sharding_strategy=ShardingStrategy.FULL_SHARD | |
if args.fsdp_sharding_strategy == "full" | |
else ShardingStrategy.HYBRID_SHARD, | |
use_orig_params=args.fsdp_use_orig_params, | |
mixed_precision=mp_policy, | |
forward_prefetch=True, | |
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
limit_all_gathers=True, | |
) | |
model.wrap_fsdp(wrapper_kwargs, device_id) | |
ddp_model = model | |
print( | |
f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" | |
) | |
print( | |
f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" | |
) | |
else: | |
model = model.to(device_id) | |
ddp_model = DDP(model, device_ids=[device_id]) | |
# Initialize gradient checkpointing | |
if args.gradient_checkpointing: | |
non_reentrant_wrapper = functools.partial( | |
checkpoint_wrapper, | |
offload_to_cpu=True, | |
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
) | |
apply_activation_checkpointing( | |
ddp_model, | |
checkpoint_wrapper_fn=non_reentrant_wrapper, | |
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) | |
and not isinstance(m, FSDP) | |
and not isinstance(m, CheckpointWrapper), | |
) | |
# Initialize optimizer | |
params_to_optimize = ddp_model.named_parameters() | |
params_to_optimize = list( | |
filter( | |
lambda x: x[1].requires_grad | |
and not getattr(x[1], "exclude_from_optimizer", False), | |
params_to_optimize, | |
) | |
) | |
if not args.fsdp or args.fsdp_use_orig_params: | |
# apply weight decay only to params in the xattn layers | |
def get_grouped_params(model): | |
params_with_wd, params_without_wd = [], [] | |
for n, p in params_to_optimize: | |
if "gated_cross_attn" in n: | |
params_with_wd.append(p) | |
else: | |
params_without_wd.append(p) | |
return [ | |
{"params": params_with_wd, "weight_decay": args.weight_decay}, | |
{"params": params_without_wd, "weight_decay": 0.0}, | |
] | |
optimizer = torch.optim.AdamW( | |
get_grouped_params(params_to_optimize), lr=args.learning_rate | |
) | |
else: | |
# unclear if we should be using no weight decay or small weight decay for all parameters | |
optimizer = torch.optim.AdamW( | |
(p for _, p in params_to_optimize), | |
lr=args.learning_rate, | |
weight_decay=args.weight_decay, | |
) | |
# load optimizer checkpoint | |
if resume_from_checkpoint is not None: | |
osd = checkpoint["optimizer_state_dict"] | |
if args.fsdp: | |
osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer) | |
optimizer.load_state_dict(osd) | |
del checkpoint["optimizer_state_dict"] | |
del osd | |
# Initialize data loaders | |
AudioTextDataInfo = get_audiotext_dataloader( | |
data_config, clap_config, tokenizer, args.batch_size, split='train', | |
epoch=0, force_reblend=True | |
) | |
total_training_steps = ( | |
len(AudioTextDataInfo.dataset) // (args.batch_size * args.world_size) | |
) * args.num_epochs | |
if args.rank == 0: | |
print(f"Total training steps: {total_training_steps}") | |
tb = SummaryWriter(os.path.join(exp_path, 'tensorboard')) | |
else: | |
tb = None | |
# Initialize lr scheduler | |
if args.lr_scheduler == "linear": | |
lr_scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=total_training_steps, | |
) | |
elif args.lr_scheduler == "cosine": | |
lr_scheduler = get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=total_training_steps, | |
) | |
else: | |
lr_scheduler = get_constant_schedule_with_warmup( | |
optimizer, num_warmup_steps=args.warmup_steps | |
) | |
# load lr scheduler checkpoint | |
if resume_from_checkpoint is not None: | |
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
del checkpoint["lr_scheduler_state_dict"] | |
# Start training! | |
ddp_model.train() | |
print('start training from epoch {}'.format(resume_from_epoch)) | |
for epoch in range(resume_from_epoch, args.num_epochs): | |
# force reblending dataset for every epoch | |
if epoch > 0: | |
AudioTextDataInfo = get_audiotext_dataloader( | |
data_config, clap_config, tokenizer, args.batch_size, split='train', | |
epoch=epoch, force_reblend=True | |
) | |
AudioTextDataInfo.set_epoch(epoch) | |
trainloader = AudioTextDataInfo.dataloader | |
# train one epoch | |
train_one_epoch( | |
args=args, | |
model=ddp_model, | |
epoch=epoch, | |
tokenizer=tokenizer, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
trainloader=trainloader, | |
device_id=device_id, | |
tb=tb | |
) | |
# save checkpoint | |
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) | |
time.sleep(1.0) | |
# validation | |
if epoch % 5 == 0: | |
torch.distributed.barrier() | |
try: | |
with torch.no_grad(): | |
valid_losses = validation_losses( | |
model=ddp_model, | |
data_config=data_config, | |
clap_config=clap_config, | |
tokenizer=tokenizer, | |
batch_size=args.batch_size, | |
autocast=get_autocast(args.precision, cache_enabled=(not args.fsdp)), | |
cast_dtype=get_cast_dtype(args.precision), | |
device_id=device_id | |
) | |
if args.rank == 0: | |
for key in valid_losses: | |
tb.add_scalar("Valid/{}".format(key), valid_losses[key], (epoch+1)*len(trainloader)) | |
except Exception as error: | |
print("An exception occurred:", error) | |
torch.distributed.barrier() | |
# save final checkpoint | |
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) | |
if args.rank == 0: | |
tb.close() | |
if __name__ == "__main__": | |
main() | |