import os # os.environ['CUDA_LAUNCH_BLOCKING']='1' import numpy as np import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP import argparse from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from datetime import datetime # import cv2 import torch from torch.nn.utils import clip_grad_norm_ import yaml from dist_train_utils import print, get_world_size, get_rank, get_local_rank, barrier, reduce_sum, reduce_mean from tqdm import tqdm from lr_utils import CosineAnnealingWarmupRestarts import ssl ssl._create_default_https_context = ssl._create_unverified_context import webdataset as wds from dataset.get_vqgan_wds import get_dataset, handle_exception # from dataset.vq_wds import get_dataset, my_sample_decoder, my_sample_prec import json import time from accelerate import Accelerator from accelerate.utils import set_seed from data_generation.vqgan.load import load_model import random parser = argparse.ArgumentParser() #ddp parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--base_lr", type=float, default=4.5e-6) parser.add_argument("--log_folder", type=str, default='f16') args = parser.parse_args() set_seed(42) accelerator = Accelerator() # class MyLoader(wds.WebLoader, torch.utils.data.DataLoader): # pass # train_dataloader = MyLoader(train_dataset, ...) # with open('/mnt/bn/robotics-data-hl/jirong/git/incontextrobotics/models/embedding/VQGAN//model.yaml', encoding='utf-8') as f: # cfg = yaml.load(f, Loader=yaml.FullLoader) log_folder = args.log_folder log_path = os.path.join('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/', log_folder) if get_rank() == 0: if not os.path.exists(log_path): os.makedirs(log_path) os.system('chmod -R 777 ' + log_path) accelerator.wait_for_everyone() timestamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) if get_rank() == 0: WRITER = SummaryWriter(log_path+'/'+timestamp, max_queue=1000) print('world_size: {}, rank: {}, local_rank: {}'.format(get_world_size(), get_rank(), get_local_rank()), all=True) # dist.init_process_group(backend='nccl', rank = get_rank(), world_size=get_world_size()) # # assign gpu # torch.cuda.set_device(get_local_rank()) # # torch.cuda.set_device(7) # device = torch.cuda.current_device() device = accelerator.device print('device:', device, all=True) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_muse_finetune_calvin_vl_20m_1e-6_256/checkpoint_vq_epoch_48127.tar') # vqmodel = load_model() # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024/checkpoint_vq_epoch_139999.tar') # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024/checkpoint_vq_epoch_64999.tar') # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_1e-4_192_disc_50000/checkpoint_vq_epoch_49999.tar') # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_1e-4_192_subset/checkpoint_vq_epoch_549999.tar') # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_3e-4_192/checkpoint_vq_epoch_149999.tar') # vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/vqgan_ckpt/ckpt.pth', 8192) vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_192_real_calvin_robot_datacomp_1e-5_disc_start_0_weight_0.2_acc_1/checkpoint_vq_epoch_9999.tar', 8192) # vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_192_real_calvin_robot_datacomp_1e-5_disc_start_0_weight_0.2_acc_32/checkpoint_vq_epoch_19199.tar', 8192) # vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_256_real_calvin_datacomp_1e-5_disc_start_0_weight_0.1/checkpoint_vq_epoch_34999.tar', 8192) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3/checkpoint_vq_epoch_89999.tar', 2048) # vqmodel = load_model(None, 2048) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3_resume/checkpoint_vq_epoch_23999.tar', 2048) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048/checkpoint_vq_epoch_122499.tar', 2048) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/git/DeLVM/vqgan_ckpt/ckpt.pth', 8192) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_muse_ego4d_llava_calvin_3e-6_disc_0/checkpoint_vq_epoch_66999.tar', 8192) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3/checkpoint_vq_epoch_89999.tar', 2048) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.5_resume/checkpoint_vq_epoch_999.tar', 2048) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048) # copied_param = [] # init_param = [] # for ((name, param), (name_s, param_s)) in zip(vq_model_t.named_parameters(), vqmodel.named_parameters()): # if param.shape == param_s.shape: # param_s.data = param.clone().data # copied_param.append(param_s) # else: # init_param.append(name_s) # # print (name, param.shape, name_s, param_s.shape) # print ('params coopied from teacher') # print (init_param) # vqmodel = get_vqmodel(cfg) # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_8192_160m_1e-4_muse_pretrained/checkpoint_vq_epoch_49999.tar') # ckpt = torch.load('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_8192_160m_1e-4_muse_pretrained/checkpoint_vq_epoch_49999.tar', map_location='cpu') # vqmodel.load_state_dict(ckpt) vqmodel = nn.SyncBatchNorm.convert_sync_batchnorm(vqmodel) vqmodel = vqmodel.to(device) # with open('/mnt/bn/robotics-data-hl/jirong/git/incontextrobotics/dataset/vqgan_imgs/calvin_only.json', 'r') as f: # tars = json.load(f) ds_list = [] # ds0 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/calvin_img_00000051.tar', seed=42) # ds_list += [(ds0, 0.05)] # print (len(tars), '============') # ds1 = get_dataset('/mnt/bn/roboicllq-data1/processed_real/imgs/real_data_img_{0000..0590}.tar', seed=42) # ds_list += [(ds1, 0.1)] # ds2 = get_dataset('/mnt/bn/roboicllq-data1/processed_real/hand_imgs/real_data_hand_img_{0000..0590}.tar', seed=42) # ds_list += [(ds2, 0.02)] ds3 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/calvin_img_{00000000..00000110}.tar', seed=42) ds_list += [(ds3, 0.1)] ds6 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/hands/calvin_hands_img_{00000000..00000110}.tar', seed=42) ds_list += [(ds6, 0.02)] with open('/mnt/bn/roboicl-jirong/codebase/RoboICL/robot_img.json', 'r') as f: tars = json.load(f) # # # print (len(tars)) ds4 = get_dataset(tars, seed=42) ds_list += [(ds4, 0.3)] with open('/mnt/bn/roboicllq-data1/aligned_robot_ds/calvin/datacomp.json', 'r') as f: tars = json.load(f) # # # print (len(tars)) ds0 = get_dataset(tars, seed=42) ds_list += [(ds0, 0.68)] ds = wds.RandomMix(*zip(*ds_list)) ds = wds.DataPipeline(ds) loader = ( wds.WebLoader(ds, num_workers=4, batch_size=args.batch_size, pin_memory=True).with_epoch(args.epoch) ) # loader = MyLoader(dataset=trainset, num_workers=4, batch_size=args.batch_size, pin_memory=True).with_epoch(args.epoch) # base_lr = 4.5e-6 * args.batch_size * get_world_size() base_lr = 1e-5 # base_lr = 3e-6 # base_lr = 1e-6 # base_lr = 1e-3 # base_lr = 5e-4 # base_lr = 1e-4 # base_lr = 5e-5 opt, _ = vqmodel.configure_optimizers(base_lr) # ae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt[0], T_max=args.epoch, eta_min=base_lr * 0.001) # disc_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt[1], T_max=args.epoch, eta_min=base_lr * 0.001) # disc_start = cfg['model']['params']['lossconfig']['params']['disc_start'] disc_start = 0 ae_opt, disc_opt = opt[0], opt[1] ae_scheduler = CosineAnnealingWarmupRestarts(ae_opt, first_cycle_steps=args.epoch, cycle_mult=1, max_lr=base_lr, min_lr=base_lr*0.9999999, warmup_steps=args.epoch/10, gamma=1) disc_scheduler = CosineAnnealingWarmupRestarts(disc_opt, first_cycle_steps=args.epoch, cycle_mult=1, max_lr=base_lr, min_lr=base_lr*0.9999999, warmup_steps=args.epoch/10, gamma=1) vqmodel, ae_opt, disc_opt, ae_scheduler, disc_scheduler, loader = accelerator.prepare(vqmodel, ae_opt, disc_opt, ae_scheduler, disc_scheduler, loader) print ('global config end ---------------------------------------------------') log_iter = 50 acc_steps = 1 def train_one_epoch(args): STEP_CNT = 0 DISC_STEP_CNT = 0 vqmodel.train() stat_dict = {} for i, data in enumerate(loader): batch = data.to(device) aeloss, log_dict_ae = vqmodel.module.training_step(batch, 0, device, STEP_CNT) if STEP_CNT >= disc_start: discloss, log_dict_disc = vqmodel.module.training_step(batch, 1, device, STEP_CNT) accelerator.backward(aeloss) # if (STEP_CNT + 1) == 0: ae_opt.step() ae_opt.zero_grad() if STEP_CNT >= disc_start: # print('asodkjsaoifdjosjio') accelerator.backward(discloss) # if (STEP_CNT + 1) == 0: disc_opt.step() disc_opt.zero_grad() # STEP_CNT += 1 if (STEP_CNT + 1) % log_iter == 0: if get_rank() == 0: for k, v in log_dict_ae.items(): if k not in stat_dict.keys(): stat_dict[k] = 0 stat_dict[k] += v.cpu().item() if STEP_CNT >= disc_start: for k, v in log_dict_disc.items(): if k not in stat_dict.keys(): stat_dict[k] = 0 stat_dict[k] += v.cpu().item() for k, v in stat_dict.items(): WRITER.add_scalar(k, v/log_iter, STEP_CNT*args.batch_size*get_world_size()) stat_dict[k] = 0 WRITER.add_scalar('lr_ae', ae_opt.param_groups[0]['lr'], STEP_CNT*args.batch_size*get_world_size()) WRITER.add_scalar('lr_disc', disc_opt.param_groups[0]['lr'], STEP_CNT*args.batch_size*get_world_size()) STEP_CNT += 1 # i += 1 ae_scheduler.step(STEP_CNT) if STEP_CNT >= disc_start: DISC_STEP_CNT += 1 disc_scheduler.step(STEP_CNT) if (STEP_CNT + 1) % (log_iter*1000) == 0: save_dict = {} if get_rank() == 0: unwrapped_model = accelerator.unwrap_model(vqmodel) accelerator.save(unwrapped_model.state_dict(), os.path.join(log_path, 'checkpoint_vq_epoch_' +str(STEP_CNT)+'.tar')) # barrier() accelerator.wait_for_everyone() print ('epoch {} train done'.format(STEP_CNT)) def main(args): train_one_epoch(args) if __name__ == '__main__': import time # train_one_epoch() main(args) time.sleep(60) dist.destroy_process_group() print ('train done!')