Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import argparse | |
from huggingface_hub import snapshot_download | |
def download_ckpt(): | |
parser = argparse.ArgumentParser(description="Download checkpoints from HuggingFace Hub") | |
parser.add_argument( | |
"--local_dir", | |
type=str, | |
default="./out", | |
help="Local directory to save the checkpoints" | |
) | |
parser.add_argument( | |
"--model_type", | |
type=str, | |
default="sd15", | |
choices=["sd15", "pas", "sd35m", "depth", "normal", "canny", "elevest"], | |
help="Model type to download" | |
) | |
parser.add_argument( | |
"--image_cond", | |
action="store_true", | |
help="Whether to download image-conditioned models" | |
) | |
args = parser.parse_args() | |
repo_id, local_dir = "chenguolin/DiffSplat", args.local_dir | |
os.makedirs(local_dir, exist_ok=True) | |
model_type, image_cond = args.model_type, args.image_cond | |
suffix = "_image" if image_cond else "" | |
# DiffSplat (SD1.5) | |
if model_type == "sd15": | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
allow_patterns=[ | |
"gsrecon_gobj265k_cnp_even4/*", # `GSRecon` | |
"gsvae_gobj265k_sd/*", # `GSVAE (SD)` | |
f"gsdiff_gobj83k_sd15{suffix}__render/*", # `DiffSplat (SD)` | |
] | |
) | |
# DiffSplat (PixArt-Sigma) | |
elif model_type == "pas": | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
allow_patterns=[ | |
"gsrecon_gobj265k_cnp_even4/*", # `GSRecon` | |
"gsvae_gobj265k_sdxl_fp16/*", # `GSVAE (SDXL)` | |
f"gsdiff_gobj83k_pas_fp16{suffix}__render/*", # `DiffSplat (PixArt-Sigma)` | |
] | |
) | |
# DiffSplat (SD3.5m) | |
elif model_type == "sd35m": | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
allow_patterns=[ | |
"gsrecon_gobj265k_cnp_even4/*", # `GSRecon` | |
"gsvae_gobj265k_sd3/*", # `GSVAE (SD3)` | |
f"gsdiff_gobj83k_sd35m{suffix}__render/*", # `DiffSplat (SD3.5m)` | |
] | |
) | |
# DiffSplat ControlNet (SD1.5) | |
elif model_type in ["depth", "normal", "canny"]: | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
allow_patterns=[ | |
f"gsdiff_gobj83k_sd15__render__{model_type}/*", # `DiffSplat ControlNet (SD1.5)` | |
] | |
) | |
# Elevation Estimation | |
elif model_type == "elevest": | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
allow_patterns=[ | |
"elevest_gobj265k_b_C25/*", | |
] | |
) | |
else: | |
raise ValueError(f"Choose from ['sd15', 'pas', 'sd35m', 'depth', 'normal', 'canny', 'elevest'], but got [{model_type}]") | |
if __name__ == "__main__": | |
download_ckpt() | |