Spaces:
Running
on
Zero
Running
on
Zero
| import time | |
| from typing import Mapping, Union | |
| from ignite.contrib.handlers import TensorboardLogger | |
| from ignite.handlers import global_step_from_engine | |
| from ignite.contrib.handlers.base_logger import BaseHandler | |
| from ignite.engine import Engine, EventEnum, Events | |
| import torch | |
| def add_time_handlers(engine: Engine): | |
| iteration_time_handler = TimeHandler("iter", freq=True, period=True) | |
| batch_time_handler = TimeHandler("get_batch", freq=False, period=True) | |
| engine.add_event_handler( | |
| Events.ITERATION_STARTED, iteration_time_handler.start_timing | |
| ) | |
| engine.add_event_handler( | |
| Events.ITERATION_COMPLETED, iteration_time_handler.end_timing | |
| ) | |
| engine.add_event_handler(Events.GET_BATCH_STARTED, batch_time_handler.start_timing) | |
| engine.add_event_handler(Events.GET_BATCH_COMPLETED, batch_time_handler.end_timing) | |
| class MetricLoggingHandler(BaseHandler): | |
| def __init__( | |
| self, | |
| tag, | |
| optimizer=None, | |
| log_loss=True, | |
| log_metrics=True, | |
| log_timings=True, | |
| global_step_transform=None, | |
| ): | |
| self.tag = tag | |
| self.optimizer = optimizer | |
| self.log_loss = log_loss | |
| self.log_metrics = log_metrics | |
| self.log_timings = log_timings | |
| self.gst = global_step_transform | |
| super(MetricLoggingHandler, self).__init__() | |
| def __call__( | |
| self, | |
| engine: Engine, | |
| logger: TensorboardLogger, | |
| event_name: Union[str, EventEnum], | |
| ): | |
| if not isinstance(logger, TensorboardLogger): | |
| raise RuntimeError( | |
| "Handler 'MetricLoggingHandler' works only with TensorboardLogger" | |
| ) | |
| if self.gst is None: | |
| gst = global_step_from_engine(engine) | |
| else: | |
| gst = self.gst | |
| global_step = gst(engine, event_name) # type: ignore[misc] | |
| if not isinstance(global_step, int): | |
| raise TypeError( | |
| f"global_step must be int, got {type(global_step)}." | |
| " Please check the output of global_step_transform." | |
| ) | |
| writer = logger.writer | |
| # Optimizer parameters | |
| if self.optimizer is not None: | |
| params = { | |
| k: float(param_group["lr"]) | |
| for k, param_group in enumerate(self.optimizer.param_groups) | |
| } | |
| for k, param in params.items(): | |
| writer.add_scalar(f"lr-{self.tag}/{k}", param, global_step) | |
| if self.log_loss: | |
| # Plot losses | |
| loss_dict = engine.state.output["loss_dict"] | |
| for k, v in loss_dict.items(): | |
| # TODO: is this needed? | |
| # if not isinstance(v, (float, int)): | |
| # print(f"{k}: {type(v)}") | |
| writer.add_scalar(f"loss-{self.tag}/{k}", v, global_step) | |
| if self.log_metrics: | |
| # Plot metrics | |
| metrics_dict = engine.state.metrics | |
| metrics_dict_custom = engine.state.output["metrics_dict"] | |
| for k, v in metrics_dict.items(): | |
| # Avoid dictionaries because of weird ignite handling of Mapping metrics | |
| if isinstance(v, Mapping) or k.endswith("assignment"): # TODO: Remove hard-coded assignment | |
| continue | |
| if isinstance(v, torch.Tensor) and v.ndim > 0: | |
| writer.add_histogram(f"metrics-{self.tag}/{k}", v, global_step) | |
| else: | |
| writer.add_scalar(f"metrics-{self.tag}/{k}", v, global_step) | |
| for k, v in metrics_dict_custom.items(): | |
| if isinstance(v, Mapping): | |
| continue | |
| if isinstance(v, torch.Tensor) and v.ndim > 0: | |
| writer.add_histogram(f"metrics-{self.tag}/{k}", v, global_step) | |
| else: | |
| writer.add_scalar(f"metrics-{self.tag}/{k}", v, global_step) | |
| if self.log_timings: | |
| # Plot timings | |
| timings_dict = engine.state.times | |
| timings_dict_custom = engine.state.output["timings_dict"] | |
| for k, v in timings_dict.items(): | |
| if k == "COMPLETED": | |
| continue | |
| writer.add_scalar(f"timing-{self.tag}/{k}", v, global_step) | |
| for k, v in timings_dict_custom.items(): | |
| writer.add_scalar(f"timing-{self.tag}/{k}", v, global_step) | |
| engine.state.output = None # For memory efficiency, val results do not need to stay in the state | |
| class TimeHandler: | |
| def __init__(self, name: str, freq: bool = False, period: bool = False) -> None: | |
| self.name = name | |
| self.freq = freq | |
| self.period = period | |
| if not self.period and not self.freq: | |
| print(f"Warning: No timings logged for {name}") | |
| self._start_time = None | |
| def start_timing(self, engine): | |
| self._start_time = time.time() | |
| def end_timing(self, engine): | |
| if self._start_time is None: | |
| period = 0 | |
| freq = 0 | |
| else: | |
| period = max(time.time() - self._start_time, 1e-6) | |
| freq = 1 / period | |
| if not hasattr(engine.state, "times"): | |
| engine.state.times = {} | |
| else: | |
| if self.period: | |
| engine.state.times[f"secs_per_{self.name}"] = period | |
| if self.freq: | |
| engine.state.times[f"num_{self.name}_per_sec"] = freq | |
| class VisualizationHandler(BaseHandler): | |
| def __init__(self, tag, visualizer, global_step_transform=None): | |
| self.tag = tag | |
| self.visualizer = visualizer | |
| self.gst = global_step_transform | |
| super(VisualizationHandler, self).__init__() | |
| def __call__( | |
| self, | |
| engine: Engine, | |
| logger: TensorboardLogger, | |
| event_name: Union[str, EventEnum], | |
| ) -> None: | |
| if not isinstance(logger, TensorboardLogger): | |
| raise RuntimeError( | |
| "Handler 'VisualizationHandler' works only with TensorboardLogger" | |
| ) | |
| if self.gst is None: | |
| gst = global_step_from_engine(engine) | |
| else: | |
| gst = self.gst | |
| global_step = gst(engine, event_name) # type: ignore[misc] | |
| if not isinstance(global_step, int): | |
| raise TypeError( | |
| f"global_step must be int, got {type(global_step)}." | |
| " Please check the output of global_step_transform." | |
| ) | |
| self.visualizer(engine, logger, global_step, self.tag) | |