Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import uuid | |
import logging | |
import base64 | |
import shutil | |
from typing import Optional, Tuple | |
import gradio as gr | |
import spaces | |
import torch | |
import cv2 | |
import numpy as np | |
import time | |
from huggingface_hub import snapshot_download | |
# ----------------------------------------------------------------------------- | |
# Environment for HF Spaces | |
# ----------------------------------------------------------------------------- | |
os.environ.setdefault("GRADIO_TEMP_DIR", "/tmp/gradio") | |
os.environ.setdefault("TMPDIR", "/tmp") | |
os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True) | |
os.makedirs(os.environ["TMPDIR"], exist_ok=True) | |
# ----------------------------------------------------------------------------- | |
# Config via environment variables (set these in your Space settings) | |
# ----------------------------------------------------------------------------- | |
# Required (you uploaded these as separate model repos on HF): | |
# - FFHQFACEALIGNMENT_REPO (e.g., "yourname/FFHQFaceAlignment") | |
# - HAIRMAPPER_REPO (e.g., "yourname/HairMapper") | |
# - SD15_REPO (e.g., "yourname/stable-diffusion-v1-5") | |
# Optional: | |
# - TRAINED_MODEL_REPO (if you uploaded motion/control/ref ckpts as a repo) | |
# If TRAINED_MODEL_REPO not provided, we will try to use local "./pretrain". | |
FFHQFACEALIGNMENT_REPO = os.getenv("FFHQFACEALIGNMENT_REPO", "") | |
HAIRMAPPER_REPO = os.getenv("HAIRMAPPER_REPO", "") | |
SD15_REPO = os.getenv("SD15_REPO", "") | |
TRAINED_MODEL_REPO = os.getenv("TRAINED_MODEL_REPO", "") | |
# 优先读取官方变量名,其次兼容 HF_TOKEN | |
HF_AUTH_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN") | |
# 需要的权重文件清单 | |
REQUIRED_WEIGHT_FILENAMES = [ | |
"pytorch_model.bin", | |
"motion_module-4140000.pth", | |
"pytorch_model_1.bin", | |
"pytorch_model_2.bin", | |
] | |
# ----------------------------------------------------------------------------- | |
# Utilities | |
# ----------------------------------------------------------------------------- | |
def _ensure_symlink(src_dir: str, dst_path: str) -> str: | |
"""Create a directory symlink at dst_path pointing to src_dir if not exists. | |
If symlink creation is unavailable, fallback to copying a minimal structure. | |
Returns the final path that should be used by imports (dst_path if created, else src_dir). | |
""" | |
try: | |
if os.path.islink(dst_path) or os.path.isdir(dst_path): | |
return dst_path | |
os.symlink(src_dir, dst_path, target_is_directory=True) | |
return dst_path | |
except Exception: | |
# Fallback: try to create the directory and copy only top-level python files/dirs needed | |
try: | |
if not os.path.exists(dst_path): | |
os.makedirs(dst_path, exist_ok=True) | |
# Last resort: shallow copy (can still be heavy; symlink is preferred on HF Linux) | |
for name in os.listdir(src_dir): | |
src = os.path.join(src_dir, name) | |
dst = os.path.join(dst_path, name) | |
if os.path.exists(dst): | |
continue | |
if os.path.isdir(src): | |
shutil.copytree(src, dst) | |
else: | |
shutil.copy2(src, dst) | |
return dst_path | |
except Exception: | |
# Give up and return original source | |
return src_dir | |
def _find_model_root(path: str) -> str: | |
"""Given a snapshot path, return the directory containing model_index.json. | |
Handles repos that nest the folder (e.g., repo/stable-diffusion-v1-5/...). | |
""" | |
if os.path.isfile(os.path.join(path, "model_index.json")): | |
return path | |
# Search one level deep for a folder with model_index.json | |
for name in os.listdir(path): | |
cand = os.path.join(path, name) | |
if os.path.isdir(cand) and os.path.isfile(os.path.join(cand, "model_index.json")): | |
return cand | |
# As a fallback, return original path | |
return path | |
def _download_models() -> Tuple[Optional[str], Optional[str], Optional[str]]: | |
"""Download HF model repos and prepare local paths. | |
Returns: | |
- sd15_path: path to the Stable Diffusion v1-5 folder (with model_index.json) | |
- hairmapper_dir: path to local HairMapper folder (import root) | |
- ffhq_dir: path to local FFHQFaceAlignment folder (import root) | |
""" | |
cache_dir = os.getenv("HF_HUB_CACHE", None) | |
# 1) Stable Diffusion 1.5 | |
sd15_path = None | |
if SD15_REPO: | |
sd_snap = snapshot_download( | |
repo_id=SD15_REPO, | |
local_files_only=False, | |
cache_dir=cache_dir, | |
token=HF_AUTH_TOKEN, | |
) | |
sd15_path = _find_model_root(sd_snap) | |
# 2) HairMapper | |
hairmapper_dir = None | |
if HAIRMAPPER_REPO: | |
hm_snap = snapshot_download( | |
repo_id=HAIRMAPPER_REPO, | |
local_files_only=False, | |
cache_dir=cache_dir, | |
token=HF_AUTH_TOKEN, | |
) | |
# If repo root contains a nested "HairMapper" folder, link to that subfolder. | |
hm_src = hm_snap | |
nested_hm = os.path.join(hm_snap, "HairMapper") | |
if os.path.isdir(nested_hm) and ( | |
os.path.isfile(os.path.join(nested_hm, "hair_mapper_run.py")) or | |
os.path.isdir(os.path.join(nested_hm, "mapper")) | |
): | |
hm_src = nested_hm | |
# Create a symlink so that imports like "from HairMapper..." work | |
hairmapper_dir = _ensure_symlink(hm_src, os.path.abspath("HairMapper")) | |
if hairmapper_dir not in sys.path: | |
sys.path.insert(0, os.path.dirname(hairmapper_dir)) | |
# 3) FFHQFaceAlignment | |
ffhq_dir = None | |
if FFHQFACEALIGNMENT_REPO: | |
fa_snap = snapshot_download( | |
repo_id=FFHQFACEALIGNMENT_REPO, | |
local_files_only=False, | |
cache_dir=cache_dir, | |
token=HF_AUTH_TOKEN, | |
) | |
# If repo root contains a nested "FFHQFaceAlignment" folder, link to that subfolder. | |
fa_src = fa_snap | |
nested_fa = os.path.join(fa_snap, "FFHQFaceAlignment") | |
if os.path.isdir(nested_fa) and ( | |
os.path.isfile(os.path.join(nested_fa, "align.py")) or | |
os.path.isdir(os.path.join(nested_fa, "lib")) | |
): | |
fa_src = nested_fa | |
# Create a symlink so that _maybe_align_image can import modules relatively | |
ffhq_dir = _ensure_symlink(fa_src, os.path.abspath("FFHQFaceAlignment")) | |
if ffhq_dir not in sys.path: | |
sys.path.insert(0, os.path.dirname(ffhq_dir)) | |
# 4) Optional: Trained model weights (motion/control/ref) | |
if TRAINED_MODEL_REPO: | |
tm_snap = snapshot_download( | |
repo_id=TRAINED_MODEL_REPO, | |
local_files_only=False, | |
cache_dir=cache_dir, | |
token=HF_AUTH_TOKEN, | |
) | |
# Symlink to ./trained_model so downstream code can load from there | |
tm_linked = _ensure_symlink(tm_snap, os.path.abspath("trained_model")) | |
# If the repo contains a nested pretrain/ folder, also expose it at ./pretrain | |
nested_pretrain = os.path.join(tm_linked, "pretrain") | |
if os.path.isdir(nested_pretrain): | |
_ensure_symlink(nested_pretrain, os.path.abspath("pretrain")) | |
return sd15_path, hairmapper_dir, ffhq_dir | |
# ----------------------------------------------------------------------------- | |
# Lazy imports that rely on downloaded models/paths | |
# ----------------------------------------------------------------------------- | |
def _import_inference_bits(): | |
from test_stablehairv2 import log_validation | |
from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection | |
from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel | |
from test_stablehairv2 import _maybe_align_image | |
from HairMapper.hair_mapper_run import bald_head | |
return ( | |
log_validation, | |
UNet3DConditionModel, | |
ControlNetModel, | |
CCProjection, | |
AutoTokenizer, | |
CLIPVisionModelWithProjection, | |
AutoencoderKL, | |
UNet2DConditionModel, | |
_maybe_align_image, | |
bald_head, | |
) | |
# ----------------------------------------------------------------------------- | |
# Prepare models on startup | |
# ----------------------------------------------------------------------------- | |
SD15_PATH, _, _ = _download_models() | |
# ----------------------------------------------------------------------------- | |
# Global model loading (CPU) so GPU task only does inference | |
# ----------------------------------------------------------------------------- | |
def _has_all_weights(dir_path: str) -> bool: | |
return all(os.path.isfile(os.path.join(dir_path, name)) for name in REQUIRED_WEIGHT_FILENAMES) | |
def _resolve_trained_model_dir() -> str: | |
pretrain_dir = os.path.abspath("pretrain") if os.path.isdir("pretrain") else None | |
trained_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None | |
trained_dir_nested = os.path.join(trained_dir, "pretrain") if trained_dir else None | |
# 优先使用 pretrain(你已说明文件在此),并校验文件齐全 | |
if pretrain_dir and _has_all_weights(pretrain_dir): | |
return pretrain_dir | |
# 其次尝试 trained_model,并校验文件齐全 | |
if trained_dir and _has_all_weights(trained_dir): | |
return trained_dir | |
# 再尝试 trained_model/pretrain 子目录 | |
if trained_dir_nested and os.path.isdir(trained_dir_nested) and _has_all_weights(trained_dir_nested): | |
return trained_dir_nested | |
# 构造更友好的报错信息 | |
def _missing_list(dir_path: str) -> str: | |
if not dir_path: | |
return "目录不存在" | |
missing = [n for n in REQUIRED_WEIGHT_FILENAMES if not os.path.isfile(os.path.join(dir_path, n))] | |
if not missing: | |
return "文件齐全" | |
return "缺少: " + ", ".join(missing) | |
msg = ( | |
"Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.\n" | |
f"pretrain 状态: {_missing_list(pretrain_dir)}\n" | |
f"trained_model 状态: {_missing_list(trained_dir)}\n" | |
f"trained_model/pretrain 状态: {_missing_list(trained_dir_nested)}" | |
) | |
raise RuntimeError(msg) | |
# Lazy globals | |
G_ARGS = None | |
G_INFER_CONFIG = None | |
G_TOKENIZER = None | |
G_IMAGE_ENCODER = None | |
G_VAE = None | |
G_UNET2 = None | |
G_CONTROLNET = None | |
G_DENOISING_UNET = None | |
G_CC_PROJ = None | |
G_HAIR_ENCODER = None | |
def _load_models_cpu_once(): | |
global G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE | |
global G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER | |
if all(x is not None for x in ( | |
G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE, | |
G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER | |
)): | |
return | |
class _Args: | |
pretrained_model_name_or_path = SD15_PATH or os.path.abspath("stable-diffusion-v1-5/stable-diffusion-v1-5") | |
model_path = _resolve_trained_model_dir() | |
image_encoder = "openai/clip-vit-large-patch14" | |
controlnet_model_name_or_path = None | |
revision = None | |
output_dir = "gradio_outputs" | |
seed = 42 | |
num_validation_images = 1 | |
validation_ids = [] | |
validation_hairs = [] | |
use_fp16 = False | |
align_before_infer = True | |
align_size = 1024 | |
G_ARGS = _Args() | |
# Import heavy libs only here | |
from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel | |
from test_stablehairv2 import UNet3DConditionModel, CCProjection, ControlNetModel | |
from omegaconf import OmegaConf | |
# Config | |
t0 = time.perf_counter() | |
t = time.perf_counter() | |
G_INFER_CONFIG = OmegaConf.load('./configs/inference/inference_v2.yaml') | |
print(f"[timing:init] load infer config: {time.perf_counter()-t:.2f}s", flush=True) | |
# Tokenizer / encoders / vae (CPU) | |
t = time.perf_counter() | |
G_TOKENIZER = AutoTokenizer.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="tokenizer", | |
revision=G_ARGS.revision) | |
print(f"[timing:init] tokenizer: {time.perf_counter()-t:.2f}s", flush=True) | |
t = time.perf_counter() | |
G_IMAGE_ENCODER = CLIPVisionModelWithProjection.from_pretrained(G_ARGS.image_encoder, revision=G_ARGS.revision) | |
print(f"[timing:init] image_encoder: {time.perf_counter()-t:.2f}s", flush=True) | |
t = time.perf_counter() | |
G_VAE = AutoencoderKL.from_pretrained(G_ARGS.pretrained_model_name_or_path, subfolder="vae", | |
revision=G_ARGS.revision) | |
print(f"[timing:init] vae: {time.perf_counter()-t:.2f}s", flush=True) | |
# UNet2D with 8-channel conv_in (CPU) | |
t = time.perf_counter() | |
G_UNET2 = UNet2DConditionModel.from_pretrained( | |
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, torch_dtype=torch.float32 | |
) | |
conv_in_8 = torch.nn.Conv2d(8, G_UNET2.conv_in.out_channels, kernel_size=G_UNET2.conv_in.kernel_size, | |
padding=G_UNET2.conv_in.padding) | |
conv_in_8.requires_grad_(False) | |
G_UNET2.conv_in.requires_grad_(False) | |
torch.nn.init.zeros_(conv_in_8.weight) | |
conv_in_8.weight[:, :4, :, :].copy_(G_UNET2.conv_in.weight) | |
conv_in_8.bias.copy_(G_UNET2.conv_in.bias) | |
G_UNET2.conv_in = conv_in_8 | |
print(f"[timing:init] unet2 + conv_in adapt: {time.perf_counter()-t:.2f}s", flush=True) | |
# ControlNet (CPU) | |
t = time.perf_counter() | |
G_CONTROLNET = ControlNetModel.from_unet(G_UNET2) | |
state_dict2 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model.bin"), map_location="cpu") | |
G_CONTROLNET.load_state_dict(state_dict2, strict=False) | |
print(f"[timing:init] controlnet load_state: {time.perf_counter()-t:.2f}s", flush=True) | |
# UNet3D (CPU) | |
t = time.perf_counter() | |
prefix = "motion_module" | |
ckpt_num = "4140000" | |
save_path = os.path.join(G_ARGS.model_path, f"{prefix}-{ckpt_num}.pth") | |
G_DENOISING_UNET = UNet3DConditionModel.from_pretrained_2d( | |
G_ARGS.pretrained_model_name_or_path, | |
save_path, | |
subfolder="unet", | |
unet_additional_kwargs=G_INFER_CONFIG.unet_additional_kwargs, | |
) | |
print(f"[timing:init] unet3d from_pretrained_2d: {time.perf_counter()-t:.2f}s", flush=True) | |
# CC projection (CPU) | |
t = time.perf_counter() | |
G_CC_PROJ = CCProjection() | |
state_dict3 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_1.bin"), map_location="cpu") | |
G_CC_PROJ.load_state_dict(state_dict3, strict=False) | |
print(f"[timing:init] cc_projection load_state: {time.perf_counter()-t:.2f}s", flush=True) | |
# Hair encoder (CPU) | |
t = time.perf_counter() | |
from ref_encoder.reference_unet import ref_unet | |
G_HAIR_ENCODER = ref_unet.from_pretrained( | |
G_ARGS.pretrained_model_name_or_path, subfolder="unet", revision=G_ARGS.revision, low_cpu_mem_usage=False, | |
device_map=None, ignore_mismatched_sizes=True | |
) | |
state_dict4 = torch.load(os.path.join(G_ARGS.model_path, "pytorch_model_2.bin"), map_location="cpu") | |
G_HAIR_ENCODER.load_state_dict(state_dict4, strict=False) | |
print(f"[timing:init] hair_encoder load_state: {time.perf_counter()-t:.2f}s", flush=True) | |
print(f"[timing:init] total preload: {time.perf_counter()-t0:.2f}s", flush=True) | |
try: | |
_load_models_cpu_once() | |
except Exception as _e: | |
print(f"[init] Model preload warning: {_e}", flush=True) | |
def _ensure_models_loaded(): | |
"""Ensure global models are loaded on CPU. If missing, try to load now; otherwise raise with hint.""" | |
global G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE | |
global G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER | |
if any(x is None for x in ( | |
G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE, | |
G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER | |
)): | |
print("[inference] Detected unloaded models. Loading on CPU...", flush=True) | |
_load_models_cpu_once() | |
if any(x is None for x in ( | |
G_ARGS, G_INFER_CONFIG, G_TOKENIZER, G_IMAGE_ENCODER, G_VAE, | |
G_UNET2, G_CONTROLNET, G_DENOISING_UNET, G_CC_PROJ, G_HAIR_ENCODER | |
)): | |
raise RuntimeError( | |
"Models failed to load. Check SD15_REPO (must be a valid SD1.5 repo) and weights in ./pretrain or TRAINED_MODEL_REPO." | |
) | |
# ----------------------------------------------------------------------------- | |
# Gradio inference | |
# ----------------------------------------------------------------------------- | |
with open("imgs/background.png", "rb") as f: | |
_b64_bg = base64.b64encode(f.read()).decode() | |
def inference(id_image, hair_image): | |
# ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。 | |
device = torch.device("cuda") | |
t_total = time.perf_counter() | |
# 确保全局模型已加载 | |
_ensure_models_loaded() | |
# 导入依赖(轻量函数,不再加载大模型) | |
( | |
log_validation, | |
UNet3DConditionModel, | |
ControlNetModel, | |
CCProjection, | |
AutoTokenizer, | |
CLIPVisionModelWithProjection, | |
AutoencoderKL, | |
UNet2DConditionModel, | |
_maybe_align_image, | |
bald_head, | |
) = _import_inference_bits() | |
os.makedirs("gradio_inputs", exist_ok=True) | |
os.makedirs("gradio_outputs", exist_ok=True) | |
id_path = "gradio_inputs/id.png" | |
hair_path = "gradio_inputs/hair.png" | |
id_image.save(id_path) | |
hair_image.save(hair_path) | |
# Align | |
t = time.perf_counter() | |
aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True) | |
aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True) | |
print(f"[timing] align total: {time.perf_counter()-t:.2f}s", flush=True) | |
aligned_id_path = "gradio_outputs/aligned_id.png" | |
aligned_hair_path = "gradio_outputs/aligned_hair.png" | |
cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR)) | |
cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR)) | |
# Balding | |
t = time.perf_counter() | |
bald_id_path = "gradio_outputs/bald_id.png" | |
cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR)) | |
bald_head(bald_id_path, bald_id_path) | |
print(f"[timing] bald_head: {time.perf_counter()-t:.2f}s", flush=True) | |
# Resolve trained model dir | |
trained_model_dir = os.path.abspath("trained_model") if os.path.isdir("trained_model") else None | |
if trained_model_dir is None and os.path.isdir("pretrain"): | |
trained_model_dir = os.path.abspath("pretrain") | |
if trained_model_dir is None: | |
raise RuntimeError("Missing trained model weights. Provide TRAINED_MODEL_REPO or include ./pretrain.") | |
class Args: | |
pretrained_model_name_or_path = SD15_PATH or os.path.abspath("stable-diffusion-v1-5/stable-diffusion-v1-5") | |
model_path = trained_model_dir | |
image_encoder = "openai/clip-vit-large-patch14" | |
controlnet_model_name_or_path = None | |
revision = None | |
output_dir = "gradio_outputs" | |
seed = 42 | |
num_validation_images = 1 | |
validation_ids = [aligned_id_path] | |
validation_hairs = [aligned_hair_path] | |
use_fp16 = False | |
align_before_infer = True | |
align_size = 1024 | |
args = Args() | |
device = torch.device("cuda") | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
# 将已加载的全局模型迁移到 GPU | |
t = time.perf_counter() | |
tokenizer = G_TOKENIZER | |
image_encoder = G_IMAGE_ENCODER.to(device) | |
vae = G_VAE.to(device, dtype=torch.float32) | |
unet2 = G_UNET2.to(device) | |
controlnet = G_CONTROLNET.to(device) | |
denoising_unet = G_DENOISING_UNET.to(device) | |
cc_projection = G_CC_PROJ.to(device) | |
Hair_Encoder = G_HAIR_ENCODER.to(device) | |
print(f"[timing] move models to cuda: {time.perf_counter()-t:.2f}s", flush=True) | |
# Run inference | |
t = time.perf_counter() | |
log_validation( | |
vae, tokenizer, image_encoder, denoising_unet, | |
args, device, logger, | |
cc_projection, controlnet, Hair_Encoder | |
) | |
print(f"[timing] sd pipeline (log_validation): {time.perf_counter()-t:.2f}s", flush=True) | |
output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4") | |
# Extract frames for slider preview | |
t = time.perf_counter() | |
frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex) | |
os.makedirs(frames_dir, exist_ok=True) | |
cap = cv2.VideoCapture(output_video) | |
frames_list = [] | |
idx = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
fp = os.path.join(frames_dir, f"{idx:03d}.png") | |
cv2.imwrite(fp, frame) | |
frames_list.append(fp) | |
idx += 1 | |
cap.release() | |
print(f"[timing] extract frames: {time.perf_counter()-t:.2f}s", flush=True) | |
print(f"[timing] total inference: {time.perf_counter()-t_total:.2f}s", flush=True) | |
max_frames = len(frames_list) if frames_list else 1 | |
first_frame = frames_list[0] if frames_list else None | |
return ( | |
aligned_id_path, | |
aligned_hair_path, | |
bald_id_path, | |
output_video, | |
frames_list, | |
gr.update(minimum=1, maximum=max_frames, value=1, step=1), | |
first_frame, | |
) | |
# ----------------------------------------------------------------------------- | |
# UI (Blocks) | |
# ----------------------------------------------------------------------------- | |
CSS = f""" | |
html, body {{ | |
height: 100%; | |
margin: 0; | |
padding: 0; | |
}} | |
.gradio-container {{ | |
width: 100% !important; | |
height: 100% !important; | |
margin: 0 !important; | |
padding: 0 !important; | |
background-image: url("data:image/png;base64,{_b64_bg}"); | |
background-size: cover; | |
background-position: center; | |
background-attachment: fixed; | |
}} | |
#title-card {{ | |
background: rgba(255, 255, 255, 0.8); | |
border-radius: 12px; | |
padding: 16px 24px; | |
box-shadow: 0 2px 8px rgba(0,0,0,0.15); | |
margin-bottom: 20px; | |
}} | |
#title-card h2 {{ | |
text-align: center; | |
margin: 4px 0 12px 0; | |
font-size: 28px; | |
}} | |
#title-card p {{ | |
text-align: center; | |
font-size: 16px; | |
color: #374151; | |
}} | |
.out-card {{ | |
border:1px solid #e5e7eb; border-radius:10px; padding:10px; | |
background: rgba(255,255,255,0.85); | |
}} | |
.two-col {{ | |
display:grid !important; grid-template-columns: 360px minmax(680px, 1fr); gap:16px | |
}} | |
.left-pane {{min-width: 360px}} | |
.right-pane {{min-width: 680px}} | |
.tabs {{ | |
background: rgba(255,255,255,0.88); | |
border-radius: 12px; | |
box-shadow: 0 8px 24px rgba(0,0,0,0.08); | |
padding: 8px; | |
border: 1px solid #e5e7eb; | |
}} | |
.tab-nav {{ | |
display: flex; gap: 8px; margin-bottom: 8px; | |
background: transparent; | |
border-bottom: 1px solid #e5e7eb; | |
padding-bottom: 6px; | |
}} | |
.tabitem {{ | |
background: rgba(255,255,255,0.88); | |
border-radius: 10px; | |
padding: 8px; | |
}} | |
#hair_gallery_wrap {{ | |
height: 260px !important; | |
overflow-y: scroll !important; | |
overflow-x: auto !important; | |
}} | |
#hair_gallery_wrap .grid, #hair_gallery_wrap .wrap {{ | |
height: 100% !important; | |
overflow-y: scroll !important; | |
}} | |
#hair_gallery {{ | |
height: 100% !important; | |
}} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=CSS) as demo: | |
with gr.Group(elem_id="title-card"): | |
gr.Markdown(""" | |
<h2 id='title'>StableHairV2 多视角发型迁移</h2> | |
<p>上传身份图与发型参考图,系统将自动完成 <b>对齐 → 秃头化 → 视频生成</b>。</p> | |
""") | |
with gr.Row(elem_classes=["two-col"]): | |
with gr.Column(scale=5, min_width=260, elem_classes=["left-pane"]): | |
id_input = gr.Image(type="pil", label="身份图", height=200) | |
hair_input = gr.Image(type="pil", label="发型参考图", height=200) | |
with gr.Row(): | |
run_btn = gr.Button("开始生成", variant="primary") | |
clear_btn = gr.Button("清空") | |
def _list_imgs(dir_path: str): | |
exts = (".png", ".jpg", ".jpeg", ".webp") | |
try: | |
files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path)) if f.lower().endswith(exts)] | |
return files | |
except Exception: | |
return [] | |
hair_list = _list_imgs("hair_resposity") | |
with gr.Accordion("发型库(点击选择后自动填充)", open=True): | |
with gr.Group(elem_id="hair_gallery_wrap"): | |
gallery = gr.Gallery(value=hair_list, columns=4, rows=2, allow_preview=True, label="发型库", | |
elem_id="hair_gallery") | |
def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined] | |
i = evt.index if hasattr(evt, 'index') else 0 | |
i = 0 if i is None else int(i) | |
if 0 <= i < len(hair_list): | |
return gr.update(value=hair_list[i]) | |
return gr.update() | |
gallery.select(_pick_hair, inputs=None, outputs=hair_input) | |
with gr.Column(scale=7, min_width=520, elem_classes=["right-pane"]): | |
with gr.Tabs(): | |
with gr.TabItem("生成视频"): | |
with gr.Group(elem_classes=["out-card"]): | |
video_out = gr.Video(label="生成的视频", height=340) | |
with gr.Row(): | |
frame_slider = gr.Slider(1, 21, value=1, step=1, label="多视角预览(拖动查看帧)") | |
frame_preview = gr.Image(type="filepath", label="预览帧", height=260) | |
frames_state = gr.State([]) | |
with gr.TabItem("归一化对齐结果"): | |
with gr.Group(elem_classes=["out-card"]): | |
with gr.Row(): | |
aligned_id_out = gr.Image(type="filepath", label="对齐后的身份图", height=240) | |
aligned_hair_out = gr.Image(type="filepath", label="对齐后的发型图", height=240) | |
with gr.TabItem("秃头化结果"): | |
with gr.Group(elem_classes=["out-card"]): | |
bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260) | |
run_btn.click( | |
fn=inference, | |
inputs=[id_input, hair_input], | |
outputs=[aligned_id_out, aligned_hair_out, bald_id_out, video_out, frames_state, frame_slider, frame_preview], | |
) | |
def _on_slide(frames, idx): | |
if not frames: | |
return gr.update() | |
i = int(idx) - 1 | |
i = max(0, min(i, len(frames) - 1)) | |
return gr.update(value=frames[i]) | |
frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview) | |
def _clear(): | |
return None, None, None, None, None | |
clear_btn.click(_clear, None, [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out]) | |
if __name__ == "__main__": | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |