Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,920 Bytes
476e0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import argparse
from huggingface_hub import snapshot_download
def download_ckpt():
parser = argparse.ArgumentParser(description="Download checkpoints from HuggingFace Hub")
parser.add_argument(
"--local_dir",
type=str,
default="./out",
help="Local directory to save the checkpoints"
)
parser.add_argument(
"--model_type",
type=str,
default="sd15",
choices=["sd15", "pas", "sd35m", "depth", "normal", "canny", "elevest"],
help="Model type to download"
)
parser.add_argument(
"--image_cond",
action="store_true",
help="Whether to download image-conditioned models"
)
args = parser.parse_args()
repo_id, local_dir = "chenguolin/DiffSplat", args.local_dir
os.makedirs(local_dir, exist_ok=True)
model_type, image_cond = args.model_type, args.image_cond
suffix = "_image" if image_cond else ""
# DiffSplat (SD1.5)
if model_type == "sd15":
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
allow_patterns=[
"gsrecon_gobj265k_cnp_even4/*", # `GSRecon`
"gsvae_gobj265k_sd/*", # `GSVAE (SD)`
f"gsdiff_gobj83k_sd15{suffix}__render/*", # `DiffSplat (SD)`
]
)
# DiffSplat (PixArt-Sigma)
elif model_type == "pas":
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
allow_patterns=[
"gsrecon_gobj265k_cnp_even4/*", # `GSRecon`
"gsvae_gobj265k_sdxl_fp16/*", # `GSVAE (SDXL)`
f"gsdiff_gobj83k_pas_fp16{suffix}__render/*", # `DiffSplat (PixArt-Sigma)`
]
)
# DiffSplat (SD3.5m)
elif model_type == "sd35m":
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
allow_patterns=[
"gsrecon_gobj265k_cnp_even4/*", # `GSRecon`
"gsvae_gobj265k_sd3/*", # `GSVAE (SD3)`
f"gsdiff_gobj83k_sd35m{suffix}__render/*", # `DiffSplat (SD3.5m)`
]
)
# DiffSplat ControlNet (SD1.5)
elif model_type in ["depth", "normal", "canny"]:
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
allow_patterns=[
f"gsdiff_gobj83k_sd15__render__{model_type}/*", # `DiffSplat ControlNet (SD1.5)`
]
)
# Elevation Estimation
elif model_type == "elevest":
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
allow_patterns=[
"elevest_gobj265k_b_C25/*",
]
)
else:
raise ValueError(f"Choose from ['sd15', 'pas', 'sd35m', 'depth', 'normal', 'canny', 'elevest'], but got [{model_type}]")
if __name__ == "__main__":
download_ckpt()
|