Spaces:
Paused
Paused
| """ Main training script """ | |
| import argparse | |
| import functools | |
| import glob | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" | |
| import random | |
| import shutil | |
| import sys | |
| sys.path.append('../') | |
| import yaml | |
| import time | |
| import numpy as np | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| from torch.distributed.fsdp import ( | |
| CPUOffload, | |
| MixedPrecision, | |
| ShardingStrategy, | |
| BackwardPrefetch, | |
| ) | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| checkpoint_wrapper, | |
| CheckpointWrapper, | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| ) | |
| from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups | |
| from torch.distributed.distributed_c10d import _get_default_group | |
| torch.cuda.empty_cache() | |
| from transformers import ( | |
| get_constant_schedule_with_warmup, | |
| get_cosine_schedule_with_warmup, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from data.data import get_audiotext_dataloader # AudioTextData, DataCollator | |
| from distributed import init_distributed_device, world_info_from_env | |
| from train_utils import ( | |
| train_one_epoch, | |
| get_mp_policy_dtype, | |
| save_checkpoint, | |
| Dict2Class, | |
| get_autocast, | |
| get_cast_dtype | |
| ) | |
| from valid_utils import validation_losses | |
| from src.factory import create_model_and_transforms | |
| def random_seed(seed=42, rank=0): | |
| torch.manual_seed(seed + rank) | |
| np.random.seed(seed + rank) | |
| random.seed(seed + rank) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path') | |
| parsed_args = parser.parse_args() | |
| config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader) | |
| data_config = config['data_config'] | |
| model_config = config['model_config'] | |
| clap_config = config["clap_config"] | |
| args = Dict2Class(config['train_config']) | |
| if 'sft_config' in config: | |
| sft_config = config['sft_config'] | |
| unfreeze_full_lm = sft_config['unfreeze_full_lm'] | |
| else: | |
| sft_config = None | |
| unfreeze_full_lm = False | |
| # get paths done | |
| exp_path = os.path.join(args.expdir, args.run_name) | |
| os.makedirs(exp_path, exist_ok=True) | |
| print('exp_path:', exp_path) | |
| shutil.copy(parsed_args.config, os.path.join(exp_path, 'config.yaml')) | |
| data_config["dataset_blending_output"] = os.path.join(exp_path, data_config["dataset_blending_output"]) | |
| # Validate args | |
| if args.fsdp and not args.fsdp_use_orig_params: | |
| print( | |
| "Warning: FSDP is running without fsdp_use_orig_params flag. " | |
| + "This is not recommended because it means we will use uniform weight decay" | |
| + " and train all embeddings, not just the newly added ones. " | |
| + "Note: OPT models are not compatible with fsdp_use_orig_params flag." | |
| ) | |
| if args.fsdp and args.fsdp_sharding_strategy == "hybrid": | |
| print( | |
| "Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding." | |
| + "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py" | |
| + "Copy and paste the code from the _optim_utils.py in this repo into the torch file." | |
| + "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state." | |
| ) | |
| # Set up distributed training | |
| print('initializing distributed environment') | |
| if args.offline: | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| args.local_rank, args.rank, args.world_size = world_info_from_env() | |
| device_id = init_distributed_device(args) | |
| random_seed(args.seed) | |
| # Initialize model | |
| print('creating model') | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
| model, tokenizer = create_model_and_transforms( | |
| **model_config, | |
| clap_config=clap_config, | |
| use_local_files=args.offline, | |
| gradient_checkpointing=args.gradient_checkpointing, | |
| freeze_lm_embeddings=args.freeze_lm_embeddings, | |
| unfreeze_full_lm=unfreeze_full_lm | |
| ) | |
| random_seed(args.seed, args.rank) | |
| # Initialize logging | |
| print(f"Start running training on rank {args.rank}.") | |
| # Load model checkpoint on CPU | |
| checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt") | |
| if len(checkpoint_list) == 0: | |
| print(f"Found no checkpoints for run {args.run_name}.") | |
| resume_from_checkpoint = None | |
| else: | |
| resume_from_checkpoint = sorted( | |
| checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
| )[-1] | |
| print( | |
| f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}." | |
| ) | |
| # load pretrained model | |
| resume_from_epoch = 0 | |
| if (resume_from_checkpoint is None) and (sft_config is not None): | |
| # just started SFT | |
| pretrained_path = os.path.join( | |
| sft_config['pretrained_path'], | |
| sft_config['pretrained_ckpt'] | |
| ) | |
| if args.rank == 0: | |
| print(f"Loading checkpoint from {pretrained_path}") | |
| checkpoint = torch.load(pretrained_path, map_location="cpu") | |
| msd = checkpoint["model_state_dict"] | |
| msd = {k.replace("module.", ""): v for k, v in msd.items()} | |
| # for fsdp, only one rank needs to load the state dict | |
| if not args.fsdp or args.rank == 0: | |
| model.load_state_dict(msd, False) | |
| del checkpoint["model_state_dict"] | |
| del msd | |
| elif resume_from_checkpoint is not None: | |
| # continue training (either pretraining or STF) | |
| if args.rank == 0: | |
| print(f"Loading checkpoint from {resume_from_checkpoint}") | |
| checkpoint = torch.load(resume_from_checkpoint, map_location="cpu") | |
| msd = checkpoint["model_state_dict"] | |
| msd = {k.replace("module.", ""): v for k, v in msd.items()} | |
| resume_from_epoch = checkpoint["epoch"] + 1 | |
| # for fsdp, only one rank needs to load the state dict | |
| if not args.fsdp or args.rank == 0: | |
| model.load_state_dict(msd, False) | |
| del checkpoint["model_state_dict"] | |
| del msd | |
| else: | |
| pass | |
| # Initialize FSDP / DDP, and ensure the model is on GPU | |
| print(f"Initializing distributed training with {args.world_size} GPUs.") | |
| if args.fsdp: | |
| print( | |
| f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" | |
| ) | |
| # init MixedPrecision | |
| if args.precision != "fp32": | |
| cast_dtype = get_mp_policy_dtype(args.precision) | |
| mp_policy = MixedPrecision( | |
| param_dtype=torch.float32, | |
| reduce_dtype=cast_dtype, # gradient communication | |
| buffer_dtype=cast_dtype, | |
| ) | |
| else: | |
| mp_policy = None | |
| # init process groups | |
| if args.fsdp_sharding_strategy == "hybrid": | |
| intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( | |
| _get_default_group() | |
| ) | |
| args.my_group = intra_node_group # for optimizer saving | |
| process_group = (intra_node_group, inter_node_group) # for FSDP init | |
| else: | |
| args.my_group = None # for optimizer saving | |
| process_group = None # for FSDP init | |
| # init FSDP | |
| wrapper_kwargs = dict( | |
| process_group=process_group, | |
| cpu_offload=CPUOffload(offload_params=False), | |
| device_id=device_id, | |
| sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks | |
| sharding_strategy=ShardingStrategy.FULL_SHARD | |
| if args.fsdp_sharding_strategy == "full" | |
| else ShardingStrategy.HYBRID_SHARD, | |
| use_orig_params=args.fsdp_use_orig_params, | |
| mixed_precision=mp_policy, | |
| forward_prefetch=True, | |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
| limit_all_gathers=True, | |
| ) | |
| model.wrap_fsdp(wrapper_kwargs, device_id) | |
| ddp_model = model | |
| print( | |
| f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" | |
| ) | |
| print( | |
| f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" | |
| ) | |
| else: | |
| model = model.to(device_id) | |
| ddp_model = DDP(model, device_ids=[device_id]) | |
| # Initialize gradient checkpointing | |
| if args.gradient_checkpointing: | |
| non_reentrant_wrapper = functools.partial( | |
| checkpoint_wrapper, | |
| offload_to_cpu=True, | |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
| ) | |
| apply_activation_checkpointing( | |
| ddp_model, | |
| checkpoint_wrapper_fn=non_reentrant_wrapper, | |
| check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) | |
| and not isinstance(m, FSDP) | |
| and not isinstance(m, CheckpointWrapper), | |
| ) | |
| # Initialize optimizer | |
| params_to_optimize = ddp_model.named_parameters() | |
| params_to_optimize = list( | |
| filter( | |
| lambda x: x[1].requires_grad | |
| and not getattr(x[1], "exclude_from_optimizer", False), | |
| params_to_optimize, | |
| ) | |
| ) | |
| if not args.fsdp or args.fsdp_use_orig_params: | |
| # apply weight decay only to params in the xattn layers | |
| def get_grouped_params(model): | |
| params_with_wd, params_without_wd = [], [] | |
| for n, p in params_to_optimize: | |
| if "gated_cross_attn" in n: | |
| params_with_wd.append(p) | |
| else: | |
| params_without_wd.append(p) | |
| return [ | |
| {"params": params_with_wd, "weight_decay": args.weight_decay}, | |
| {"params": params_without_wd, "weight_decay": 0.0}, | |
| ] | |
| optimizer = torch.optim.AdamW( | |
| get_grouped_params(params_to_optimize), lr=args.learning_rate | |
| ) | |
| else: | |
| # unclear if we should be using no weight decay or small weight decay for all parameters | |
| optimizer = torch.optim.AdamW( | |
| (p for _, p in params_to_optimize), | |
| lr=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| ) | |
| # load optimizer checkpoint | |
| if resume_from_checkpoint is not None: | |
| osd = checkpoint["optimizer_state_dict"] | |
| if args.fsdp: | |
| osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer) | |
| optimizer.load_state_dict(osd) | |
| del checkpoint["optimizer_state_dict"] | |
| del osd | |
| # Initialize data loaders | |
| AudioTextDataInfo = get_audiotext_dataloader( | |
| data_config, clap_config, tokenizer, args.batch_size, split='train', | |
| epoch=0, force_reblend=True | |
| ) | |
| total_training_steps = ( | |
| len(AudioTextDataInfo.dataset) // (args.batch_size * args.world_size) | |
| ) * args.num_epochs | |
| if args.rank == 0: | |
| print(f"Total training steps: {total_training_steps}") | |
| tb = SummaryWriter(os.path.join(exp_path, 'tensorboard')) | |
| else: | |
| tb = None | |
| # Initialize lr scheduler | |
| if args.lr_scheduler == "linear": | |
| lr_scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| elif args.lr_scheduler == "cosine": | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| else: | |
| lr_scheduler = get_constant_schedule_with_warmup( | |
| optimizer, num_warmup_steps=args.warmup_steps | |
| ) | |
| # load lr scheduler checkpoint | |
| if resume_from_checkpoint is not None: | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
| del checkpoint["lr_scheduler_state_dict"] | |
| # Start training! | |
| ddp_model.train() | |
| print('start training from epoch {}'.format(resume_from_epoch)) | |
| for epoch in range(resume_from_epoch, args.num_epochs): | |
| # force reblending dataset for every epoch | |
| if epoch > 0: | |
| AudioTextDataInfo = get_audiotext_dataloader( | |
| data_config, clap_config, tokenizer, args.batch_size, split='train', | |
| epoch=epoch, force_reblend=True | |
| ) | |
| AudioTextDataInfo.set_epoch(epoch) | |
| trainloader = AudioTextDataInfo.dataloader | |
| # train one epoch | |
| train_one_epoch( | |
| args=args, | |
| model=ddp_model, | |
| epoch=epoch, | |
| tokenizer=tokenizer, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| trainloader=trainloader, | |
| device_id=device_id, | |
| tb=tb | |
| ) | |
| # save checkpoint | |
| save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) | |
| time.sleep(1.0) | |
| # validation | |
| if epoch % 5 == 0: | |
| torch.distributed.barrier() | |
| try: | |
| with torch.no_grad(): | |
| valid_losses = validation_losses( | |
| model=ddp_model, | |
| data_config=data_config, | |
| clap_config=clap_config, | |
| tokenizer=tokenizer, | |
| batch_size=args.batch_size, | |
| autocast=get_autocast(args.precision, cache_enabled=(not args.fsdp)), | |
| cast_dtype=get_cast_dtype(args.precision), | |
| device_id=device_id | |
| ) | |
| if args.rank == 0: | |
| for key in valid_losses: | |
| tb.add_scalar("Valid/{}".format(key), valid_losses[key], (epoch+1)*len(trainloader)) | |
| except Exception as error: | |
| print("An exception occurred:", error) | |
| torch.distributed.barrier() | |
| # save final checkpoint | |
| save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) | |
| if args.rank == 0: | |
| tb.close() | |
| if __name__ == "__main__": | |
| main() | |