Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_app.py
Browse files- gradio_app.py +226 -116
gradio_app.py
CHANGED
@@ -1,38 +1,190 @@
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
-
#
|
|
|
|
|
5 |
os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio")
|
6 |
os.environ.setdefault("TMPDIR", "/tmp")
|
7 |
-
os.makedirs("
|
8 |
-
os.makedirs("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# 同时修改你的输出目录为相对路径
|
11 |
-
os.makedirs("gradio_inputs", exist_ok=True)
|
12 |
-
os.makedirs("gradio_outputs", exist_ok=True)
|
13 |
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
import logging
|
16 |
-
import gradio as gr
|
17 |
-
import torch
|
18 |
-
import os
|
19 |
-
import uuid
|
20 |
-
from test_stablehairv2 import log_validation
|
21 |
-
from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection
|
22 |
-
from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
|
23 |
-
from omegaconf import OmegaConf
|
24 |
-
import numpy as np
|
25 |
-
import cv2
|
26 |
-
from test_stablehairv2 import _maybe_align_image
|
27 |
-
from HairMapper.hair_mapper_run import bald_head
|
28 |
-
|
29 |
-
import base64
|
30 |
|
|
|
|
|
|
|
31 |
with open("imgs/background.jpg", "rb") as f:
|
32 |
-
|
33 |
|
34 |
|
35 |
def inference(id_image, hair_image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
os.makedirs("gradio_inputs", exist_ok=True)
|
37 |
os.makedirs("gradio_outputs", exist_ok=True)
|
38 |
|
@@ -41,40 +193,46 @@ def inference(id_image, hair_image):
|
|
41 |
id_image.save(id_path)
|
42 |
hair_image.save(hair_path)
|
43 |
|
44 |
-
#
|
45 |
aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
|
46 |
aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
|
47 |
|
48 |
-
# 保存对齐结果(方便 Gradio 输出)
|
49 |
aligned_id_path = "gradio_outputs/aligned_id.png"
|
50 |
aligned_hair_path = "gradio_outputs/aligned_hair.png"
|
51 |
cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
52 |
cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
|
53 |
|
54 |
-
#
|
55 |
bald_id_path = "gradio_outputs/bald_id.png"
|
56 |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
57 |
bald_head(bald_id_path, bald_id_path)
|
58 |
|
59 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
class Args:
|
61 |
-
pretrained_model_name_or_path = "
|
62 |
-
model_path =
|
63 |
image_encoder = "openai/clip-vit-large-patch14"
|
64 |
controlnet_model_name_or_path = None
|
65 |
revision = None
|
66 |
output_dir = "gradio_outputs"
|
67 |
seed = 42
|
68 |
num_validation_images = 1
|
69 |
-
validation_ids = [aligned_id_path]
|
70 |
-
validation_hairs = [aligned_hair_path]
|
71 |
use_fp16 = False
|
|
|
|
|
72 |
|
73 |
args = Args()
|
74 |
|
75 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
76 |
|
77 |
-
# 初始化 logger
|
78 |
logging.basicConfig(
|
79 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
80 |
datefmt="%m/%d/%Y %H:%M:%S",
|
@@ -82,15 +240,17 @@ def inference(id_image, hair_image):
|
|
82 |
)
|
83 |
logger = logging.getLogger(__name__)
|
84 |
|
85 |
-
#
|
86 |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
|
87 |
revision=args.revision)
|
88 |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
|
89 |
-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",
|
90 |
-
|
91 |
|
|
|
92 |
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
93 |
|
|
|
94 |
unet2 = UNet2DConditionModel.from_pretrained(
|
95 |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
|
96 |
).to(device)
|
@@ -126,10 +286,10 @@ def inference(id_image, hair_image):
|
|
126 |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
|
127 |
device_map=None, ignore_mismatched_sizes=True
|
128 |
).to(device)
|
129 |
-
|
130 |
-
Hair_Encoder.load_state_dict(
|
131 |
|
132 |
-
#
|
133 |
log_validation(
|
134 |
vae, tokenizer, image_encoder, denoising_unet,
|
135 |
args, device, logger,
|
@@ -138,7 +298,7 @@ def inference(id_image, hair_image):
|
|
138 |
|
139 |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
140 |
|
141 |
-
#
|
142 |
frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
|
143 |
os.makedirs(frames_dir, exist_ok=True)
|
144 |
cap = cv2.VideoCapture(output_video)
|
@@ -157,34 +317,21 @@ def inference(id_image, hair_image):
|
|
157 |
max_frames = len(frames_list) if frames_list else 1
|
158 |
first_frame = frames_list[0] if frames_list else None
|
159 |
|
160 |
-
return
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
#
|
172 |
-
#
|
173 |
-
#
|
174 |
-
|
175 |
-
# gr.Image(type="filepath", label="对齐后的身份图"),
|
176 |
-
# gr.Image(type="filepath", label="对齐后的发型图"),
|
177 |
-
# gr.Image(type="filepath", label="秃头化后的身份图"),
|
178 |
-
# gr.Video(label="生成的视频")
|
179 |
-
# ],
|
180 |
-
# title="StableHairV2 多视角发型迁移",
|
181 |
-
# description="上传身份图和发型参考图,查看对齐结果并生成多视角视频"
|
182 |
-
# )
|
183 |
-
# if __name__ == "__main__":
|
184 |
-
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
185 |
-
|
186 |
-
# Blocks 美化版
|
187 |
-
css = f"""
|
188 |
html, body {{
|
189 |
height: 100%;
|
190 |
margin: 0;
|
@@ -195,10 +342,10 @@ css = f"""
|
|
195 |
height: 100% !important;
|
196 |
margin: 0 !important;
|
197 |
padding: 0 !important;
|
198 |
-
background-image: url("data:image/jpeg;base64,{
|
199 |
background-size: cover;
|
200 |
background-position: center;
|
201 |
-
background-attachment: fixed;
|
202 |
}}
|
203 |
#title-card {{
|
204 |
background: rgba(255, 255, 255, 0.8);
|
@@ -226,7 +373,6 @@ css = f"""
|
|
226 |
}}
|
227 |
.left-pane {{min-width: 360px}}
|
228 |
.right-pane {{min-width: 680px}}
|
229 |
-
/* Tabs 美化 */
|
230 |
.tabs {{
|
231 |
background: rgba(255,255,255,0.88);
|
232 |
border-radius: 12px;
|
@@ -240,31 +386,11 @@ css = f"""
|
|
240 |
border-bottom: 1px solid #e5e7eb;
|
241 |
padding-bottom: 6px;
|
242 |
}}
|
243 |
-
.tab-nav button {{
|
244 |
-
background: rgba(255,255,255,0.7);
|
245 |
-
border: 1px solid #e5e7eb;
|
246 |
-
backdrop-filter: blur(6px);
|
247 |
-
border-radius: 8px;
|
248 |
-
padding: 6px 12px;
|
249 |
-
color: #111827;
|
250 |
-
transition: all .2s ease;
|
251 |
-
}}
|
252 |
-
.tab-nav button:hover {{
|
253 |
-
transform: translateY(-1px);
|
254 |
-
box-shadow: 0 4px 10px rgba(0,0,0,0.06);
|
255 |
-
}}
|
256 |
-
.tab-nav button[aria-selected="true"] {{
|
257 |
-
background: #4f46e5;
|
258 |
-
color: #fff;
|
259 |
-
border-color: #4f46e5;
|
260 |
-
box-shadow: 0 6px 14px rgba(79,70,229,0.25);
|
261 |
-
}}
|
262 |
.tabitem {{
|
263 |
background: rgba(255,255,255,0.88);
|
264 |
border-radius: 10px;
|
265 |
padding: 8px;
|
266 |
}}
|
267 |
-
/* 发型库滚动限制容器:固定260px高度,内部可滚动 */
|
268 |
#hair_gallery_wrap {{
|
269 |
height: 260px !important;
|
270 |
overflow-y: scroll !important;
|
@@ -274,17 +400,13 @@ css = f"""
|
|
274 |
height: 100% !important;
|
275 |
overflow-y: scroll !important;
|
276 |
}}
|
277 |
-
/* 确保画廊本体占满容器高度,避免滚动条落到页面底部 */
|
278 |
#hair_gallery {{
|
279 |
height: 100% !important;
|
280 |
}}
|
281 |
"""
|
282 |
|
283 |
-
|
284 |
-
|
285 |
-
css=css
|
286 |
-
) as demo:
|
287 |
-
# ==== 顶部 Panel ====
|
288 |
with gr.Group(elem_id="title-card"):
|
289 |
gr.Markdown("""
|
290 |
<h2 id='title'>StableHairV2 多视角发型迁移</h2>
|
@@ -300,13 +422,10 @@ with gr.Blocks(
|
|
300 |
run_btn = gr.Button("开始生成", variant="primary")
|
301 |
clear_btn = gr.Button("清空")
|
302 |
|
303 |
-
# ========= 发型库(点击即填充到“发型参考图”) =========
|
304 |
def _list_imgs(dir_path: str):
|
305 |
exts = (".png", ".jpg", ".jpeg", ".webp")
|
306 |
-
# exts = (".jpg")
|
307 |
try:
|
308 |
-
files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path))
|
309 |
-
if f.lower().endswith(exts)]
|
310 |
return files
|
311 |
except Exception:
|
312 |
return []
|
@@ -315,11 +434,8 @@ with gr.Blocks(
|
|
315 |
|
316 |
with gr.Accordion("发型库(点击选择后自动填充)", open=True):
|
317 |
with gr.Group(elem_id="hair_gallery_wrap"):
|
318 |
-
gallery = gr.Gallery(
|
319 |
-
|
320 |
-
columns=4, rows=2, allow_preview=True, label="发型库",
|
321 |
-
elem_id="hair_gallery"
|
322 |
-
)
|
323 |
|
324 |
def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined]
|
325 |
i = evt.index if hasattr(evt, 'index') else 0
|
@@ -350,12 +466,11 @@ with gr.Blocks(
|
|
350 |
with gr.Group(elem_classes=["out-card"]):
|
351 |
bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260)
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
|
360 |
def _on_slide(frames, idx):
|
361 |
if not frames:
|
@@ -364,20 +479,15 @@ with gr.Blocks(
|
|
364 |
i = max(0, min(i, len(frames) - 1))
|
365 |
return gr.update(value=frames[i])
|
366 |
|
367 |
-
|
368 |
frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview)
|
369 |
|
370 |
-
|
371 |
def _clear():
|
372 |
return None, None, None, None, None
|
373 |
|
|
|
374 |
|
375 |
-
clear_btn.click(_clear, None,
|
376 |
-
[id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out])
|
377 |
|
378 |
if __name__ == "__main__":
|
379 |
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
380 |
|
381 |
|
382 |
-
|
383 |
-
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
+
import uuid
|
4 |
+
import logging
|
5 |
+
import base64
|
6 |
+
import shutil
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import torch
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
|
16 |
|
17 |
+
# -----------------------------------------------------------------------------
|
18 |
+
# Environment for HF Spaces
|
19 |
+
# -----------------------------------------------------------------------------
|
20 |
os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio")
|
21 |
os.environ.setdefault("TMPDIR", "/tmp")
|
22 |
+
os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
|
23 |
+
os.makedirs(os.environ["TMPDIR"], exist_ok=True)
|
24 |
+
|
25 |
+
|
26 |
+
# -----------------------------------------------------------------------------
|
27 |
+
# Config via environment variables (set these in your Space settings)
|
28 |
+
# -----------------------------------------------------------------------------
|
29 |
+
# Required (you uploaded these as separate model repos on HF):
|
30 |
+
# - FFHQFACEALIGNMENT_REPO (e.g., "yourname/FFHQFaceAlignment")
|
31 |
+
# - HAIRMAPPER_REPO (e.g., "yourname/HairMapper")
|
32 |
+
# - SD15_REPO (e.g., "yourname/stable-diffusion-v1-5")
|
33 |
+
# Optional:
|
34 |
+
# - TRAINED_MODEL_REPO (if you uploaded motion/control/ref ckpts as a repo)
|
35 |
+
# If TRAINED_MODEL_REPO not provided, we will try to use local "./pretrain".
|
36 |
+
FFHQFACEALIGNMENT_REPO = os.getenv("FFHQFACEALIGNMENT_REPO", "")
|
37 |
+
HAIRMAPPER_REPO = os.getenv("HAIRMAPPER_REPO", "")
|
38 |
+
SD15_REPO = os.getenv("SD15_REPO", "")
|
39 |
+
TRAINED_MODEL_REPO = os.getenv("TRAINED_MODEL_REPO", "")
|
40 |
+
|
41 |
+
|
42 |
+
# -----------------------------------------------------------------------------
|
43 |
+
# Utilities
|
44 |
+
# -----------------------------------------------------------------------------
|
45 |
+
def _ensure_symlink(src_dir: str, dst_path: str) -> str:
|
46 |
+
"""Create a directory symlink at dst_path pointing to src_dir if not exists.
|
47 |
+
If symlink creation is unavailable, fallback to copying a minimal structure.
|
48 |
+
Returns the final path that should be used by imports (dst_path if created, else src_dir).
|
49 |
+
"""
|
50 |
+
try:
|
51 |
+
if os.path.islink(dst_path) or os.path.isdir(dst_path):
|
52 |
+
return dst_path
|
53 |
+
os.symlink(src_dir, dst_path, target_is_directory=True)
|
54 |
+
return dst_path
|
55 |
+
except Exception:
|
56 |
+
# Fallback: try to create the directory and copy only top-level python files/dirs needed
|
57 |
+
try:
|
58 |
+
if not os.path.exists(dst_path):
|
59 |
+
os.makedirs(dst_path, exist_ok=True)
|
60 |
+
# Last resort: shallow copy (can still be heavy; symlink is preferred on HF Linux)
|
61 |
+
for name in os.listdir(src_dir):
|
62 |
+
src = os.path.join(src_dir, name)
|
63 |
+
dst = os.path.join(dst_path, name)
|
64 |
+
if os.path.exists(dst):
|
65 |
+
continue
|
66 |
+
if os.path.isdir(src):
|
67 |
+
shutil.copytree(src, dst)
|
68 |
+
else:
|
69 |
+
shutil.copy2(src, dst)
|
70 |
+
return dst_path
|
71 |
+
except Exception:
|
72 |
+
# Give up and return original source
|
73 |
+
return src_dir
|
74 |
+
|
75 |
+
|
76 |
+
def _find_model_root(path: str) -> str:
|
77 |
+
"""Given a snapshot path, return the directory containing model_index.json.
|
78 |
+
Handles repos that nest the folder (e.g., repo/stable-diffusion-v1-5/...).
|
79 |
+
"""
|
80 |
+
if os.path.isfile(os.path.join(path, "model_index.json")):
|
81 |
+
return path
|
82 |
+
# Search one level deep for a folder with model_index.json
|
83 |
+
for name in os.listdir(path):
|
84 |
+
cand = os.path.join(path, name)
|
85 |
+
if os.path.isdir(cand) and os.path.isfile(os.path.join(cand, "model_index.json")):
|
86 |
+
return cand
|
87 |
+
# As a fallback, return original path
|
88 |
+
return path
|
89 |
+
|
90 |
+
|
91 |
+
def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
92 |
+
"""Download HF model repos and prepare local paths.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
- sd15_path: path to the Stable Diffusion v1-5 folder (with model_index.json)
|
96 |
+
- hairmapper_dir: path to local HairMapper folder (import root)
|
97 |
+
- ffhq_dir: path to local FFHQFaceAlignment folder (import root)
|
98 |
+
"""
|
99 |
+
cache_dir = os.getenv("HF_HUB_CACHE", None)
|
100 |
+
|
101 |
+
# 1) Stable Diffusion 1.5
|
102 |
+
sd15_path = None
|
103 |
+
if SD15_REPO:
|
104 |
+
sd_snap = snapshot_download(repo_id=SD15_REPO, local_files_only=False, cache_dir=cache_dir)
|
105 |
+
sd15_path = _find_model_root(sd_snap)
|
106 |
+
|
107 |
+
# 2) HairMapper
|
108 |
+
hairmapper_dir = None
|
109 |
+
if HAIRMAPPER_REPO:
|
110 |
+
hm_snap = snapshot_download(repo_id=HAIRMAPPER_REPO, local_files_only=False, cache_dir=cache_dir)
|
111 |
+
# Create a symlink so that imports like "from HairMapper..." work
|
112 |
+
hairmapper_dir = _ensure_symlink(hm_snap, os.path.abspath("HairMapper"))
|
113 |
+
if hairmapper_dir not in sys.path:
|
114 |
+
sys.path.insert(0, hairmapper_dir)
|
115 |
+
|
116 |
+
# 3) FFHQFaceAlignment
|
117 |
+
ffhq_dir = None
|
118 |
+
if FFHQFACEALIGNMENT_REPO:
|
119 |
+
fa_snap = snapshot_download(repo_id=FFHQFACEALIGNMENT_REPO, local_files_only=False, cache_dir=cache_dir)
|
120 |
+
# Create a symlink so that test_stablehairv2._maybe_align_image("./FFHQFaceAlignment") resolves
|
121 |
+
ffhq_dir = _ensure_symlink(fa_snap, os.path.abspath("FFHQFaceAlignment"))
|
122 |
+
if ffhq_dir not in sys.path:
|
123 |
+
sys.path.insert(0, ffhq_dir)
|
124 |
+
|
125 |
+
# 4) Optional: Trained model weights (motion/control/ref)
|
126 |
+
if TRAINED_MODEL_REPO:
|
127 |
+
tm_snap = snapshot_download(repo_id=TRAINED_MODEL_REPO, local_files_only=False, cache_dir=cache_dir)
|
128 |
+
# Symlink to ./trained_model so downstream code can load from there
|
129 |
+
_ = _ensure_symlink(tm_snap, os.path.abspath("trained_model"))
|
130 |
+
|
131 |
+
return sd15_path, hairmapper_dir, ffhq_dir
|
132 |
+
|
133 |
+
|
134 |
+
# -----------------------------------------------------------------------------
|
135 |
+
# Lazy imports that rely on downloaded models/paths
|
136 |
+
# -----------------------------------------------------------------------------
|
137 |
+
def _import_inference_bits():
|
138 |
+
from test_stablehairv2 import log_validation
|
139 |
+
from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection
|
140 |
+
from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel
|
141 |
+
from test_stablehairv2 import _maybe_align_image
|
142 |
+
from HairMapper.hair_mapper_run import bald_head
|
143 |
+
return (
|
144 |
+
log_validation,
|
145 |
+
UNet3DConditionModel,
|
146 |
+
ControlNetModel,
|
147 |
+
CCProjection,
|
148 |
+
AutoTokenizer,
|
149 |
+
CLIPVisionModelWithProjection,
|
150 |
+
AutoencoderKL,
|
151 |
+
UNet2DConditionModel,
|
152 |
+
_maybe_align_image,
|
153 |
+
bald_head,
|
154 |
+
)
|
155 |
|
|
|
|
|
|
|
156 |
|
157 |
+
# -----------------------------------------------------------------------------
|
158 |
+
# Prepare models on startup
|
159 |
+
# -----------------------------------------------------------------------------
|
160 |
+
SD15_PATH, _, _ = _download_models()
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
+
# -----------------------------------------------------------------------------
|
164 |
+
# Gradio inference
|
165 |
+
# -----------------------------------------------------------------------------
|
166 |
with open("imgs/background.jpg", "rb") as f:
|
167 |
+
_b64_bg = base64.b64encode(f.read()).decode()
|
168 |
|
169 |
|
170 |
def inference(id_image, hair_image):
|
171 |
+
# Require GPU (HairMapper currently uses CUDA explicitly)
|
172 |
+
if not torch.cuda.is_available():
|
173 |
+
raise RuntimeError("This demo requires a GPU Space. Please enable a GPU in this Space.")
|
174 |
+
|
175 |
+
(
|
176 |
+
log_validation,
|
177 |
+
UNet3DConditionModel,
|
178 |
+
ControlNetModel,
|
179 |
+
CCProjection,
|
180 |
+
AutoTokenizer,
|
181 |
+
CLIPVisionModelWithProjection,
|
182 |
+
AutoencoderKL,
|
183 |
+
UNet2DConditionModel,
|
184 |
+
_maybe_align_image,
|
185 |
+
bald_head,
|
186 |
+
) = _import_inference_bits()
|
187 |
+
|
188 |
os.makedirs("gradio_inputs", exist_ok=True)
|
189 |
os.makedirs("gradio_outputs", exist_ok=True)
|
190 |
|
|
|
193 |
id_image.save(id_path)
|
194 |
hair_image.save(hair_path)
|
195 |
|
196 |
+
# Align
|
197 |
aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True)
|
198 |
aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True)
|
199 |
|
|
|
200 |
aligned_id_path = "gradio_outputs/aligned_id.png"
|
201 |
aligned_hair_path = "gradio_outputs/aligned_hair.png"
|
202 |
cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
203 |
cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR))
|
204 |
|
205 |
+
# Balding
|
206 |
bald_id_path = "gradio_outputs/bald_id.png"
|
207 |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR))
|
208 |
bald_head(bald_id_path, bald_id_path)
|
209 |
|
210 |
+
# Resolve trained model dir
|
211 |
+
trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None
|
212 |
+
if trained_model_dir is None and os.path.isdir("pretrain"):
|
213 |
+
trained_model_dir = os.path.abspath("pretrain")
|
214 |
+
if trained_model_dir is None:
|
215 |
+
raise RuntimeError("Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.")
|
216 |
+
|
217 |
class Args:
|
218 |
+
pretrained_model_name_or_path = SD15_PATH or os.path.abspath("stable-diffusion-v1-5/stable-diffusion-v1-5")
|
219 |
+
model_path = trained_model_dir
|
220 |
image_encoder = "openai/clip-vit-large-patch14"
|
221 |
controlnet_model_name_or_path = None
|
222 |
revision = None
|
223 |
output_dir = "gradio_outputs"
|
224 |
seed = 42
|
225 |
num_validation_images = 1
|
226 |
+
validation_ids = [aligned_id_path]
|
227 |
+
validation_hairs = [aligned_hair_path]
|
228 |
use_fp16 = False
|
229 |
+
align_before_infer = True
|
230 |
+
align_size = 1024
|
231 |
|
232 |
args = Args()
|
233 |
|
234 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
235 |
|
|
|
236 |
logging.basicConfig(
|
237 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
238 |
datefmt="%m/%d/%Y %H:%M:%S",
|
|
|
240 |
)
|
241 |
logger = logging.getLogger(__name__)
|
242 |
|
243 |
+
# Load tokenizer/encoders/vae
|
244 |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",
|
245 |
revision=args.revision)
|
246 |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device)
|
247 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",
|
248 |
+
revision=args.revision).to(device, dtype=torch.float32)
|
249 |
|
250 |
+
from omegaconf import OmegaConf
|
251 |
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml')
|
252 |
|
253 |
+
# UNet2D with 8-channel conv_in
|
254 |
unet2 = UNet2DConditionModel.from_pretrained(
|
255 |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32
|
256 |
).to(device)
|
|
|
286 |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False,
|
287 |
device_map=None, ignore_mismatched_sizes=True
|
288 |
).to(device)
|
289 |
+
state_dict4 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu")
|
290 |
+
Hair_Encoder.load_state_dict(state_dict4, strict=False)
|
291 |
|
292 |
+
# Run inference
|
293 |
log_validation(
|
294 |
vae, tokenizer, image_encoder, denoising_unet,
|
295 |
args, device, logger,
|
|
|
298 |
|
299 |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4")
|
300 |
|
301 |
+
# Extract frames for slider preview
|
302 |
frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex)
|
303 |
os.makedirs(frames_dir, exist_ok=True)
|
304 |
cap = cv2.VideoCapture(output_video)
|
|
|
317 |
max_frames = len(frames_list) if frames_list else 1
|
318 |
first_frame = frames_list[0] if frames_list else None
|
319 |
|
320 |
+
return (
|
321 |
+
aligned_id_path,
|
322 |
+
aligned_hair_path,
|
323 |
+
bald_id_path,
|
324 |
+
output_video,
|
325 |
+
frames_list,
|
326 |
+
gr.update(minimum=1, maximum=max_frames, value=1, step=1),
|
327 |
+
first_frame,
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
# -----------------------------------------------------------------------------
|
332 |
+
# UI (Blocks)
|
333 |
+
# -----------------------------------------------------------------------------
|
334 |
+
CSS = f"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
html, body {{
|
336 |
height: 100%;
|
337 |
margin: 0;
|
|
|
342 |
height: 100% !important;
|
343 |
margin: 0 !important;
|
344 |
padding: 0 !important;
|
345 |
+
background-image: url("data:image/jpeg;base64,{_b64_bg}");
|
346 |
background-size: cover;
|
347 |
background-position: center;
|
348 |
+
background-attachment: fixed;
|
349 |
}}
|
350 |
#title-card {{
|
351 |
background: rgba(255, 255, 255, 0.8);
|
|
|
373 |
}}
|
374 |
.left-pane {{min-width: 360px}}
|
375 |
.right-pane {{min-width: 680px}}
|
|
|
376 |
.tabs {{
|
377 |
background: rgba(255,255,255,0.88);
|
378 |
border-radius: 12px;
|
|
|
386 |
border-bottom: 1px solid #e5e7eb;
|
387 |
padding-bottom: 6px;
|
388 |
}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
.tabitem {{
|
390 |
background: rgba(255,255,255,0.88);
|
391 |
border-radius: 10px;
|
392 |
padding: 8px;
|
393 |
}}
|
|
|
394 |
#hair_gallery_wrap {{
|
395 |
height: 260px !important;
|
396 |
overflow-y: scroll !important;
|
|
|
400 |
height: 100% !important;
|
401 |
overflow-y: scroll !important;
|
402 |
}}
|
|
|
403 |
#hair_gallery {{
|
404 |
height: 100% !important;
|
405 |
}}
|
406 |
"""
|
407 |
|
408 |
+
|
409 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=CSS) as demo:
|
|
|
|
|
|
|
410 |
with gr.Group(elem_id="title-card"):
|
411 |
gr.Markdown("""
|
412 |
<h2 id='title'>StableHairV2 多视角发型迁移</h2>
|
|
|
422 |
run_btn = gr.Button("开始生成", variant="primary")
|
423 |
clear_btn = gr.Button("清空")
|
424 |
|
|
|
425 |
def _list_imgs(dir_path: str):
|
426 |
exts = (".png", ".jpg", ".jpeg", ".webp")
|
|
|
427 |
try:
|
428 |
+
files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path)) if f.lower().endswith(exts)]
|
|
|
429 |
return files
|
430 |
except Exception:
|
431 |
return []
|
|
|
434 |
|
435 |
with gr.Accordion("发型库(点击选择后自动填充)", open=True):
|
436 |
with gr.Group(elem_id="hair_gallery_wrap"):
|
437 |
+
gallery = gr.Gallery(value=hair_list, columns=4, rows=2, allow_preview=True, label="发型库",
|
438 |
+
elem_id="hair_gallery")
|
|
|
|
|
|
|
439 |
|
440 |
def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined]
|
441 |
i = evt.index if hasattr(evt, 'index') else 0
|
|
|
466 |
with gr.Group(elem_classes=["out-card"]):
|
467 |
bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260)
|
468 |
|
469 |
+
run_btn.click(
|
470 |
+
fn=inference,
|
471 |
+
inputs=[id_input, hair_input],
|
472 |
+
outputs=[aligned_id_out, aligned_hair_out, bald_id_out, video_out, frames_state, frame_slider, frame_preview],
|
473 |
+
)
|
|
|
474 |
|
475 |
def _on_slide(frames, idx):
|
476 |
if not frames:
|
|
|
479 |
i = max(0, min(i, len(frames) - 1))
|
480 |
return gr.update(value=frames[i])
|
481 |
|
|
|
482 |
frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview)
|
483 |
|
|
|
484 |
def _clear():
|
485 |
return None, None, None, None, None
|
486 |
|
487 |
+
clear_btn.click(_clear, None, [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out])
|
488 |
|
|
|
|
|
489 |
|
490 |
if __name__ == "__main__":
|
491 |
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
492 |
|
493 |
|
|
|
|