|
|
|
|
|
import argparse |
|
import math |
|
import os |
|
import os.path as osp |
|
from multiprocessing import Pool |
|
|
|
import torch |
|
from mmengine.config import Config |
|
from mmengine.utils import mkdir_or_exist |
|
|
|
|
|
def download(url, out_file, min_bytes=math.pow(1024, 2), progress=True): |
|
|
|
assert_msg = f"Downloaded url '{url}' does not exist " \ |
|
f'or size is < min_bytes={min_bytes}' |
|
try: |
|
print(f'Downloading {url} to {out_file}...') |
|
torch.hub.download_url_to_file(url, str(out_file), progress=progress) |
|
assert osp.exists( |
|
out_file) and osp.getsize(out_file) > min_bytes, assert_msg |
|
except Exception as e: |
|
if osp.exists(out_file): |
|
os.remove(out_file) |
|
print(f'ERROR: {e}\nRe-attempting {url} to {out_file} ...') |
|
os.system(f"curl -L '{url}' -o '{out_file}' --retry 3 -C -" |
|
) |
|
finally: |
|
if osp.exists(out_file) and osp.getsize(out_file) < min_bytes: |
|
os.remove(out_file) |
|
|
|
if not osp.exists(out_file): |
|
print(f'ERROR: {assert_msg}\n') |
|
print('=========================================\n') |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Download checkpoints') |
|
parser.add_argument('config', help='test config file path') |
|
parser.add_argument( |
|
'out', type=str, help='output dir of checkpoints to be stored') |
|
parser.add_argument( |
|
'--nproc', type=int, default=16, help='num of Processes') |
|
parser.add_argument( |
|
'--intranet', |
|
action='store_true', |
|
help='switch to internal network url') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
mkdir_or_exist(args.out) |
|
|
|
cfg = Config.fromfile(args.config) |
|
|
|
checkpoint_url_list = [] |
|
checkpoint_out_list = [] |
|
|
|
for model in cfg: |
|
model_infos = cfg[model] |
|
if not isinstance(model_infos, list): |
|
model_infos = [model_infos] |
|
for model_info in model_infos: |
|
checkpoint = model_info['checkpoint'] |
|
out_file = osp.join(args.out, checkpoint) |
|
if not osp.exists(out_file): |
|
|
|
url = model_info['url'] |
|
if args.intranet is True: |
|
url = url.replace('.com', '.sensetime.com') |
|
url = url.replace('https', 'http') |
|
|
|
checkpoint_url_list.append(url) |
|
checkpoint_out_list.append(out_file) |
|
|
|
if len(checkpoint_url_list) > 0: |
|
pool = Pool(min(os.cpu_count(), args.nproc)) |
|
pool.starmap(download, zip(checkpoint_url_list, checkpoint_out_list)) |
|
else: |
|
print('No files to download!') |
|
|