File size: 3,637 Bytes
824afbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import torch as T
import re
from tqdm import tqdm
from datetime import timedelta
import requests
import hashlib
from io import BytesIO
def rank0():
rank = os.environ.get('RANK')
if rank is None or rank == '0':
return True
return False
def local0():
local_rank = os.environ.get('LOCAL_RANK')
if local_rank is None or local_rank == '0':
return True
return False
class tqdm0(tqdm):
def __init__(self, *args, **kwargs):
total = kwargs.get('total', None)
if total is None and len(args) > 0:
total = len(args[0])
except TypeError:
if total is not None:
kwargs['miniters'] = max(1, total // 20)
super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
def print0(*args, **kwargs):
if rank0():
print(*args, **kwargs)
_PRINTED_IDS = set()
def printonce(*args, id=None, **kwargs):
if id is None:
id = ' '.join(map(str, args))
if id not in _PRINTED_IDS:
print(*args, **kwargs)
def print0once(*args, **kwargs):
if rank0():
printonce(*args, **kwargs)
def init_dist():
if T.distributed.is_initialized():
print0('Distributed already initialized')
rank = T.distributed.get_rank()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
world_size = T.distributed.get_world_size()
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{local_rank}'
T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
print(f'Rank {rank} of {world_size}.')
except Exception as e:
print0once(f'Not initializing distributed env: {e}')
rank = 0
local_rank = 0
world_size = 1
return rank, local_rank, world_size
def load_ckpt(load_from_location, expected_hash=None):
if local0():
os.makedirs('ckpt', exist_ok=True)
url = f"{load_from_location}.pt"
save_path = f"ckpt/{load_from_location}.pt"
if not os.path.exists(save_path):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
for chunk in response.iter_content(chunk_size=8192):
if expected_hash is not None:
with open(save_path, 'rb') as f:
file_hash = hashlib.md5(
if file_hash != expected_hash:
print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
return load_ckpt(load_from_location, expected_hash)
if T.distributed.is_initialized():
T.distributed.barrier() # so that ranks don't try to load checkpoint before it's finished downloading
loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu')
return loaded |