jbilcke-hf HF Staff Claude commited on
Commit
7f51853
·
1 Parent(s): a046781

Make weights path configurable via WEIGHTS_PATH environment variable

Browse files

- Add WEIGHTS_PATH constant that reads from environment variable with /data/weights/ as fallback
- Update MODEL_BASE to use configurable weights path
- Update model checkpoint path in create_args() function
- Update model download logic to use configurable path instead of hardcoded ./weights/

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -15,7 +15,10 @@ from hymm_sp.data_kits.data_tools import save_videos_grid
15
  from hymm_sp.config import parse_args
16
  import argparse
17
 
18
- os.environ["MODEL_BASE"] = "weights/stdmodels"
 
 
 
19
  os.environ["DISABLE_SP"] = "1"
20
  os.environ["CPU_OFFLOAD"] = "1"
21
 
@@ -42,7 +45,7 @@ class CropResize:
42
 
43
  def create_args():
44
  args = argparse.Namespace()
45
- args.ckpt = "weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
46
  args.video_size = [704, 1216]
47
  args.cfg_scale = 1.0
48
  args.image_start = True
@@ -65,13 +68,14 @@ def create_args():
65
 
66
  logger.info("Initializing Hunyuan-GameCraft model...")
67
 
68
- if not os.path.exists("weights/gamecraft_models/mp_rank_00_model_states_distill.pt"):
 
69
  logger.info("Downloading model weights from Hugging Face...")
70
- os.makedirs("weights/gamecraft_models", exist_ok=True)
71
  hf_hub_download(
72
  repo_id="tencent/Hunyuan-GameCraft-1.0",
73
  filename="gamecraft_models/mp_rank_00_model_states_distill.pt",
74
- local_dir="weights/",
75
  local_dir_use_symlinks=False
76
  )
77
 
 
15
  from hymm_sp.config import parse_args
16
  import argparse
17
 
18
+ # Get weights path from environment variable or use default
19
+ WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "/data/weights")
20
+
21
+ os.environ["MODEL_BASE"] = os.path.join(WEIGHTS_PATH, "stdmodels")
22
  os.environ["DISABLE_SP"] = "1"
23
  os.environ["CPU_OFFLOAD"] = "1"
24
 
 
45
 
46
  def create_args():
47
  args = argparse.Namespace()
48
+ args.ckpt = os.path.join(WEIGHTS_PATH, "gamecraft_models/mp_rank_00_model_states_distill.pt")
49
  args.video_size = [704, 1216]
50
  args.cfg_scale = 1.0
51
  args.image_start = True
 
68
 
69
  logger.info("Initializing Hunyuan-GameCraft model...")
70
 
71
+ model_path = os.path.join(WEIGHTS_PATH, "gamecraft_models/mp_rank_00_model_states_distill.pt")
72
+ if not os.path.exists(model_path):
73
  logger.info("Downloading model weights from Hugging Face...")
74
+ os.makedirs(os.path.join(WEIGHTS_PATH, "gamecraft_models"), exist_ok=True)
75
  hf_hub_download(
76
  repo_id="tencent/Hunyuan-GameCraft-1.0",
77
  filename="gamecraft_models/mp_rank_00_model_states_distill.pt",
78
+ local_dir=WEIGHTS_PATH,
79
  local_dir_use_symlinks=False
80
  )
81