Spaces:
Running
on
T4
Running
on
T4
""" | |
Integrate numerical values for some iterations | |
Typically used for loss computation / logging to tensorboard | |
Call finalize and create a new Integrator when you want to display/log | |
""" | |
from typing import Callable, Union | |
import torch | |
from mmaudio.utils.logger import TensorboardLogger | |
from mmaudio.utils.tensor_utils import distribute_into_histogram | |
class Integrator: | |
def __init__(self, logger: TensorboardLogger, distributed: bool = True): | |
self.values = {} | |
self.counts = {} | |
self.hooks = [] # List is used here to maintain insertion order | |
# for binned tensors | |
self.binned_tensors = {} | |
self.binned_tensor_indices = {} | |
self.logger = logger | |
self.distributed = distributed | |
self.local_rank = torch.distributed.get_rank() | |
self.world_size = torch.distributed.get_world_size() | |
def add_scalar(self, key: str, x: Union[torch.Tensor, int, float]): | |
if isinstance(x, torch.Tensor): | |
x = x.detach() | |
if x.dtype in [torch.long, torch.int, torch.bool]: | |
x = x.float() | |
if key not in self.values: | |
self.counts[key] = 1 | |
self.values[key] = x | |
else: | |
self.counts[key] += 1 | |
self.values[key] += x | |
def add_dict(self, tensor_dict: dict[str, torch.Tensor]): | |
for k, v in tensor_dict.items(): | |
self.add_scalar(k, v) | |
def add_binned_tensor(self, key: str, x: torch.Tensor, indices: torch.Tensor): | |
if key not in self.binned_tensors: | |
self.binned_tensors[key] = [x.detach().flatten()] | |
self.binned_tensor_indices[key] = [indices.detach().flatten()] | |
else: | |
self.binned_tensors[key].append(x.detach().flatten()) | |
self.binned_tensor_indices[key].append(indices.detach().flatten()) | |
def add_hook(self, hook: Callable[[torch.Tensor], tuple[str, torch.Tensor]]): | |
""" | |
Adds a custom hook, i.e. compute new metrics using values in the dict | |
The hook takes the dict as argument, and returns a (k, v) tuple | |
e.g. for computing IoU | |
""" | |
self.hooks.append(hook) | |
def reset_except_hooks(self): | |
self.values = {} | |
self.counts = {} | |
# Average and output the metrics | |
def finalize(self, prefix: str, it: int, ignore_timer: bool = False) -> None: | |
for hook in self.hooks: | |
k, v = hook(self.values) | |
self.add_scalar(k, v) | |
# for the metrics | |
outputs = {} | |
for k, v in self.values.items(): | |
avg = v / self.counts[k] | |
if self.distributed: | |
# Inplace operation | |
if isinstance(avg, torch.Tensor): | |
avg = avg.cuda() | |
else: | |
avg = torch.tensor(avg).cuda() | |
torch.distributed.reduce(avg, dst=0) | |
if self.local_rank == 0: | |
avg = (avg / self.world_size).cpu().item() | |
outputs[k] = avg | |
else: | |
# Simple does it | |
outputs[k] = avg | |
if (not self.distributed) or (self.local_rank == 0): | |
self.logger.log_metrics(prefix, outputs, it, ignore_timer=ignore_timer) | |
# for the binned tensors | |
for k, v in self.binned_tensors.items(): | |
x = torch.cat(v, dim=0) | |
indices = torch.cat(self.binned_tensor_indices[k], dim=0) | |
hist, count = distribute_into_histogram(x, indices) | |
if self.distributed: | |
torch.distributed.reduce(hist, dst=0) | |
torch.distributed.reduce(count, dst=0) | |
if self.local_rank == 0: | |
hist = hist / count | |
else: | |
hist = hist / count | |
if (not self.distributed) or (self.local_rank == 0): | |
self.logger.log_histogram(f'{prefix}/{k}', hist, it) | |