Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import time | |
import numpy as np | |
import random | |
import os | |
import socket | |
import typing as tp | |
import torch | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
# Change this to reflect your cluster layout. | |
# The GPU for a given rank is (rank % GPUS_PER_NODE). | |
GPUS_PER_NODE = 8 | |
SETUP_RETRY_COUNT = 3 | |
used_device = 0 | |
def setup(rank, world_size): | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12355" | |
# initialize the process group | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
def cleanup(): | |
dist.destroy_process_group() | |
def setup_dist(device=0): | |
""" | |
Setup a distributed process group. | |
""" | |
global used_device | |
used_device = device | |
if dist.is_initialized(): | |
return | |
def dev(): | |
""" | |
Get the device to use for torch.distributed. | |
""" | |
global used_device | |
if torch.cuda.is_available() and used_device >= 0: | |
return torch.device(f"cuda:{used_device}") | |
return torch.device("cpu") | |
def load_state_dict(path, **kwargs): | |
""" | |
Load a PyTorch file without redundant fetches across MPI ranks. | |
""" | |
return torch.load(path, **kwargs) | |
def sync_params(params): | |
""" | |
Synchronize a sequence of Tensors across ranks from rank 0. | |
""" | |
for p in params: | |
with torch.no_grad(): | |
dist.broadcast(p, 0) | |
def _find_free_port(): | |
try: | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.bind(("", 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return s.getsockname()[1] | |
finally: | |
s.close() | |
def world_size(): | |
if torch.distributed.is_initialized(): | |
return torch.distributed.get_world_size() | |
else: | |
return 1 | |
def is_distributed(): | |
return world_size() > 1 | |
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): | |
if is_distributed(): | |
return torch.distributed.all_reduce(tensor, op) | |
def _is_complex_or_float(tensor): | |
return torch.is_floating_point(tensor) or torch.is_complex(tensor) | |
def _check_number_of_params(params: tp.List[torch.Tensor]): | |
# utility function to check that the number of params in all workers is the same, | |
# and thus avoid a deadlock with distributed all reduce. | |
if not is_distributed() or not params: | |
return | |
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) | |
all_reduce(tensor) | |
if tensor.item() != len(params) * world_size(): | |
# If not all the workers have the same number, for at least one of them, | |
# this inequality will be verified. | |
raise RuntimeError( | |
f"Mismatch in number of params: ours is {len(params)}, " | |
"at least one worker has a different one." | |
) | |
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): | |
"""Broadcast the tensors from the given parameters to all workers. | |
This can be used to ensure that all workers have the same model to start with. | |
""" | |
if not is_distributed(): | |
return | |
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] | |
_check_number_of_params(tensors) | |
handles = [] | |
for tensor in tensors: | |
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) | |
handles.append(handle) | |
for handle in handles: | |
handle.wait() | |
def fixseed(seed): | |
torch.backends.cudnn.benchmark = False | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
def prGreen(skk): | |
print("\033[92m {}\033[00m".format(skk)) | |
def prRed(skk): | |
print("\033[91m {}\033[00m".format(skk)) | |
def to_numpy(tensor): | |
if torch.is_tensor(tensor): | |
return tensor.cpu().numpy() | |
elif type(tensor).__module__ != "numpy": | |
raise ValueError("Cannot convert {} to numpy array".format(type(tensor))) | |
return tensor | |
def to_torch(ndarray): | |
if type(ndarray).__module__ == "numpy": | |
return torch.from_numpy(ndarray) | |
elif not torch.is_tensor(ndarray): | |
raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray))) | |
return ndarray | |
def cleanexit(): | |
import sys | |
import os | |
try: | |
sys.exit(0) | |
except SystemExit: | |
os._exit(0) | |
def load_model_wo_clip(model, state_dict): | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
assert len(unexpected_keys) == 0 | |
assert all([k.startswith("clip_model.") for k in missing_keys]) | |
def freeze_joints(x, joints_to_freeze): | |
# Freezes selected joint *rotations* as they appear in the first frame | |
# x [bs, [root+n_joints], joint_dim(6), seqlen] | |
frozen = x.detach().clone() | |
frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] | |
return frozen | |
class TimerError(Exception): | |
"""A custom exception used to report errors in use of Timer class""" | |
class Timer: | |
def __init__(self): | |
self._start_time = None | |
def start(self): | |
"""Start a new timer""" | |
if self._start_time is not None: | |
raise TimerError(f"Timer is running. Use .stop() to stop it") | |
self._start_time = time.perf_counter() | |
def stop(self, iter=None): | |
"""Stop the timer, and report the elapsed time""" | |
if self._start_time is None: | |
raise TimerError(f"Timer is not running. Use .start() to start it") | |
elapsed_time = time.perf_counter() - self._start_time | |
self._start_time = None | |
iter_msg = "" | |
if iter is not None: | |
if iter > elapsed_time: | |
iter_per_sec = iter / elapsed_time | |
iter_msg = f"[iter/s: {iter_per_sec:0.4f}]" | |
else: | |
sec_per_iter = elapsed_time / iter | |
iter_msg = f"[s/iter: {sec_per_iter:0.4f}]" | |
print(f"Elapsed time: {elapsed_time:0.4f} seconds {iter_msg}") | |