File size: 5,156 Bytes
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pytorch_lightning as pl
import os
import shutil
import fnmatch
import torch
from argparse import ArgumentParser
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import WandbLogger

from dev.utils.func import Logging, load_config_act
from dev.datasets.scalable_dataset import MultiDataModule
from dev.model.smart import SMART


def backup(source_dir, backup_dir):
    """
    Back up the source directory (code and configs) to a backup directory.
    """

    if os.path.exists(backup_dir):
        return
    os.makedirs(backup_dir, exist_ok=False)

    # Helper function to check if a path matches exclude patterns
    def should_exclude(path):
        for pattern in exclude_patterns:
            if fnmatch.fnmatch(os.path.basename(path), pattern):
                return True
        return False

    # Iterate through the files and directories in source_dir
    for root, dirs, files in os.walk(source_dir):
        # Skip excluded directories
        dirs[:] = [d for d in dirs if not should_exclude(d)]

        # Determine the relative path and destination path
        rel_path = os.path.relpath(root, source_dir)
        dest_dir = os.path.join(backup_dir, rel_path)
        os.makedirs(dest_dir, exist_ok=True)
        
        # Copy all relevant files
        for file in files:
            if any(fnmatch.fnmatch(file, pattern) for pattern in include_patterns):
                shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file))
    
    print(f"Backup completed. Files saved to: {backup_dir}")


if __name__ == '__main__':
    pl.seed_everything(2024, workers=True)
    torch.set_printoptions(precision=3)

    parser = ArgumentParser()
    Predictor_hash = {'smart': SMART,}
    parser.add_argument('--config', type=str, default='configs/ours_long_term.yaml')
    parser.add_argument('--pretrain_ckpt', type=str, default='')
    parser.add_argument('--ckpt_path', type=str, default='')
    parser.add_argument('--save_ckpt_path', type=str, default="output/debug")
    parser.add_argument('--devices', type=int, default=1)
    args = parser.parse_args()

    # backup codes
    exclude_patterns = ['*output*', '*logs', 'wandb', 'data', '*debug*', '*backup*', 'interact_*', '*edge_map*', '__pycache__']
    include_patterns = ['*.py', '*.json', '*.yaml', '*.yml', '*.sh']
    backup(os.getcwd(), os.path.join(args.save_ckpt_path, 'backups'))

    logger = Logging().log(level='DEBUG')
    config = load_config_act(args.config)
    Predictor = Predictor_hash[config.Model.predictor]
    strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)
    Data_config = config.Dataset
    datamodule = MultiDataModule(**vars(Data_config), logger=logger)

    import os
    wandb_logger = None
    if int(os.getenv('WANDB', 0)) and not int(os.getenv('DEBUG', 0)):
        # squeue -O username,state,nodelist,gres,minmemory,numcpus,name
        wandb_logger = WandbLogger(project='simagent')

    trainer_config = config.Trainer
    max_epochs = trainer_config.max_epochs

    if args.pretrain_ckpt == '':
        model = Predictor(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs)
    else:
        model = Predictor(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs)
        model.load_params_from_file(filename=args.pretrain_ckpt)

    every_n_epochs = 1
    if int(os.getenv('OVERFIT', 0)):
        max_epochs = trainer_config.overfit_epochs
        every_n_epochs = 100

    if int(os.getenv('CHECK_INPUTS', 0)):
        max_epochs = 1

    check_val_every_n_epoch = 1  # save checkpoints for each epoch
    model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,
                                       filename='{epoch:02d}',
                                       save_top_k=5,
                                       monitor='epoch',
                                       mode='max',
                                       save_last=True,
                                       every_n_train_steps=1000,
                                       save_on_train_epoch_end=True)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=args.devices if args.devices is not None else trainer_config.devices,
                         strategy=strategy, logger=wandb_logger,
                         accumulate_grad_batches=trainer_config.accumulate_grad_batches,
                         num_nodes=trainer_config.num_nodes,
                         callbacks=[model_checkpoint, lr_monitor],
                         max_epochs=max_epochs,
                         num_sanity_val_steps=0,
                         check_val_every_n_epoch=check_val_every_n_epoch,
                         log_every_n_steps=1,
                         gradient_clip_val=0.5)

    if args.ckpt_path == '':
        trainer.fit(model, datamodule)
    else:
        trainer.fit(model, datamodule, ckpt_path=args.ckpt_path)