jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 76 files
260ff53 verified
#! /usr/bin/env python
import argparse
from functools import partial
import json
from pathlib import Path
import subprocess
from typing import Optional
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("host", type=str)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("--user", type=Optional[str])
parser.add_argument("--rootdir", type=Optional[str])
args = parser.parse_args()
run = partial(subprocess.run, shell=True, check=True, text=True)
host = args.host if args.user is None else f"{args.user}@{args.host}"
def run_remote_cmd(cmd):
return subprocess.check_output(f"ssh {host} {cmd}", shell=True, text=True)
def ls(p):
out = run_remote_cmd(f"ls {p}")
return out.strip().split("\n")[::-1]
def ask(l, info=None):
print(
"\n".join(
[
f"{i:{len(str(len(l)))}d}: {d}"
+ (f" ({info[d]})" if info is not None else "")
for i, d in enumerate(l, 1)
]
)
)
while True:
i = input("\nEnter a number: ")
if i.isdigit() and 1 <= int(i) <= len(l):
break
print("\n/!\\ Invalid choice\n")
return l[int(i) - 1]
def ask_if_verbose(question, default):
if not args.verbose:
return default
suffix = "[Y|n]" if default else "[y|N]"
answer = input(f"{question} {suffix} ").lower()
return (answer != "n") if default else (answer == "y")
def get_info(rundir):
return json.loads(
run_remote_cmd(f"cat {rundir}/checkpoints/info_for_import_script.json")
)
if args.rootdir is None:
for p in Path(__file__).resolve().parents:
if (p / ".git").is_dir():
break
else:
raise RuntimeError("This file is not in a git repository")
out = run_remote_cmd(f"find -type d -name {p.name}").strip().split("\n")
assert len(out) == 1
rootdir = out[0]
else:
rootdir = f'{args.rootdir.strip().strip("/")}'
dates = ls(f"{rootdir}/outputs")
date = ask(dates)
times = ls(f"{rootdir}/outputs/{date}")
infos = {
time: get_info(rundir=f"{rootdir}/outputs/{date}/{time}") for time in times
}
time = ask(times, infos)
src = f"{rootdir}/outputs/{date}/{time}"
dst = Path(args.host) / date
dst.mkdir(exist_ok=True, parents=True)
exclude = [
"*.log",
"checkpoints/*",
"checkpoints_tmp",
".hydra",
"media",
"__pycache__",
"wandb",
]
include = ["checkpoints/agent_versions"]
if ask_if_verbose("Download only last checkpoint?", default=True):
last_ckpt = ls(f"{src}/checkpoints/agent_versions")[0]
exclude.append("checkpoints/agent_versions/*")
include.append(f"checkpoints/agent_versions/{last_ckpt}")
if not ask_if_verbose("Download train dataset?", default=False):
exclude.append("dataset/train")
if not ask_if_verbose("Download test dataset?", default=False):
exclude.append("dataset/test")
cmd = "rsync -av"
for i in include:
cmd += f' --include="{i}"'
for e in exclude:
cmd += f' --exclude="{e}"'
cmd += f" {host}:{src} {str(dst)}"
run(cmd)
path = (dst / time).absolute()
print(f"\n--> Run imported in:\n{path}")
run(f"echo {path} | xclip")
if __name__ == "__main__":
main()