ouclxy commited on
Commit
a5c8b6d
·
verified ·
1 Parent(s): d3513d5

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +24 -7
gradio_app.py CHANGED
@@ -6,8 +6,8 @@ import base64
6
  import shutil
7
  from typing import Optional, Tuple
8
 
9
- import spaces
10
  import gradio as gr
 
11
  import torch
12
  import cv2
13
  import numpy as np
@@ -121,10 +121,18 @@ def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]:
121
  cache_dir=cache_dir,
122
  token=HF_AUTH_TOKEN,
123
  )
 
 
 
 
 
 
 
 
124
  # Create a symlink so that imports like "from HairMapper..." work
125
- hairmapper_dir = _ensure_symlink(hm_snap, os.path.abspath("HairMapper"))
126
  if hairmapper_dir not in sys.path:
127
- sys.path.insert(0, hairmapper_dir)
128
 
129
  # 3) FFHQFaceAlignment
130
  ffhq_dir = None
@@ -135,10 +143,18 @@ def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]:
135
  cache_dir=cache_dir,
136
  token=HF_AUTH_TOKEN,
137
  )
138
- # Create a symlink so that test_stablehairv2._maybe_align_image("./FFHQFaceAlignment") resolves
139
- ffhq_dir = _ensure_symlink(fa_snap, os.path.abspath("FFHQFaceAlignment"))
 
 
 
 
 
 
 
 
140
  if ffhq_dir not in sys.path:
141
- sys.path.insert(0, ffhq_dir)
142
 
143
  # 4) Optional: Trained model weights (motion/control/ref)
144
  if TRAINED_MODEL_REPO:
@@ -186,9 +202,10 @@ SD15_PATH, _, _ = _download_models()
186
  # -----------------------------------------------------------------------------
187
  # Gradio inference
188
  # -----------------------------------------------------------------------------
189
- with open("imgs/background.png", "rb") as f:
190
  _b64_bg = base64.b64encode(f.read()).decode()
191
 
 
192
  @spaces.GPU
193
  def inference(id_image, hair_image):
194
  # Require GPU (HairMapper currently uses CUDA explicitly)
 
6
  import shutil
7
  from typing import Optional, Tuple
8
 
 
9
  import gradio as gr
10
+ import spaces
11
  import torch
12
  import cv2
13
  import numpy as np
 
121
  cache_dir=cache_dir,
122
  token=HF_AUTH_TOKEN,
123
  )
124
+ # If repo root contains a nested "HairMapper" folder, link to that subfolder.
125
+ hm_src = hm_snap
126
+ nested_hm = os.path.join(hm_snap, "HairMapper")
127
+ if os.path.isdir(nested_hm) and (
128
+ os.path.isfile(os.path.join(nested_hm, "hair_mapper_run.py")) or
129
+ os.path.isdir(os.path.join(nested_hm, "mapper"))
130
+ ):
131
+ hm_src = nested_hm
132
  # Create a symlink so that imports like "from HairMapper..." work
133
+ hairmapper_dir = _ensure_symlink(hm_src, os.path.abspath("HairMapper"))
134
  if hairmapper_dir not in sys.path:
135
+ sys.path.insert(0, os.path.dirname(hairmapper_dir))
136
 
137
  # 3) FFHQFaceAlignment
138
  ffhq_dir = None
 
143
  cache_dir=cache_dir,
144
  token=HF_AUTH_TOKEN,
145
  )
146
+ # If repo root contains a nested "FFHQFaceAlignment" folder, link to that subfolder.
147
+ fa_src = fa_snap
148
+ nested_fa = os.path.join(fa_snap, "FFHQFaceAlignment")
149
+ if os.path.isdir(nested_fa) and (
150
+ os.path.isfile(os.path.join(nested_fa, "align.py")) or
151
+ os.path.isdir(os.path.join(nested_fa, "lib"))
152
+ ):
153
+ fa_src = nested_fa
154
+ # Create a symlink so that _maybe_align_image can import modules relatively
155
+ ffhq_dir = _ensure_symlink(fa_src, os.path.abspath("FFHQFaceAlignment"))
156
  if ffhq_dir not in sys.path:
157
+ sys.path.insert(0, os.path.dirname(ffhq_dir))
158
 
159
  # 4) Optional: Trained model weights (motion/control/ref)
160
  if TRAINED_MODEL_REPO:
 
202
  # -----------------------------------------------------------------------------
203
  # Gradio inference
204
  # -----------------------------------------------------------------------------
205
+ with open("imgs/background.jpg", "rb") as f:
206
  _b64_bg = base64.b64encode(f.read()).decode()
207
 
208
+
209
  @spaces.GPU
210
  def inference(id_image, hair_image):
211
  # Require GPU (HairMapper currently uses CUDA explicitly)