ouclxy commited on
Commit
af8f9f7
·
verified ·
1 Parent(s): eae8684

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +226 -116
gradio_app.py CHANGED
@@ -1,38 +1,190 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
- # 替换原有的目录设置
 
 
5
  os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio")
6
  os.environ.setdefault("TMPDIR", "/tmp")
7
- os.makedirs("/tmp/gradio", exist_ok=True)
8
- os.makedirs("/tmp", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # 同时修改你的输出目录为相对路径
11
- os.makedirs("gradio_inputs", exist_ok=True)
12
- os.makedirs("gradio_outputs", exist_ok=True)
13
 
 
 
 
 
14
 
15
- import logging
16
- import gradio as gr
17
- import torch
18
- import os
19
- import uuid
20
- from test_stablehairv2 import log_validation
21
- from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection
22
- from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
23
- from omegaconf import OmegaConf
24
- import numpy as np
25
- import cv2
26
- from test_stablehairv2 import _maybe_align_image
27
- from HairMapper.hair_mapper_run import bald_head
28
-
29
- import base64
30
 
 
 
 
31
  with open("imgs/background.jpg", "rb") as f:
32
- b64_img = base64.b64encode(f.read()).decode()
33
 
34
 
35
  def inference(id_image, hair_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  os.makedirs("gradio_inputs", exist_ok=True)
37
  os.makedirs("gradio_outputs", exist_ok=True)
38
 
@@ -41,40 +193,46 @@ def inference(id_image, hair_image):
41
  id_image.save(id_path)
42
  hair_image.save(hair_path)
43
 
44
- # ===== 图像对齐 =====
45
  aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
46
  aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
47
 
48
- # 保存对齐结果(方便 Gradio 输出)
49
  aligned_id_path = "gradio_outputs/aligned_id.png"
50
  aligned_hair_path = "gradio_outputs/aligned_hair.png"
51
  cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
52
  cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
53
 
54
- # ===== 调用 HairMapper 秃头化 =====
55
  bald_id_path = "gradio_outputs/bald_id.png"
56
  cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
57
  bald_head(bald_id_path, bald_id_path)
58
 
59
- # ===== 原本的 Args =====
 
 
 
 
 
 
60
  class Args:
61
- pretrained_model_name_or_path = "./stable-diffusion-v1-5/stable-diffusion-v1-5"
62
- model_path = "./trained_model"
63
  image_encoder = "openai/clip-vit-large-patch14"
64
  controlnet_model_name_or_path = None
65
  revision = None
66
  output_dir = "gradio_outputs"
67
  seed = 42
68
  num_validation_images = 1
69
- validation_ids = [aligned_id_path] # 用对齐后的图像
70
- validation_hairs = [aligned_hair_path] # 用对齐后的图像
71
  use_fp16 = False
 
 
72
 
73
  args = Args()
74
 
75
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
 
77
- # 初始化 logger
78
  logging.basicConfig(
79
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80
  datefmt="%m/%d/%Y %H:%M:%S",
@@ -82,15 +240,17 @@ def inference(id_image, hair_image):
82
  )
83
  logger = logging.getLogger(__name__)
84
 
85
- # ===== 模型加载(和 main() 对齐) =====
86
  tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
87
  revision=args.revision)
88
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
89
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(
90
- device, dtype=torch.float32)
91
 
 
92
  infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
93
 
 
94
  unet2 = UNet2DConditionModel.from_pretrained(
95
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
96
  ).to(device)
@@ -126,10 +286,10 @@ def inference(id_image, hair_image):
126
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
127
  device_map=None, ignore_mismatched_sizes=True
128
  ).to(device)
129
- state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
130
- Hair_Encoder.load_state_dict(state_dict2, strict=False)
131
 
132
- # 推理
133
  log_validation(
134
  vae, tokenizer, image_encoder, denoising_unet,
135
  args, device, logger,
@@ -138,7 +298,7 @@ def inference(id_image, hair_image):
138
 
139
  output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
140
 
141
- # 提取视频帧用于可拖动预览
142
  frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
143
  os.makedirs(frames_dir, exist_ok=True)
144
  cap = cv2.VideoCapture(output_video)
@@ -157,34 +317,21 @@ def inference(id_image, hair_image):
157
  max_frames = len(frames_list) if frames_list else 1
158
  first_frame = frames_list[0] if frames_list else None
159
 
160
- return aligned_id_path, aligned_hair_path, bald_id_path, output_video, frames_list, gr.update(minimum=1,
161
- maximum=max_frames,
162
- value=1,
163
- step=1), first_frame
164
-
165
-
166
- # Gradio 前端
167
- # 原 Interface 版本(保留以便回退)
168
- # demo = gr.Interface(
169
- # fn=inference,
170
- # inputs=[
171
- # gr.Image(type="pil", label="上传身份图(ID Image)"),
172
- # gr.Image(type="pil", label="上传发型图(Hair Reference Image)")
173
- # ],
174
- # outputs=[
175
- # gr.Image(type="filepath", label="对齐后的身份图"),
176
- # gr.Image(type="filepath", label="对齐后的发型图"),
177
- # gr.Image(type="filepath", label="秃头化后的身份图"),
178
- # gr.Video(label="生成的视频")
179
- # ],
180
- # title="StableHairV2 多视角发型迁移",
181
- # description="上传身份图和发型参考图,查看对齐结果并生成多视角视频"
182
- # )
183
- # if __name__ == "__main__":
184
- # demo.launch(server_name="0.0.0.0", server_port=7860)
185
-
186
- # Blocks 美化版
187
- css = f"""
188
  html, body {{
189
  height: 100%;
190
  margin: 0;
@@ -195,10 +342,10 @@ css = f"""
195
  height: 100% !important;
196
  margin: 0 !important;
197
  padding: 0 !important;
198
- background-image: url("data:image/jpeg;base64,{b64_img}");
199
  background-size: cover;
200
  background-position: center;
201
- background-attachment: fixed; /* 背景固定 */
202
  }}
203
  #title-card {{
204
  background: rgba(255, 255, 255, 0.8);
@@ -226,7 +373,6 @@ css = f"""
226
  }}
227
  .left-pane {{min-width: 360px}}
228
  .right-pane {{min-width: 680px}}
229
- /* Tabs 美化 */
230
  .tabs {{
231
  background: rgba(255,255,255,0.88);
232
  border-radius: 12px;
@@ -240,31 +386,11 @@ css = f"""
240
  border-bottom: 1px solid #e5e7eb;
241
  padding-bottom: 6px;
242
  }}
243
- .tab-nav button {{
244
- background: rgba(255,255,255,0.7);
245
- border: 1px solid #e5e7eb;
246
- backdrop-filter: blur(6px);
247
- border-radius: 8px;
248
- padding: 6px 12px;
249
- color: #111827;
250
- transition: all .2s ease;
251
- }}
252
- .tab-nav button:hover {{
253
- transform: translateY(-1px);
254
- box-shadow: 0 4px 10px rgba(0,0,0,0.06);
255
- }}
256
- .tab-nav button[aria-selected="true"] {{
257
- background: #4f46e5;
258
- color: #fff;
259
- border-color: #4f46e5;
260
- box-shadow: 0 6px 14px rgba(79,70,229,0.25);
261
- }}
262
  .tabitem {{
263
  background: rgba(255,255,255,0.88);
264
  border-radius: 10px;
265
  padding: 8px;
266
  }}
267
- /* 发型库滚动限制容器:固定260px高度,内部可滚动 */
268
  #hair_gallery_wrap {{
269
  height: 260px !important;
270
  overflow-y: scroll !important;
@@ -274,17 +400,13 @@ css = f"""
274
  height: 100% !important;
275
  overflow-y: scroll !important;
276
  }}
277
- /* 确保画廊本体占满容器高度,避免滚动条落到页面底部 */
278
  #hair_gallery {{
279
  height: 100% !important;
280
  }}
281
  """
282
 
283
- with gr.Blocks(
284
- theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
285
- css=css
286
- ) as demo:
287
- # ==== 顶部 Panel ====
288
  with gr.Group(elem_id="title-card"):
289
  gr.Markdown("""
290
  <h2 id='title'>StableHairV2 多视角发型迁移</h2>
@@ -300,13 +422,10 @@ with gr.Blocks(
300
  run_btn = gr.Button("开始生成", variant="primary")
301
  clear_btn = gr.Button("清空")
302
 
303
- # ========= 发型库(点击即填充到“发型参考图”) =========
304
  def _list_imgs(dir_path: str):
305
  exts = (".png", ".jpg", ".jpeg", ".webp")
306
- # exts = (".jpg")
307
  try:
308
- files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path))
309
- if f.lower().endswith(exts)]
310
  return files
311
  except Exception:
312
  return []
@@ -315,11 +434,8 @@ with gr.Blocks(
315
 
316
  with gr.Accordion("发型库(点击选择后自动填充)", open=True):
317
  with gr.Group(elem_id="hair_gallery_wrap"):
318
- gallery = gr.Gallery(
319
- value=hair_list,
320
- columns=4, rows=2, allow_preview=True, label="发型库",
321
- elem_id="hair_gallery"
322
- )
323
 
324
  def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined]
325
  i = evt.index if hasattr(evt, 'index') else 0
@@ -350,12 +466,11 @@ with gr.Blocks(
350
  with gr.Group(elem_classes=["out-card"]):
351
  bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260)
352
 
353
- # 逻辑保持不变
354
- run_btn.click(fn=inference,
355
- inputs=[id_input, hair_input],
356
- outputs=[aligned_id_out, aligned_hair_out, bald_id_out,
357
- video_out, frames_state, frame_slider, frame_preview])
358
-
359
 
360
  def _on_slide(frames, idx):
361
  if not frames:
@@ -364,20 +479,15 @@ with gr.Blocks(
364
  i = max(0, min(i, len(frames) - 1))
365
  return gr.update(value=frames[i])
366
 
367
-
368
  frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview)
369
 
370
-
371
  def _clear():
372
  return None, None, None, None, None
373
 
 
374
 
375
- clear_btn.click(_clear, None,
376
- [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out])
377
 
378
  if __name__ == "__main__":
379
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)
380
 
381
 
382
-
383
-
 
1
  import os
2
+ import sys
3
+ import uuid
4
+ import logging
5
+ import base64
6
+ import shutil
7
+ from typing import Optional, Tuple
8
+
9
+ import gradio as gr
10
+ import torch
11
+ import cv2
12
+ import numpy as np
13
+
14
+ from huggingface_hub import snapshot_download
15
 
16
 
17
+ # -----------------------------------------------------------------------------
18
+ # Environment for HF Spaces
19
+ # -----------------------------------------------------------------------------
20
  os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio")
21
  os.environ.setdefault("TMPDIR", "/tmp")
22
+ os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
23
+ os.makedirs(os.environ["TMPDIR"], exist_ok=True)
24
+
25
+
26
+ # -----------------------------------------------------------------------------
27
+ # Config via environment variables (set these in your Space settings)
28
+ # -----------------------------------------------------------------------------
29
+ # Required (you uploaded these as separate model repos on HF):
30
+ # - FFHQFACEALIGNMENT_REPO (e.g., "yourname/FFHQFaceAlignment")
31
+ # - HAIRMAPPER_REPO (e.g., "yourname/HairMapper")
32
+ # - SD15_REPO (e.g., "yourname/stable-diffusion-v1-5")
33
+ # Optional:
34
+ # - TRAINED_MODEL_REPO (if you uploaded motion/control/ref ckpts as a repo)
35
+ # If TRAINED_MODEL_REPO not provided, we will try to use local "./pretrain".
36
+ FFHQFACEALIGNMENT_REPO = os.getenv("FFHQFACEALIGNMENT_REPO", "")
37
+ HAIRMAPPER_REPO = os.getenv("HAIRMAPPER_REPO", "")
38
+ SD15_REPO = os.getenv("SD15_REPO", "")
39
+ TRAINED_MODEL_REPO = os.getenv("TRAINED_MODEL_REPO", "")
40
+
41
+
42
+ # -----------------------------------------------------------------------------
43
+ # Utilities
44
+ # -----------------------------------------------------------------------------
45
+ def _ensure_symlink(src_dir: str, dst_path: str) -> str:
46
+ """Create a directory symlink at dst_path pointing to src_dir if not exists.
47
+ If symlink creation is unavailable, fallback to copying a minimal structure.
48
+ Returns the final path that should be used by imports (dst_path if created, else src_dir).
49
+ """
50
+ try:
51
+ if os.path.islink(dst_path) or os.path.isdir(dst_path):
52
+ return dst_path
53
+ os.symlink(src_dir, dst_path, target_is_directory=True)
54
+ return dst_path
55
+ except Exception:
56
+ # Fallback: try to create the directory and copy only top-level python files/dirs needed
57
+ try:
58
+ if not os.path.exists(dst_path):
59
+ os.makedirs(dst_path, exist_ok=True)
60
+ # Last resort: shallow copy (can still be heavy; symlink is preferred on HF Linux)
61
+ for name in os.listdir(src_dir):
62
+ src = os.path.join(src_dir, name)
63
+ dst = os.path.join(dst_path, name)
64
+ if os.path.exists(dst):
65
+ continue
66
+ if os.path.isdir(src):
67
+ shutil.copytree(src, dst)
68
+ else:
69
+ shutil.copy2(src, dst)
70
+ return dst_path
71
+ except Exception:
72
+ # Give up and return original source
73
+ return src_dir
74
+
75
+
76
+ def _find_model_root(path: str) -> str:
77
+ """Given a snapshot path, return the directory containing model_index.json.
78
+ Handles repos that nest the folder (e.g., repo/stable-diffusion-v1-5/...).
79
+ """
80
+ if os.path.isfile(os.path.join(path, "model_index.json")):
81
+ return path
82
+ # Search one level deep for a folder with model_index.json
83
+ for name in os.listdir(path):
84
+ cand = os.path.join(path, name)
85
+ if os.path.isdir(cand) and os.path.isfile(os.path.join(cand, "model_index.json")):
86
+ return cand
87
+ # As a fallback, return original path
88
+ return path
89
+
90
+
91
+ def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]:
92
+ """Download HF model repos and prepare local paths.
93
+
94
+ Returns:
95
+ - sd15_path: path to the Stable Diffusion v1-5 folder (with model_index.json)
96
+ - hairmapper_dir: path to local HairMapper folder (import root)
97
+ - ffhq_dir: path to local FFHQFaceAlignment folder (import root)
98
+ """
99
+ cache_dir = os.getenv("HF_HUB_CACHE", None)
100
+
101
+ # 1) Stable Diffusion 1.5
102
+ sd15_path = None
103
+ if SD15_REPO:
104
+ sd_snap = snapshot_download(repo_id=SD15_REPO, local_files_only=False, cache_dir=cache_dir)
105
+ sd15_path = _find_model_root(sd_snap)
106
+
107
+ # 2) HairMapper
108
+ hairmapper_dir = None
109
+ if HAIRMAPPER_REPO:
110
+ hm_snap = snapshot_download(repo_id=HAIRMAPPER_REPO, local_files_only=False, cache_dir=cache_dir)
111
+ # Create a symlink so that imports like "from HairMapper..." work
112
+ hairmapper_dir = _ensure_symlink(hm_snap, os.path.abspath("HairMapper"))
113
+ if hairmapper_dir not in sys.path:
114
+ sys.path.insert(0, hairmapper_dir)
115
+
116
+ # 3) FFHQFaceAlignment
117
+ ffhq_dir = None
118
+ if FFHQFACEALIGNMENT_REPO:
119
+ fa_snap = snapshot_download(repo_id=FFHQFACEALIGNMENT_REPO, local_files_only=False, cache_dir=cache_dir)
120
+ # Create a symlink so that test_stablehairv2._maybe_align_image("./FFHQFaceAlignment") resolves
121
+ ffhq_dir = _ensure_symlink(fa_snap, os.path.abspath("FFHQFaceAlignment"))
122
+ if ffhq_dir not in sys.path:
123
+ sys.path.insert(0, ffhq_dir)
124
+
125
+ # 4) Optional: Trained model weights (motion/control/ref)
126
+ if TRAINED_MODEL_REPO:
127
+ tm_snap = snapshot_download(repo_id=TRAINED_MODEL_REPO, local_files_only=False, cache_dir=cache_dir)
128
+ # Symlink to ./trained_model so downstream code can load from there
129
+ _ = _ensure_symlink(tm_snap, os.path.abspath("trained_model"))
130
+
131
+ return sd15_path, hairmapper_dir, ffhq_dir
132
+
133
+
134
+ # -----------------------------------------------------------------------------
135
+ # Lazy imports that rely on downloaded models/paths
136
+ # -----------------------------------------------------------------------------
137
+ def _import_inference_bits():
138
+ from test_stablehairv2 import log_validation
139
+ from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection
140
+ from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
141
+ from test_stablehairv2 import _maybe_align_image
142
+ from HairMapper.hair_mapper_run import bald_head
143
+ return (
144
+ log_validation,
145
+ UNet3DConditionModel,
146
+ ControlNetModel,
147
+ CCProjection,
148
+ AutoTokenizer,
149
+ CLIPVisionModelWithProjection,
150
+ AutoencoderKL,
151
+ UNet2DConditionModel,
152
+ _maybe_align_image,
153
+ bald_head,
154
+ )
155
 
 
 
 
156
 
157
+ # -----------------------------------------------------------------------------
158
+ # Prepare models on startup
159
+ # -----------------------------------------------------------------------------
160
+ SD15_PATH, _, _ = _download_models()
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ # -----------------------------------------------------------------------------
164
+ # Gradio inference
165
+ # -----------------------------------------------------------------------------
166
  with open("imgs/background.jpg", "rb") as f:
167
+ _b64_bg = base64.b64encode(f.read()).decode()
168
 
169
 
170
  def inference(id_image, hair_image):
171
+ # Require GPU (HairMapper currently uses CUDA explicitly)
172
+ if not torch.cuda.is_available():
173
+ raise RuntimeError("This demo requires a GPU Space. Please enable a GPU in this Space.")
174
+
175
+ (
176
+ log_validation,
177
+ UNet3DConditionModel,
178
+ ControlNetModel,
179
+ CCProjection,
180
+ AutoTokenizer,
181
+ CLIPVisionModelWithProjection,
182
+ AutoencoderKL,
183
+ UNet2DConditionModel,
184
+ _maybe_align_image,
185
+ bald_head,
186
+ ) = _import_inference_bits()
187
+
188
  os.makedirs("gradio_inputs", exist_ok=True)
189
  os.makedirs("gradio_outputs", exist_ok=True)
190
 
 
193
  id_image.save(id_path)
194
  hair_image.save(hair_path)
195
 
196
+ # Align
197
  aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
198
  aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
199
 
 
200
  aligned_id_path = "gradio_outputs/aligned_id.png"
201
  aligned_hair_path = "gradio_outputs/aligned_hair.png"
202
  cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
203
  cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
204
 
205
+ # Balding
206
  bald_id_path = "gradio_outputs/bald_id.png"
207
  cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
208
  bald_head(bald_id_path, bald_id_path)
209
 
210
+ # Resolve trained model dir
211
+ trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
212
+ if trained_model_dir is None and os.path.isdir("pretrain"):
213
+ trained_model_dir = os.path.abspath("pretrain")
214
+ if trained_model_dir is None:
215
+ raise RuntimeError("Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.")
216
+
217
  class Args:
218
+ pretrained_model_name_or_path = SD15_PATH or os.path.abspath("stable-diffusion-v1-5/stable-diffusion-v1-5")
219
+ model_path = trained_model_dir
220
  image_encoder = "openai/clip-vit-large-patch14"
221
  controlnet_model_name_or_path = None
222
  revision = None
223
  output_dir = "gradio_outputs"
224
  seed = 42
225
  num_validation_images = 1
226
+ validation_ids = [aligned_id_path]
227
+ validation_hairs = [aligned_hair_path]
228
  use_fp16 = False
229
+ align_before_infer = True
230
+ align_size = 1024
231
 
232
  args = Args()
233
 
234
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
235
 
 
236
  logging.basicConfig(
237
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
238
  datefmt="%m/%d/%Y %H:%M:%S",
 
240
  )
241
  logger = logging.getLogger(__name__)
242
 
243
+ # Load tokenizer/encoders/vae
244
  tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
245
  revision=args.revision)
246
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
247
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",
248
+ revision=args.revision).to(device, dtype=torch.float32)
249
 
250
+ from omegaconf import OmegaConf
251
  infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
252
 
253
+ # UNet2D with 8-channel conv_in
254
  unet2 = UNet2DConditionModel.from_pretrained(
255
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
256
  ).to(device)
 
286
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
287
  device_map=None, ignore_mismatched_sizes=True
288
  ).to(device)
289
+ state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
290
+ Hair_Encoder.load_state_dict(state_dict4, strict=False)
291
 
292
+ # Run inference
293
  log_validation(
294
  vae, tokenizer, image_encoder, denoising_unet,
295
  args, device, logger,
 
298
 
299
  output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
300
 
301
+ # Extract frames for slider preview
302
  frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
303
  os.makedirs(frames_dir, exist_ok=True)
304
  cap = cv2.VideoCapture(output_video)
 
317
  max_frames = len(frames_list) if frames_list else 1
318
  first_frame = frames_list[0] if frames_list else None
319
 
320
+ return (
321
+ aligned_id_path,
322
+ aligned_hair_path,
323
+ bald_id_path,
324
+ output_video,
325
+ frames_list,
326
+ gr.update(minimum=1, maximum=max_frames, value=1, step=1),
327
+ first_frame,
328
+ )
329
+
330
+
331
+ # -----------------------------------------------------------------------------
332
+ # UI (Blocks)
333
+ # -----------------------------------------------------------------------------
334
+ CSS = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  html, body {{
336
  height: 100%;
337
  margin: 0;
 
342
  height: 100% !important;
343
  margin: 0 !important;
344
  padding: 0 !important;
345
+ background-image: url("data:image/jpeg;base64,{_b64_bg}");
346
  background-size: cover;
347
  background-position: center;
348
+ background-attachment: fixed;
349
  }}
350
  #title-card {{
351
  background: rgba(255, 255, 255, 0.8);
 
373
  }}
374
  .left-pane {{min-width: 360px}}
375
  .right-pane {{min-width: 680px}}
 
376
  .tabs {{
377
  background: rgba(255,255,255,0.88);
378
  border-radius: 12px;
 
386
  border-bottom: 1px solid #e5e7eb;
387
  padding-bottom: 6px;
388
  }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  .tabitem {{
390
  background: rgba(255,255,255,0.88);
391
  border-radius: 10px;
392
  padding: 8px;
393
  }}
 
394
  #hair_gallery_wrap {{
395
  height: 260px !important;
396
  overflow-y: scroll !important;
 
400
  height: 100% !important;
401
  overflow-y: scroll !important;
402
  }}
 
403
  #hair_gallery {{
404
  height: 100% !important;
405
  }}
406
  """
407
 
408
+
409
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=CSS) as demo:
 
 
 
410
  with gr.Group(elem_id="title-card"):
411
  gr.Markdown("""
412
  <h2 id='title'>StableHairV2 多视角发型迁移</h2>
 
422
  run_btn = gr.Button("开始生成", variant="primary")
423
  clear_btn = gr.Button("清空")
424
 
 
425
  def _list_imgs(dir_path: str):
426
  exts = (".png", ".jpg", ".jpeg", ".webp")
 
427
  try:
428
+ files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path)) if f.lower().endswith(exts)]
 
429
  return files
430
  except Exception:
431
  return []
 
434
 
435
  with gr.Accordion("发型库(点击选择后自动填充)", open=True):
436
  with gr.Group(elem_id="hair_gallery_wrap"):
437
+ gallery = gr.Gallery(value=hair_list, columns=4, rows=2, allow_preview=True, label="发型库",
438
+ elem_id="hair_gallery")
 
 
 
439
 
440
  def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined]
441
  i = evt.index if hasattr(evt, 'index') else 0
 
466
  with gr.Group(elem_classes=["out-card"]):
467
  bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260)
468
 
469
+ run_btn.click(
470
+ fn=inference,
471
+ inputs=[id_input, hair_input],
472
+ outputs=[aligned_id_out, aligned_hair_out, bald_id_out, video_out, frames_state, frame_slider, frame_preview],
473
+ )
 
474
 
475
  def _on_slide(frames, idx):
476
  if not frames:
 
479
  i = max(0, min(i, len(frames) - 1))
480
  return gr.update(value=frames[i])
481
 
 
482
  frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview)
483
 
 
484
  def _clear():
485
  return None, None, None, None, None
486
 
487
+ clear_btn.click(_clear, None, [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out])
488
 
 
 
489
 
490
  if __name__ == "__main__":
491
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)
492
 
493