#! /usr/bin/env python import argparse from functools import partial import json from pathlib import Path import subprocess from typing import Optional def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("host", type=str) parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--user", type=Optional[str]) parser.add_argument("--rootdir", type=Optional[str]) args = parser.parse_args() run = partial(subprocess.run, shell=True, check=True, text=True) host = args.host if args.user is None else f"{args.user}@{args.host}" def run_remote_cmd(cmd): return subprocess.check_output(f"ssh {host} {cmd}", shell=True, text=True) def ls(p): out = run_remote_cmd(f"ls {p}") return out.strip().split("\n")[::-1] def ask(l, info=None): print( "\n".join( [ f"{i:{len(str(len(l)))}d}: {d}" + (f" ({info[d]})" if info is not None else "") for i, d in enumerate(l, 1) ] ) ) while True: i = input("\nEnter a number: ") if i.isdigit() and 1 <= int(i) <= len(l): break print("\n/!\\ Invalid choice\n") return l[int(i) - 1] def ask_if_verbose(question, default): if not args.verbose: return default suffix = "[Y|n]" if default else "[y|N]" answer = input(f"{question} {suffix} ").lower() return (answer != "n") if default else (answer == "y") def get_info(rundir): return json.loads( run_remote_cmd(f"cat {rundir}/checkpoints/info_for_import_script.json") ) if args.rootdir is None: for p in Path(__file__).resolve().parents: if (p / ".git").is_dir(): break else: raise RuntimeError("This file is not in a git repository") out = run_remote_cmd(f"find -type d -name {p.name}").strip().split("\n") assert len(out) == 1 rootdir = out[0] else: rootdir = f'{args.rootdir.strip().strip("/")}' dates = ls(f"{rootdir}/outputs") date = ask(dates) times = ls(f"{rootdir}/outputs/{date}") infos = { time: get_info(rundir=f"{rootdir}/outputs/{date}/{time}") for time in times } time = ask(times, infos) src = f"{rootdir}/outputs/{date}/{time}" dst = Path(args.host) / date dst.mkdir(exist_ok=True, parents=True) exclude = [ "*.log", "checkpoints/*", "checkpoints_tmp", ".hydra", "media", "__pycache__", "wandb", ] include = ["checkpoints/agent_versions"] if ask_if_verbose("Download only last checkpoint?", default=True): last_ckpt = ls(f"{src}/checkpoints/agent_versions")[0] exclude.append("checkpoints/agent_versions/*") include.append(f"checkpoints/agent_versions/{last_ckpt}") if not ask_if_verbose("Download train dataset?", default=False): exclude.append("dataset/train") if not ask_if_verbose("Download test dataset?", default=False): exclude.append("dataset/test") cmd = "rsync -av" for i in include: cmd += f' --include="{i}"' for e in exclude: cmd += f' --exclude="{e}"' cmd += f" {host}:{src} {str(dst)}" run(cmd) path = (dst / time).absolute() print(f"\n--> Run imported in:\n{path}") run(f"echo {path} | xclip") if __name__ == "__main__": main()