import argparse import datetime import os import shutil import sys import time import warnings from functools import partial import cv2 import torch import torch.cuda.amp as amp import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data as data from loguru import logger from torch.optim.lr_scheduler import MultiStepLR import utils.config as config import wandb # from engine.engine_verbonly import train, validate # from engine.engine_verbonly_hardneg import train, validate from utils.misc import (init_random_seed, set_random_seed, setup_logger, worker_init_fn) warnings.filterwarnings("ignore") cv2.setNumThreads(0) def get_parser(): parser = argparse.ArgumentParser( description='Pytorch Referring Expression Segmentation') parser.add_argument('--config', default='path to xxx.yaml', type=str, help='config file') parser.add_argument('--opts', default=None, nargs=argparse.REMAINDER, help='override some settings in the config.') args = parser.parse_args() assert args.config is not None cfg = config.load_cfg_from_cfg_file(args.config) if args.opts is not None: cfg = config.merge_cfg_from_list(cfg, args.opts) return cfg @logger.catch def main(): args = get_parser() args.manual_seed = init_random_seed(args.manual_seed) set_random_seed(args.manual_seed, deterministic=False) args.ngpus_per_node = torch.cuda.device_count() args.world_size = args.ngpus_per_node * args.world_size if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available!") mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,), join=True) def main_worker(gpu, args): args.output_dir = os.path.join(args.output_folder, args.exp_name) # local rank & global rank args.gpu = gpu args.rank = args.rank * args.ngpus_per_node + gpu torch.cuda.set_device(args.gpu) # logger setup_logger(args.output_dir, distributed_rank=args.gpu, filename="train.log", mode="a") # dist init dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) print(f"Initializing process: GPU {gpu}, Rank {args.rank}, World Size {args.world_size}") # wandb if args.rank == 0: # wandb.login(key='0363308e57fadd5c07e9294b934f64f27448b968') wandb.login(key='1a67d591f30466a974d6f41d1437f870ab462dc8') #chaeyun print('login succeeded!') print() if args.rank == 0: wandb.init(job_type="training", mode="online", config=args, project="Hardpos_CRIS", # project="debug", name=args.exp_name, tags=[args.dataset, args.clip_pretrain]) dist.barrier() # build model if args.metric_mode == "original" : from engine.engine import train, validate from model_ import build_segmenter_original from utils.dataset import RefDataset model, param_list = build_segmenter_original(args) elif args.metric_mode == "hardpos_only" or args.metric_mode == "hardpos_only_op2": from engine.engine_verbonly import train, validate from model_ import build_segmenter_pos from utils.dataset_verbonly import RefDataset model, param_list = build_segmenter_pos(args) elif "hardpos_only_rev" in args.metric_mode : from engine.engine_verbonly import train, validate from model_ import build_segmenter_pos_rev from utils.dataset_verbonly import RefDataset model, param_list = build_segmenter_pos_rev(args) else : from engine.engine_verbonly_hardneg import train, validate from model_ import build_segmenter from utils.dataset_verbonly import RefDataset model, param_list = build_segmenter(args) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) logger.info(model) model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[args.gpu], find_unused_parameters=True) dist.barrier() # build optimizer & lr scheduler optimizer = torch.optim.Adam(param_list, lr=args.base_lr, weight_decay=args.weight_decay) scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_decay) scaler = amp.GradScaler() # build dataset ### dataset check assert os.path.exists(args.train_lmdb), f"Train LMDB path {args.train_lmdb} does not exist." assert os.path.exists(args.mask_root), f"Mask root path {args.mask_root} does not exist." assert os.path.exists(args.val_lmdb), f"Val LMDB path {args.val_lmdb} does not exist." args.batch_size = int(args.batch_size / args.ngpus_per_node) args.batch_size_val = int(args.batch_size_val / args.ngpus_per_node) args.workers = int( (args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) # dataset check 2 # load는 되는가? try: dataset = RefDataset(lmdb_dir=args.train_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.train_split, mode='train', input_size=args.input_size, word_length=args.word_len, args=args) print(f"Dataset size: {len(dataset)}") except Exception as e: print(f"Dataset initialization error: {e}") train_data = RefDataset(lmdb_dir=args.train_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.train_split, mode='train', input_size=args.input_size, word_length=args.word_len, args=args) val_data = RefDataset(lmdb_dir=args.val_lmdb, mask_dir=args.mask_root, dataset=args.dataset, split=args.val_split, mode='val', input_size=args.input_size, word_length=args.word_len, args=args) print("Successfully loaded datasets!") # build dataloader init_fn = partial(worker_init_fn, num_workers=args.workers, rank=args.rank, seed=args.manual_seed) train_sampler = data.distributed.DistributedSampler(train_data, shuffle=True) val_sampler = data.distributed.DistributedSampler(val_data, shuffle=False) train_loader = data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, worker_init_fn=init_fn, sampler=train_sampler, drop_last=True) val_loader = data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers_val, pin_memory=True, sampler=val_sampler, drop_last=True) print("Successfully loaded dataloaders!") best_IoU = 0.0 best_oIoU = 0.0 # resume if args.resume: path = None if os.path.isfile(args.resume): path = args.resume elif args.resume == 'latest': # Check if the output directory exists and list its contents dirs = os.listdir(args.output_dir) if "last_model.pth" in dirs: path = os.path.join(args.output_dir, "last_model.pth") if path is None or not os.path.isfile(path): # If no valid checkpoint is found print(f"Checkpoint '{path}' does not exist. Starting a new training run.") else: logger.info(f"=> loading checkpoint '{path}'") # checkpoint = torch.load(path) checkpoint = torch.load(path, map_location='cpu') args.start_epoch = checkpoint['epoch'] best_IoU = checkpoint["best_iou"] best_oIoU = checkpoint["best_oiou"] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) logger.info(f"=> loaded checkpoint '{path}' (epoch {checkpoint['epoch']})") # if args.resume: # if os.path.isfile(args.resume): # logger.info("=> loading checkpoint '{}'".format(args.resume)) # # Define a function to map the location # # def map_location_fn(storage, loc): # # return storage.cuda() # # checkpoint = torch.load(args.resume, map_location=map_location_fn) # checkpoint = torch.load(args.resume) # args.start_epoch = checkpoint['epoch'] # best_IoU = checkpoint["best_iou"] # best_oIoU = checkpoint["best_oiou"] # model.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) # scheduler.load_state_dict(checkpoint['scheduler']) # logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) # else: # raise ValueError( # "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" # .format(args.resume)) # start training start_time = time.time() for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 # shuffle loader train_sampler.set_epoch(epoch_log) # train train(train_loader, model, optimizer, scheduler, scaler, epoch_log, args) # evaluation iou, oiou, prec_dict = validate(val_loader, model, epoch_log, args) # save model if dist.get_rank() == 0: lastname = os.path.join(args.output_dir, "last_model.pth") torch.save( { 'epoch': epoch_log, 'cur_iou': iou, 'best_iou': best_IoU, 'best_oiou' : best_oIoU, 'prec': prec_dict, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, lastname) if iou >= best_IoU: best_IoU = iou bestname = os.path.join(args.output_dir, "best_model_miou.pth") shutil.copyfile(lastname, bestname) if oiou >= best_oIoU : best_oIoU = oiou bestname_oiou = os.path.join(args.output_dir, "best_model_oiou.pth") shutil.copyfile(lastname, bestname_oiou) # update lr scheduler.step(epoch_log) torch.cuda.empty_cache() time.sleep(2) if dist.get_rank() == 0: wandb.finish() logger.info("* Best IoU={} * ".format(best_IoU)) logger.info("* Best oIoU={} * ".format(best_oIoU)) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('* Training time {} *'.format(total_time_str)) if __name__ == '__main__': main() sys.exit(0)