Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import torch as th | |
| import os | |
| import re | |
| import glob | |
| import copy | |
| from typing import Dict, Any, Iterator, Mapping, Optional, Union, Tuple, List | |
| from collections import OrderedDict | |
| from torch.utils.tensorboard import SummaryWriter | |
| from omegaconf import OmegaConf, DictConfig | |
| from torch.optim.lr_scheduler import LRScheduler | |
| from visualize.ca_body.utils.torch import to_device | |
| from visualize.ca_body.utils.module_loader import load_class, build_optimizer | |
| import torch.nn as nn | |
| import logging | |
| logging.basicConfig( | |
| format="[%(asctime)s][%(levelname)s][%(name)s]:%(message)s", | |
| level=logging.INFO, | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def process_losses( | |
| loss_dict: Dict[str, Any], reduce: bool = True, detach: bool = True | |
| ) -> Dict[str, th.Tensor]: | |
| """Preprocess the dict of losses outputs.""" | |
| result = {k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")} | |
| if detach: | |
| result = {k: v.detach() for k, v in result.items()} | |
| if reduce: | |
| result = {k: float(v.mean().item()) for k, v in result.items()} | |
| return result | |
| def load_config(path: str) -> DictConfig: | |
| # NOTE: THIS IS THE ONLY PLACE WHERE WE MODIFY CONFIG | |
| config = OmegaConf.load(path) | |
| # TODO: we should need to get rid of this in favor of DB | |
| assert 'CARE_ROOT' in os.environ | |
| config.CARE_ROOT = os.environ['CARE_ROOT'] | |
| logger.info(f'{config.CARE_ROOT=}') | |
| if not os.path.isabs(config.train.run_dir): | |
| config.train.run_dir = os.path.join(os.environ['CARE_ROOT'], config.train.run_dir) | |
| logger.info(f'{config.train.run_dir=}') | |
| os.makedirs(config.train.run_dir, exist_ok=True) | |
| return config | |
| def load_from_config(config: Mapping[str, Any], **kwargs): | |
| """Instantiate an object given a config and arguments.""" | |
| assert 'class_name' in config and 'module_name' not in config | |
| config = copy.deepcopy(config) | |
| ckpt = None if 'ckpt' not in config else config.pop('ckpt') | |
| class_name = config.pop('class_name') | |
| object_class = load_class(class_name) | |
| instance = object_class(**config, **kwargs) | |
| if ckpt is not None: | |
| load_checkpoint( | |
| ckpt_path=ckpt.path, | |
| modules={ckpt.get('module_name', 'model'): instance}, | |
| ignore_names=ckpt.get('ignore_names', []), | |
| strict=ckpt.get('strict', False), | |
| ) | |
| return instance | |
| def save_checkpoint(ckpt_path, modules: Dict[str, Any], iteration=None, keep_last_k=None): | |
| if keep_last_k is not None: | |
| raise NotImplementedError() | |
| ckpt_dict = {} | |
| if os.path.isdir(ckpt_path): | |
| assert iteration is not None | |
| ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt") | |
| ckpt_dict["iteration"] = iteration | |
| for name, mod in modules.items(): | |
| if hasattr(mod, "module"): | |
| mod = mod.module | |
| ckpt_dict[name] = mod.state_dict() | |
| th.save(ckpt_dict, ckpt_path) | |
| def filter_params(params, ignore_names): | |
| return OrderedDict( | |
| [ | |
| (k, v) | |
| for k, v in params.items() | |
| if not any([re.match(n, k) is not None for n in ignore_names]) | |
| ] | |
| ) | |
| def save_file_summaries(path: str, summaries: Dict[str, Tuple[str, Any]]): | |
| """Saving regular summaries for monitoring purposes.""" | |
| for name, (value, ext) in summaries.items(): | |
| #save(f'{path}/{name}.{ext}', value) | |
| raise NotImplementedError() | |
| def load_checkpoint( | |
| ckpt_path: str, | |
| modules: Dict[str, Any], | |
| iteration: int =None, | |
| strict: bool =False, | |
| map_location: Optional[str] =None, | |
| ignore_names: Optional[Dict[str, List[str]]]=None, | |
| ): | |
| """Load a checkpoint. | |
| Args: | |
| ckpt_path: directory or the full path to the checkpoint | |
| """ | |
| if map_location is None: | |
| map_location = "cpu" | |
| # adding | |
| if os.path.isdir(ckpt_path): | |
| if iteration is None: | |
| # lookup latest iteration | |
| iteration = max( | |
| [ | |
| int(os.path.splitext(os.path.basename(p))[0]) | |
| for p in glob.glob(os.path.join(ckpt_path, "*.pt")) | |
| ] | |
| ) | |
| ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt") | |
| logger.info(f"loading checkpoint {ckpt_path}") | |
| ckpt_dict = th.load(ckpt_path, map_location=map_location) | |
| for name, mod in modules.items(): | |
| params = ckpt_dict[name] | |
| if ignore_names is not None and name in ignore_names: | |
| logger.info(f"skipping: {ignore_names[name]}") | |
| params = filter_params(params, ignore_names[name]) | |
| mod.load_state_dict(params, strict=strict) | |
| def train( | |
| model: nn.Module, | |
| loss_fn: nn.Module, | |
| optimizer: th.optim.Optimizer, | |
| train_data: Iterator, | |
| config: Mapping[str, Any], | |
| lr_scheduler: Optional[LRScheduler] = None, | |
| train_writer: Optional[SummaryWriter] = None, | |
| saving_enabled: bool = True, | |
| logging_enabled: bool = True, | |
| iteration: int = 0, | |
| device: Optional[Union[th.device, str]] = "cuda:0", | |
| ) -> None: | |
| for batch in train_data: | |
| if batch is None: | |
| logger.info("skipping empty batch") | |
| continue | |
| batch = to_device(batch, device) | |
| batch["iteration"] = iteration | |
| # leaving only inputs acutally used by the model | |
| preds = model(**filter_inputs(batch, model, required_only=False)) | |
| # TODO: switch to the old-school loss computation | |
| loss, loss_dict = loss_fn(preds, batch, iteration=iteration) | |
| assert not th.isnan(loss), "loss is NaN" | |
| if th.isnan(loss): | |
| _loss_dict = process_losses(loss_dict) | |
| loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) | |
| logger.info(f"iter={iteration}: {loss_str}") | |
| raise ValueError("loss is NaN") | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if logging_enabled and iteration % config.train.log_every_n_steps == 0: | |
| _loss_dict = process_losses(loss_dict) | |
| loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) | |
| logger.info(f"iter={iteration}: {loss_str}") | |
| if logging_enabled and train_writer and iteration % config.train.log_every_n_steps == 0: | |
| for name, value in _loss_dict.items(): | |
| train_writer.add_scalar(f"Losses/{name}", value, global_step=iteration) | |
| train_writer.flush() | |
| if saving_enabled and iteration % config.train.ckpt_every_n_steps == 0: | |
| logger.info(f"iter={iteration}: saving checkpoint to `{config.train.ckpt_dir}`") | |
| save_checkpoint( | |
| config.train.ckpt_dir, | |
| {"model": model, "optimizer": optimizer}, | |
| iteration=iteration, | |
| ) | |
| if logging_enabled and iteration % config.train.summary_every_n_steps == 0: | |
| summaries = model.compute_summaries(preds, batch) | |
| save_file_summaries(config.train.run_dir, summaries, prefix="train") | |
| if lr_scheduler is not None and iteration and iteration % config.train.update_lr_every == 0: | |
| lr_scheduler.step() | |
| iteration += 1 | |
| if iteration >= config.train.n_max_iters: | |
| logger.info(f"reached max number of iters ({config.train.n_max_iters})") | |
| break | |
| if saving_enabled: | |
| logger.info(f"saving the final checkpoint to `{config.train.run_dir}/model.pt`") | |
| save_checkpoint(f"{config.train.run_dir}/model.pt", {"model": model}) | |