Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| from typing import Callable, List, Optional, Union | |
| import torch | |
| from ..dist_utils import master_only | |
| from .hook import HOOKS, Hook | |
| class ProfilerHook(Hook): | |
| """Profiler to analyze performance during training. | |
| PyTorch Profiler is a tool that allows the collection of the performance | |
| metrics during the training. More details on Profiler can be found at | |
| https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile | |
| Args: | |
| by_epoch (bool): Profile performance by epoch or by iteration. | |
| Default: True. | |
| profile_iters (int): Number of iterations for profiling. | |
| If ``by_epoch=True``, profile_iters indicates that they are the | |
| first profile_iters epochs at the beginning of the | |
| training, otherwise it indicates the first profile_iters | |
| iterations. Default: 1. | |
| activities (list[str]): List of activity groups (CPU, CUDA) to use in | |
| profiling. Default: ['cpu', 'cuda']. | |
| schedule (dict, optional): Config of generating the callable schedule. | |
| if schedule is None, profiler will not add step markers into the | |
| trace and table view. Default: None. | |
| on_trace_ready (callable, dict): Either a handler or a dict of generate | |
| handler. Default: None. | |
| record_shapes (bool): Save information about operator's input shapes. | |
| Default: False. | |
| profile_memory (bool): Track tensor memory allocation/deallocation. | |
| Default: False. | |
| with_stack (bool): Record source information (file and line number) | |
| for the ops. Default: False. | |
| with_flops (bool): Use formula to estimate the FLOPS of specific | |
| operators (matrix multiplication and 2D convolution). | |
| Default: False. | |
| json_trace_path (str, optional): Exports the collected trace in Chrome | |
| JSON format. Default: None. | |
| Example: | |
| >>> runner = ... # instantiate a Runner | |
| >>> # tensorboard trace | |
| >>> trace_config = dict(type='tb_trace', dir_name='work_dir') | |
| >>> profiler_config = dict(on_trace_ready=trace_config) | |
| >>> runner.register_profiler_hook(profiler_config) | |
| >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)]) | |
| """ | |
| def __init__(self, | |
| by_epoch: bool = True, | |
| profile_iters: int = 1, | |
| activities: List[str] = ['cpu', 'cuda'], | |
| schedule: Optional[dict] = None, | |
| on_trace_ready: Optional[Union[Callable, dict]] = None, | |
| record_shapes: bool = False, | |
| profile_memory: bool = False, | |
| with_stack: bool = False, | |
| with_flops: bool = False, | |
| json_trace_path: Optional[str] = None) -> None: | |
| try: | |
| from torch import profiler # torch version >= 1.8.1 | |
| except ImportError: | |
| raise ImportError('profiler is the new feature of torch1.8.1, ' | |
| f'but your version is {torch.__version__}') | |
| assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.' | |
| self.by_epoch = by_epoch | |
| if profile_iters < 1: | |
| raise ValueError('profile_iters should be greater than 0, but got ' | |
| f'{profile_iters}') | |
| self.profile_iters = profile_iters | |
| if not isinstance(activities, list): | |
| raise ValueError( | |
| f'activities should be list, but got {type(activities)}') | |
| self.activities = [] | |
| for activity in activities: | |
| activity = activity.lower() | |
| if activity == 'cpu': | |
| self.activities.append(profiler.ProfilerActivity.CPU) | |
| elif activity == 'cuda': | |
| self.activities.append(profiler.ProfilerActivity.CUDA) | |
| else: | |
| raise ValueError( | |
| f'activity should be "cpu" or "cuda", but got {activity}') | |
| if schedule is not None: | |
| self.schedule = profiler.schedule(**schedule) | |
| else: | |
| self.schedule = None | |
| self.on_trace_ready = on_trace_ready | |
| self.record_shapes = record_shapes | |
| self.profile_memory = profile_memory | |
| self.with_stack = with_stack | |
| self.with_flops = with_flops | |
| self.json_trace_path = json_trace_path | |
| def before_run(self, runner): | |
| if self.by_epoch and runner.max_epochs < self.profile_iters: | |
| raise ValueError('self.profile_iters should not be greater than ' | |
| f'{runner.max_epochs}') | |
| if not self.by_epoch and runner.max_iters < self.profile_iters: | |
| raise ValueError('self.profile_iters should not be greater than ' | |
| f'{runner.max_iters}') | |
| if callable(self.on_trace_ready): # handler | |
| _on_trace_ready = self.on_trace_ready | |
| elif isinstance(self.on_trace_ready, dict): # config of handler | |
| trace_cfg = self.on_trace_ready.copy() | |
| trace_type = trace_cfg.pop('type') # log_trace handler | |
| if trace_type == 'log_trace': | |
| def _log_handler(prof): | |
| print(prof.key_averages().table(**trace_cfg)) | |
| _on_trace_ready = _log_handler | |
| elif trace_type == 'tb_trace': # tensorboard_trace handler | |
| try: | |
| import torch_tb_profiler # noqa: F401 | |
| except ImportError: | |
| raise ImportError('please run "pip install ' | |
| 'torch-tb-profiler" to install ' | |
| 'torch_tb_profiler') | |
| _on_trace_ready = torch.profiler.tensorboard_trace_handler( | |
| **trace_cfg) | |
| else: | |
| raise ValueError('trace_type should be "log_trace" or ' | |
| f'"tb_trace", but got {trace_type}') | |
| elif self.on_trace_ready is None: | |
| _on_trace_ready = None # type: ignore | |
| else: | |
| raise ValueError('on_trace_ready should be handler, dict or None, ' | |
| f'but got {type(self.on_trace_ready)}') | |
| if runner.max_epochs > 1: | |
| warnings.warn(f'profiler will profile {runner.max_epochs} epochs ' | |
| 'instead of 1 epoch. Since profiler will slow down ' | |
| 'the training, it is recommended to train 1 epoch ' | |
| 'with ProfilerHook and adjust your setting according' | |
| ' to the profiler summary. During normal training ' | |
| '(epoch > 1), you may disable the ProfilerHook.') | |
| self.profiler = torch.profiler.profile( | |
| activities=self.activities, | |
| schedule=self.schedule, | |
| on_trace_ready=_on_trace_ready, | |
| record_shapes=self.record_shapes, | |
| profile_memory=self.profile_memory, | |
| with_stack=self.with_stack, | |
| with_flops=self.with_flops) | |
| self.profiler.__enter__() | |
| runner.logger.info('profiler is profiling...') | |
| def after_train_epoch(self, runner): | |
| if self.by_epoch and runner.epoch == self.profile_iters - 1: | |
| runner.logger.info('profiler may take a few minutes...') | |
| self.profiler.__exit__(None, None, None) | |
| if self.json_trace_path is not None: | |
| self.profiler.export_chrome_trace(self.json_trace_path) | |
| def after_train_iter(self, runner): | |
| self.profiler.step() | |
| if not self.by_epoch and runner.iter == self.profile_iters - 1: | |
| runner.logger.info('profiler may take a few minutes...') | |
| self.profiler.__exit__(None, None, None) | |
| if self.json_trace_path is not None: | |
| self.profiler.export_chrome_trace(self.json_trace_path) | |