import os os.environ.setdefault("GRADIO_TEMP_DIR", "/data2/lzliu/tmp/gradio") os.environ.setdefault("TMPDIR", "/data2/lzliu/tmp") os.makedirs("/data2/lzliu/tmp/gradio", exist_ok=True) os.makedirs("/data2/lzliu/tmp", exist_ok=True) # 其余保持不变 import logging import gradio as gr import torch import os import uuid from test_stablehairv2 import log_validation from test_stablehairv2 import UNet3DConditionModel, ControlNetModel, CCProjection from test_stablehairv2 import AutoTokenizer, CLIPVisionModelWithProjection, AutoencoderKL, UNet2DConditionModel from omegaconf import OmegaConf import numpy as np import cv2 from test_stablehairv2 import _maybe_align_image from HairMapper.hair_mapper_run import bald_head import base64 with open("imgs/background.jpg", "rb") as f: b64_img = base64.b64encode(f.read()).decode() def inference(id_image, hair_image): os.makedirs("gradio_inputs", exist_ok=True) os.makedirs("gradio_outputs", exist_ok=True) id_path = "gradio_inputs/id.png" hair_path = "gradio_inputs/hair.png" id_image.save(id_path) hair_image.save(hair_path) # ===== 图像对齐 ===== aligned_id = _maybe_align_image(id_path, output_size=1024, prefer_cuda=True) aligned_hair = _maybe_align_image(hair_path, output_size=1024, prefer_cuda=True) # 保存对齐结果(方便 Gradio 输出) aligned_id_path = "gradio_outputs/aligned_id.png" aligned_hair_path = "gradio_outputs/aligned_hair.png" cv2.imwrite(aligned_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR)) cv2.imwrite(aligned_hair_path, cv2.cvtColor(aligned_hair, cv2.COLOR_RGB2BGR)) # ===== 调用 HairMapper 秃头化 ===== bald_id_path = "gradio_outputs/bald_id.png" cv2.imwrite(bald_id_path, cv2.cvtColor(aligned_id, cv2.COLOR_RGB2BGR)) bald_head(bald_id_path, bald_id_path) # ===== 原本的 Args ===== class Args: pretrained_model_name_or_path = "./stable-diffusion-v1-5/stable-diffusion-v1-5" model_path = "./trained_model" image_encoder = "openai/clip-vit-large-patch14" controlnet_model_name_or_path = None revision = None output_dir = "gradio_outputs" seed = 42 num_validation_images = 1 validation_ids = [aligned_id_path] # 用对齐后的图像 validation_hairs = [aligned_hair_path] # 用对齐后的图像 use_fp16 = False args = Args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 初始化 logger logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) # ===== 模型加载(和 main() 对齐) ===== tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision) image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder, revision=args.revision).to(device) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to( device, dtype=torch.float32) infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml') unet2 = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, torch_dtype=torch.float32 ).to(device) conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size, padding=unet2.conv_in.padding) conv_in_8.requires_grad_(False) unet2.conv_in.requires_grad_(False) torch.nn.init.zeros_(conv_in_8.weight) conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight) conv_in_8.bias.copy_(unet2.conv_in.bias) unet2.conv_in = conv_in_8 controlnet = ControlNetModel.from_unet(unet2).to(device) state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location="cpu") controlnet.load_state_dict(state_dict2, strict=False) prefix = "motion_module" ckpt_num = "4140000" save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth") denoising_unet = UNet3DConditionModel.from_pretrained_2d( args.pretrained_model_name_or_path, save_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(device) cc_projection = CCProjection().to(device) state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location="cpu") cc_projection.load_state_dict(state_dict3, strict=False) from ref_encoder.reference_unet import ref_unet Hair_Encoder = ref_unet.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False, device_map=None, ignore_mismatched_sizes=True ).to(device) state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location="cpu") Hair_Encoder.load_state_dict(state_dict2, strict=False) # 推理 log_validation( vae, tokenizer, image_encoder, denoising_unet, args, device, logger, cc_projection, controlnet, Hair_Encoder ) output_video = os.path.join(args.output_dir, "validation", "generated_video_0.mp4") # 提取视频帧用于可拖动预览 frames_dir = os.path.join(args.output_dir, "frames", uuid.uuid4().hex) os.makedirs(frames_dir, exist_ok=True) cap = cv2.VideoCapture(output_video) frames_list = [] idx = 0 while True: ret, frame = cap.read() if not ret: break fp = os.path.join(frames_dir, f"{idx:03d}.png") cv2.imwrite(fp, frame) frames_list.append(fp) idx += 1 cap.release() max_frames = len(frames_list) if frames_list else 1 first_frame = frames_list[0] if frames_list else None return aligned_id_path, aligned_hair_path, bald_id_path, output_video, frames_list, gr.update(minimum=1, maximum=max_frames, value=1, step=1), first_frame # Gradio 前端 # 原 Interface 版本(保留以便回退) # demo = gr.Interface( # fn=inference, # inputs=[ # gr.Image(type="pil", label="上传身份图(ID Image)"), # gr.Image(type="pil", label="上传发型图(Hair Reference Image)") # ], # outputs=[ # gr.Image(type="filepath", label="对齐后的身份图"), # gr.Image(type="filepath", label="对齐后的发型图"), # gr.Image(type="filepath", label="秃头化后的身份图"), # gr.Video(label="生成的视频") # ], # title="StableHairV2 多视角发型迁移", # description="上传身份图和发型参考图,查看对齐结果并生成多视角视频" # ) # if __name__ == "__main__": # demo.launch(server_name="0.0.0.0", server_port=7860) # Blocks 美化版 css = f""" html, body {{ height: 100%; margin: 0; padding: 0; }} .gradio-container {{ width: 100% !important; height: 100% !important; margin: 0 !important; padding: 0 !important; background-image: url("data:image/jpeg;base64,{b64_img}"); background-size: cover; background-position: center; background-attachment: fixed; /* 背景固定 */ }} #title-card {{ background: rgba(255, 255, 255, 0.8); border-radius: 12px; padding: 16px 24px; box-shadow: 0 2px 8px rgba(0,0,0,0.15); margin-bottom: 20px; }} #title-card h2 {{ text-align: center; margin: 4px 0 12px 0; font-size: 28px; }} #title-card p {{ text-align: center; font-size: 16px; color: #374151; }} .out-card {{ border:1px solid #e5e7eb; border-radius:10px; padding:10px; background: rgba(255,255,255,0.85); }} .two-col {{ display:grid !important; grid-template-columns: 360px minmax(680px, 1fr); gap:16px }} .left-pane {{min-width: 360px}} .right-pane {{min-width: 680px}} /* Tabs 美化 */ .tabs {{ background: rgba(255,255,255,0.88); border-radius: 12px; box-shadow: 0 8px 24px rgba(0,0,0,0.08); padding: 8px; border: 1px solid #e5e7eb; }} .tab-nav {{ display: flex; gap: 8px; margin-bottom: 8px; background: transparent; border-bottom: 1px solid #e5e7eb; padding-bottom: 6px; }} .tab-nav button {{ background: rgba(255,255,255,0.7); border: 1px solid #e5e7eb; backdrop-filter: blur(6px); border-radius: 8px; padding: 6px 12px; color: #111827; transition: all .2s ease; }} .tab-nav button:hover {{ transform: translateY(-1px); box-shadow: 0 4px 10px rgba(0,0,0,0.06); }} .tab-nav button[aria-selected="true"] {{ background: #4f46e5; color: #fff; border-color: #4f46e5; box-shadow: 0 6px 14px rgba(79,70,229,0.25); }} .tabitem {{ background: rgba(255,255,255,0.88); border-radius: 10px; padding: 8px; }} /* 发型库滚动限制容器:固定260px高度,内部可滚动 */ #hair_gallery_wrap {{ height: 260px !important; overflow-y: scroll !important; overflow-x: auto !important; }} #hair_gallery_wrap .grid, #hair_gallery_wrap .wrap {{ height: 100% !important; overflow-y: scroll !important; }} /* 确保画廊本体占满容器高度,避免滚动条落到页面底部 */ #hair_gallery {{ height: 100% !important; }} """ with gr.Blocks( theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), css=css ) as demo: # ==== 顶部 Panel ==== with gr.Group(elem_id="title-card"): gr.Markdown("""
上传身份图与发型参考图,系统将自动完成 对齐 → 秃头化 → 视频生成。
""") with gr.Row(elem_classes=["two-col"]): with gr.Column(scale=5, min_width=260, elem_classes=["left-pane"]): id_input = gr.Image(type="pil", label="身份图", height=200) hair_input = gr.Image(type="pil", label="发型参考图", height=200) with gr.Row(): run_btn = gr.Button("开始生成", variant="primary") clear_btn = gr.Button("清空") # ========= 发型库(点击即填充到“发型参考图”) ========= def _list_imgs(dir_path: str): exts = (".png", ".jpg", ".jpeg", ".webp") # exts = (".jpg") try: files = [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path)) if f.lower().endswith(exts)] return files except Exception: return [] hair_list = _list_imgs("hair_resposity") with gr.Accordion("发型库(点击选择后自动填充)", open=True): with gr.Group(elem_id="hair_gallery_wrap"): gallery = gr.Gallery( value=hair_list, columns=4, rows=2, allow_preview=True, label="发型库", elem_id="hair_gallery" ) def _pick_hair(evt: gr.SelectData): # type: ignore[name-defined] i = evt.index if hasattr(evt, 'index') else 0 i = 0 if i is None else int(i) if 0 <= i < len(hair_list): return gr.update(value=hair_list[i]) return gr.update() gallery.select(_pick_hair, inputs=None, outputs=hair_input) with gr.Column(scale=7, min_width=520, elem_classes=["right-pane"]): with gr.Tabs(): with gr.TabItem("生成视频"): with gr.Group(elem_classes=["out-card"]): video_out = gr.Video(label="生成的视频", height=340) with gr.Row(): frame_slider = gr.Slider(1, 21, value=1, step=1, label="多视角预览(拖动查看帧)") frame_preview = gr.Image(type="filepath", label="预览帧", height=260) frames_state = gr.State([]) with gr.TabItem("归一化对齐结果"): with gr.Group(elem_classes=["out-card"]): with gr.Row(): aligned_id_out = gr.Image(type="filepath", label="对齐后的身份图", height=240) aligned_hair_out = gr.Image(type="filepath", label="对齐后的发型图", height=240) with gr.TabItem("秃头化结果"): with gr.Group(elem_classes=["out-card"]): bald_id_out = gr.Image(type="filepath", label="秃头化后的身份图", height=260) # 逻辑保持不变 run_btn.click(fn=inference, inputs=[id_input, hair_input], outputs=[aligned_id_out, aligned_hair_out, bald_id_out, video_out, frames_state, frame_slider, frame_preview]) def _on_slide(frames, idx): if not frames: return gr.update() i = int(idx) - 1 i = max(0, min(i, len(frames) - 1)) return gr.update(value=frames[i]) frame_slider.change(_on_slide, inputs=[frames_state, frame_slider], outputs=frame_preview) def _clear(): return None, None, None, None, None clear_btn.click(_clear, None, [id_input, hair_input, aligned_id_out, aligned_hair_out, bald_id_out]) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)