|
import argparse |
|
import os.path as osp |
|
import itertools |
|
from omegaconf import OmegaConf |
|
from paintmind.engine.util import instantiate_from_config |
|
from paintmind.utils.device_utils import configure_compute_backend |
|
|
|
def parse_args(): |
|
"""Parse command line arguments.""" |
|
parser = argparse.ArgumentParser("Test a model") |
|
|
|
|
|
parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model directory") |
|
parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number to test") |
|
parser.add_argument('--cfg', type=str, default=None, help="Path to config file") |
|
parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use") |
|
|
|
|
|
parser.add_argument('--cfg_value', type=float, nargs='+', default=[None], |
|
help='Legacy parameter for GPT classifier-free guidance scale') |
|
parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None], |
|
help="Autoencoder classifier-free guidance scale") |
|
parser.add_argument('--diff_cfg', type=float, nargs='+', default=[None], |
|
help="Diffusion classifier-free guidance scale") |
|
parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None], |
|
help="CFG schedule type (e.g., constant, linear)") |
|
parser.add_argument('--diff_cfg_schedule', type=str, nargs='+', default=[None], |
|
help="Diffusion CFG schedule type (e.g., constant, inv_linear)") |
|
parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None], |
|
help="Number of slots to use for inference") |
|
parser.add_argument('--temperature', type=float, nargs='+', default=[None], |
|
help="Temperature for sampling") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def load_config(model_path, cfg_path=None): |
|
"""Load configuration from file or model directory.""" |
|
if cfg_path is not None and osp.exists(cfg_path): |
|
config_path = cfg_path |
|
elif model_path and osp.exists(osp.join(model_path, 'config.yaml')): |
|
config_path = osp.join(model_path, 'config.yaml') |
|
else: |
|
raise ValueError(f"No config file found at {model_path} or {cfg_path}") |
|
|
|
return OmegaConf.load(config_path) |
|
|
|
|
|
def setup_checkpoint_path(model_path, step, config): |
|
"""Set up the checkpoint path based on model and step.""" |
|
if model_path: |
|
ckpt_path = osp.join(model_path, 'models', f'step{step}') |
|
if not osp.exists(ckpt_path): |
|
print(f"Skipping non-existent checkpoint: {ckpt_path}") |
|
return None |
|
if hasattr(config.trainer.params, 'model'): |
|
config.trainer.params.model.params.ckpt_path = ckpt_path |
|
else: |
|
config.trainer.params.gpt_model.params.ckpt_path = ckpt_path |
|
else: |
|
result_folder = config.trainer.params.result_folder |
|
ckpt_path = osp.join(result_folder, 'models', f'step{step}') |
|
if hasattr(config.trainer.params, 'model'): |
|
config.trainer.params.model.params.ckpt_path = ckpt_path |
|
else: |
|
config.trainer.params.gpt_model.params.ckpt_path = ckpt_path |
|
|
|
return ckpt_path |
|
|
|
|
|
def setup_test_config(config, use_coco=False): |
|
"""Set up common test configuration parameters.""" |
|
config.trainer.params.test_dataset = config.trainer.params.dataset |
|
if not use_coco: |
|
config.trainer.params.test_dataset.params.split = 'val' |
|
else: |
|
config.trainer.params.test_dataset.target = 'paintmind.utils.datasets.COCO' |
|
config.trainer.params.test_dataset.params.root = './dataset/coco' |
|
config.trainer.params.test_dataset.params.split = 'val2017' |
|
config.trainer.params.test_only = True |
|
config.trainer.params.compile = False |
|
config.trainer.params.eval_fid = True |
|
config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz' |
|
if hasattr(config.trainer.params, 'model'): |
|
config.trainer.params.model.params.num_sampling_steps = '250' |
|
else: |
|
config.trainer.params.ae_model.params.num_sampling_steps = '250' |
|
|
|
|
|
def apply_cfg_params(config, param_dict): |
|
"""Apply CFG-related parameters to the config.""" |
|
|
|
if param_dict.get('cfg_value') is not None: |
|
config.trainer.params.cfg = param_dict['cfg_value'] |
|
print(f"Setting cfg to {param_dict['cfg_value']}") |
|
|
|
if param_dict.get('ae_cfg') is not None: |
|
config.trainer.params.ae_cfg = param_dict['ae_cfg'] |
|
print(f"Setting ae_cfg to {param_dict['ae_cfg']}") |
|
|
|
if param_dict.get('diff_cfg') is not None: |
|
config.trainer.params.diff_cfg = param_dict['diff_cfg'] |
|
print(f"Setting diff_cfg to {param_dict['diff_cfg']}") |
|
|
|
if param_dict.get('cfg_schedule') is not None: |
|
config.trainer.params.cfg_schedule = param_dict['cfg_schedule'] |
|
print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}") |
|
|
|
if param_dict.get('diff_cfg_schedule') is not None: |
|
config.trainer.params.diff_cfg_schedule = param_dict['diff_cfg_schedule'] |
|
print(f"Setting diff_cfg_schedule to {param_dict['diff_cfg_schedule']}") |
|
|
|
if param_dict.get('test_num_slots') is not None: |
|
config.trainer.params.test_num_slots = param_dict['test_num_slots'] |
|
print(f"Setting test_num_slots to {param_dict['test_num_slots']}") |
|
|
|
if param_dict.get('temperature') is not None: |
|
config.trainer.params.temperature = param_dict['temperature'] |
|
print(f"Setting temperature to {param_dict['temperature']}") |
|
|
|
|
|
def run_test(config): |
|
"""Instantiate trainer and run test.""" |
|
trainer = instantiate_from_config(config.trainer) |
|
trainer.train() |
|
|
|
|
|
def generate_param_combinations(args): |
|
"""Generate all combinations of parameters from the provided arguments.""" |
|
|
|
param_grid = { |
|
'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value, |
|
'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg, |
|
'diff_cfg': [None] if args.diff_cfg == [None] else args.diff_cfg, |
|
'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule, |
|
'diff_cfg_schedule': [None] if args.diff_cfg_schedule == [None] else args.diff_cfg_schedule, |
|
'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots, |
|
'temperature': [None] if args.temperature == [None] else args.temperature |
|
} |
|
|
|
|
|
active_params = [k for k, v in param_grid.items() if v != [None]] |
|
|
|
if not active_params: |
|
|
|
yield {k: None for k in param_grid.keys()} |
|
return |
|
|
|
|
|
active_values = [param_grid[k] for k in active_params] |
|
for combination in itertools.product(*active_values): |
|
param_dict = {k: None for k in param_grid.keys()} |
|
for i, param_name in enumerate(active_params): |
|
param_dict[param_name] = combination[i] |
|
yield param_dict |
|
|
|
|
|
def test(args): |
|
"""Main test function that processes arguments and runs tests.""" |
|
|
|
for model in args.model: |
|
for step in args.step: |
|
print(f"Testing model: {model} at step: {step}") |
|
|
|
|
|
config = load_config(model, args.cfg) |
|
|
|
|
|
ckpt_path = setup_checkpoint_path(model, step, config) |
|
if ckpt_path is None: |
|
continue |
|
|
|
use_coco = args.dataset == 'coco' or args.dataset == 'COCO' |
|
|
|
setup_test_config(config, use_coco) |
|
|
|
|
|
for param_dict in generate_param_combinations(args): |
|
|
|
current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True)) |
|
|
|
|
|
param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None]) |
|
print(f"Testing with parameters: {param_str}") |
|
|
|
|
|
apply_cfg_params(current_config, param_dict) |
|
run_test(current_config) |
|
|
|
|
|
def main(): |
|
"""Main entry point for the script.""" |
|
args = parse_args() |
|
configure_compute_backend() |
|
test(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|