Hamnivore commited on
Commit
5a56341
·
1 Parent(s): 2c37566

Convert to runtime weight downloading - Remove LFS tracked weights, add download functionality

Browse files
.gitattributes CHANGED
@@ -1,2 +1 @@
1
- weights/** filter=lfs diff=lfs merge=lfs -text
2
  examples/** filter=lfs diff=lfs merge=lfs -text
 
 
1
  examples/** filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -8,4 +8,20 @@ flagged
8
  Synthetic4Relight
9
  **/__pycache__/
10
  vis_*/
11
- src/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  Synthetic4Relight
9
  **/__pycache__/
10
  vis_*/
11
+ src/
12
+
13
+ # Downloaded model weights (large files)
14
+ weights/*/checkpoints/*.ckpt
15
+ weights/*/checkpoints/*.bin
16
+ weights/*/checkpoints/*.safetensors
17
+
18
+ # HuggingFace cache directory
19
+ hf_cache/
20
+ .cache/
21
+
22
+ # Virtual environment (if not already excluded)
23
+ test_env/
24
+
25
+ # Temporary files
26
+ *.tmp
27
+ *.temp
app.py CHANGED
@@ -14,11 +14,72 @@ import rembg
14
  import sys
15
  from loguru import logger
16
 
 
 
 
 
17
  _SAMPLE_TAB_ID_ = 0
18
  _HIGHRES_TAB_ID_ = 1
19
  _FOREGROUND_TAB_ID_ = 2
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def set_loggers(level):
23
  logger.remove()
24
  logger.add(sys.stderr, level=level)
@@ -92,6 +153,11 @@ description = \
92
 
93
  set_loggers("INFO")
94
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
95
  logger.info(f"Loading Models...")
96
  model_dict = {
97
  "Albedo": InferenceModel(ckpt_path="weights/albedo",
 
14
  import sys
15
  from loguru import logger
16
 
17
+ # Add imports for downloading weights
18
+ from huggingface_hub import hf_hub_download
19
+ import shutil
20
+
21
  _SAMPLE_TAB_ID_ = 0
22
  _HIGHRES_TAB_ID_ = 1
23
  _FOREGROUND_TAB_ID_ = 2
24
 
25
 
26
+ def download_model_weights():
27
+ """Download model weights from the original repository if they don't exist."""
28
+
29
+ # Original repository ID
30
+ repo_id = "LittleFrog/IntrinsicAnything"
31
+
32
+ # Define the paths and files to download
33
+ model_configs = [
34
+ {
35
+ "local_dir": "weights/albedo",
36
+ "files": [
37
+ "weights/albedo/checkpoints/last.ckpt",
38
+ "weights/albedo/configs/project.yaml"
39
+ ]
40
+ },
41
+ {
42
+ "local_dir": "weights/specular",
43
+ "files": [
44
+ "weights/specular/checkpoints/last.ckpt",
45
+ "weights/specular/configs/project.yaml"
46
+ ]
47
+ }
48
+ ]
49
+
50
+ for config in model_configs:
51
+ local_dir = config["local_dir"]
52
+
53
+ # Create directories if they don't exist
54
+ os.makedirs(f"{local_dir}/checkpoints", exist_ok=True)
55
+ os.makedirs(f"{local_dir}/configs", exist_ok=True)
56
+
57
+ for file_path in config["files"]:
58
+ local_file_path = file_path
59
+
60
+ # Check if file exists and is not a LFS pointer (> 1KB)
61
+ if not os.path.exists(local_file_path) or os.path.getsize(local_file_path) < 1024:
62
+ logger.info(f"Downloading {file_path} from HuggingFace...")
63
+
64
+ try:
65
+ # Download the file
66
+ downloaded_file = hf_hub_download(
67
+ repo_id=repo_id,
68
+ filename=file_path,
69
+ repo_type="space",
70
+ cache_dir="./hf_cache"
71
+ )
72
+
73
+ # Copy to the expected location
74
+ shutil.copy2(downloaded_file, local_file_path)
75
+ logger.info(f"Successfully downloaded {file_path}")
76
+
77
+ except Exception as e:
78
+ logger.error(f"Failed to download {file_path}: {e}")
79
+ raise e
80
+ else:
81
+ logger.info(f"{local_file_path} already exists and appears to be valid")
82
+
83
  def set_loggers(level):
84
  logger.remove()
85
  logger.add(sys.stderr, level=level)
 
153
 
154
  set_loggers("INFO")
155
  device = "cuda" if torch.cuda.is_available() else "cpu"
156
+
157
+ # Download model weights if needed
158
+ logger.info("Checking and downloading model weights if needed...")
159
+ download_model_weights()
160
+
161
  logger.info(f"Loading Models...")
162
  model_dict = {
163
  "Albedo": InferenceModel(ckpt_path="weights/albedo",
weights/albedo/checkpoints/last.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5fa7a1caa7e1e3818119cd9a2e8715ee7b86a77fa66447cc4b0767d8ab550f8
3
- size 15458840153
 
 
 
 
weights/specular/checkpoints/last.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:361b7c3824f4603c4137657ff2fc8a127e0972954d425b440516522f034de7a0
3
- size 15458847705