Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
import argparse | |
import logging | |
import sys | |
import os | |
import random | |
import numpy as np | |
import cv2 | |
import torch | |
from PIL import Image | |
from transformers import AutoTokenizer, CLIPVisionModelWithProjection | |
from diffusers import AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel | |
from src.models.unet_3d import UNet3DConditionModel | |
from ref_encoder.reference_unet import CCProjection | |
from ref_encoder.latent_controlnet import ControlNetModel | |
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline as Hair3dPipeline | |
from src.utils.util import save_videos_grid | |
from omegaconf import OmegaConf | |
from HairMapper.hair_mapper_run import bald_head | |
# face align | |
def _maybe_align_image(image_path: str, output_size: int, prefer_cuda: bool = True): | |
"""Align and crop a face image to FFHQ-style using FFHQFaceAlignment if available. | |
Falls back to simple resize if alignment fails. | |
Returns an RGB uint8 numpy array of shape (H, W, 3). | |
""" | |
try: | |
ffhq_dir = os.path.join(os.path.dirname(__file__), 'FFHQFaceAlignment') | |
if ffhq_dir not in sys.path: | |
sys.path.insert(0, ffhq_dir) | |
# Lazy imports to avoid hard dependency if user doesn't enable alignment | |
from lib.landmarks_pytorch import LandmarksEstimation | |
from align import align_crop_image | |
# Read image as RGB uint8 | |
img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
if img_bgr is None: | |
raise RuntimeError(f"Failed to read image: {image_path}") | |
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype('uint8') | |
device = torch.device('cuda' if prefer_cuda and torch.cuda.is_available() else 'cpu') | |
le = LandmarksEstimation(type='2D') | |
img_tensor = torch.tensor(np.transpose(img, (2, 0, 1))).float().to(device) | |
with torch.no_grad(): | |
landmarks, _ = le.detect_landmarks(img_tensor.unsqueeze(0), detected_faces=None) | |
if len(landmarks) > 0: | |
lm = np.asarray(landmarks[0].detach().cpu().numpy()) | |
aligned = align_crop_image(image=img, landmarks=lm, transform_size=output_size) | |
if aligned is None or aligned.size == 0: | |
return cv2.resize(img, (output_size, output_size)) | |
return aligned | |
else: | |
return cv2.resize(img, (output_size, output_size)) | |
except Exception: | |
# Silent fallback to simple resize on any failure | |
img_bgr = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype('uint8') if img_bgr is not None else None | |
if img is None: | |
raise | |
return cv2.resize(img, (output_size, output_size)) | |
def log_validation( | |
vae, tokenizer, image_encoder, denoising_unet, | |
args, device, logger, cc_projection, | |
controlnet, hair_encoder, feature_extractor=None | |
): | |
""" | |
Run inference on validation pairs and save generated videos. | |
""" | |
logger.info("Starting validation inference...") | |
# Initialize inference pipeline | |
pipeline = Hair3dPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
image_encoder=image_encoder, | |
feature_extractor=feature_extractor, | |
controlnet=controlnet, | |
vae=vae, | |
tokenizer=tokenizer, | |
denoising_unet=denoising_unet, | |
safety_checker=None, | |
revision=args.revision, | |
torch_dtype=torch.float16 if args.use_fp16 else torch.float32, | |
).to(device) | |
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | |
pipeline.set_progress_bar_config(disable=True) | |
# Create output directory | |
output_dir = os.path.join(args.output_dir, "validation") | |
os.makedirs(output_dir, exist_ok=True) | |
print(output_dir) | |
# Speed/length overrides via env/args | |
import os as _os | |
steps = int(_os.getenv('SH_STEPS', getattr(args, 'num_inference_steps', 30))) | |
gscale = float(_os.getenv('SH_GUIDANCE', getattr(args, 'guidance_scale', 1.5))) | |
vlen = int(_os.getenv('SH_VIDEO_LENGTH', getattr(args, 'video_length', 21))) | |
# 统一时序长度:上下文帧数始终等于视频帧数(不再读取 SH_CONTEXT_FRAMES) | |
cframes = vlen | |
# Generate camera trajectory with exactly vlen frames | |
angles = np.linspace(0, 2 * np.pi, vlen, endpoint=False) | |
X = 0.4 * np.sin(angles) | |
Y = -0.05 + 0.3 * np.cos(angles) | |
x_tensor = torch.tensor(X, dtype=torch.float32).unsqueeze(1).to(device) | |
y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1).to(device) | |
# # Load reference images | |
# id_image = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB) | |
# id_image = cv2.resize(id_image, (512, 512)) | |
# Load reference images (optionally align) | |
align_enabled = getattr(args, 'align_before_infer', True) | |
align_size = getattr(args, 'align_size', 1024) | |
prefer_cuda = True if device.type == 'cuda' else False | |
if align_enabled: | |
id_image = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda) | |
else: | |
id_image = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB) | |
id_image = cv2.resize(id_image, (512, 512)) | |
# ===== ���� HairMapper ͺͷ�� ===== | |
temp_bald_path = os.path.join(args.output_dir, "bald_id.png") | |
cv2.imwrite(temp_bald_path, cv2.cvtColor(id_image, cv2.COLOR_RGB2BGR)) # �������ͼ | |
bald_head(temp_bald_path, temp_bald_path) # ͺͷ�������DZ��� | |
# ���¼���ͺͷͼ�� (RGB) | |
id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB) | |
id_image = cv2.resize(id_image, (512, 512)) | |
id_list = [id_image for _ in range(vlen)] | |
if align_enabled: | |
hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda) | |
prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda) | |
else: | |
hair_image = cv2.cvtColor(cv2.imread(args.validation_hairs[0]), cv2.COLOR_BGR2RGB) | |
hair_image = cv2.resize(hair_image, (512, 512)) | |
prompt_img = cv2.cvtColor(cv2.imread(args.validation_ids[0]), cv2.COLOR_BGR2RGB) | |
prompt_img = cv2.resize(prompt_img, (512, 512)) | |
hair_image = cv2.resize(hair_image, (512, 512)) | |
prompt_img = cv2.resize(prompt_img, (512, 512)) | |
prompt_img = [prompt_img] | |
# Perform inference and save videos | |
for idx in range(args.num_validation_images): | |
result = pipeline( | |
prompt="", | |
negative_prompt="", | |
num_inference_steps=steps, | |
guidance_scale=gscale, | |
width=512, | |
height=512, | |
controlnet_condition=id_list, | |
controlnet_conditioning_scale=1.0, | |
generator=torch.Generator(device).manual_seed(args.seed), | |
ref_image=hair_image, | |
prompt_img=prompt_img, | |
reference_encoder=hair_encoder, | |
poses=None, | |
x=x_tensor, | |
y=y_tensor, | |
video_length=vlen, | |
context_frames=cframes, | |
) | |
video = torch.cat([result.videos, result.videos], dim=0) | |
video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4") | |
save_videos_grid(video, video_path, n_rows=5, fps=24) | |
logger.info(f"Saved generated video: {video_path}") | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Inference script for 3D hairstyle generation" | |
) | |
parser.add_argument( | |
"--pretrained_model_name_or_path", type=str, required=True, | |
help="Path or ID of the pretrained pipeline" | |
) | |
parser.add_argument( | |
"--model_path", type=str, required=True, | |
help="Path or ID of the pretrained pipeline" | |
) | |
parser.add_argument( | |
"--image_encoder", type=str, required=True, | |
help="Path or ID of the CLIP vision encoder" | |
) | |
parser.add_argument( | |
"--controlnet_model_name_or_path", type=str, default=None, | |
help="Path or ID of the ControlNet model" | |
) | |
parser.add_argument( | |
"--revision", type=str, default=None, | |
help="Model revision or Git reference" | |
) | |
parser.add_argument( | |
"--output_dir", type=str, default="inference_output", | |
help="Directory to save inference results" | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, | |
help="Random seed for reproducibility" | |
) | |
parser.add_argument( | |
"--num_validation_images", type=int, default=3, | |
help="Number of videos to generate per input pair" | |
) | |
parser.add_argument( | |
"--validation_ids", type=str, nargs='+', required=True, | |
help="Path(s) to identity conditioning images" | |
) | |
parser.add_argument( | |
"--validation_hairs", type=str, nargs='+', required=True, | |
help="Path(s) to hairstyle reference images" | |
) | |
parser.add_argument( | |
"--use_fp16", action="store_true", | |
help="Enable fp16 inference" | |
) | |
parser.add_argument( | |
"--align_before_infer", action="store_true", default=True, | |
help="Align and crop input images to FFHQ style before inference" | |
) | |
parser.add_argument( | |
"--align_size", type=int, default=1024, | |
help="Output size for aligned images when alignment is enabled" | |
) | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
# Setup device and logger | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
# Set random seed | |
torch.manual_seed(args.seed) | |
if device.type == "cuda": | |
torch.cuda.manual_seed_all(args.seed) | |
# Load models | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
revision=args.revision | |
) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
args.image_encoder, | |
revision=args.revision | |
).to(device) | |
vae = AutoencoderKL.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="vae", | |
revision=args.revision | |
).to(device) | |
infer_config = OmegaConf.load('./configs/inference/inference_v2.yaml') | |
unet2 = UNet2DConditionModel.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="unet", use_safetensors=True, revision=args.revision, | |
torch_dtype=torch.float16 | |
).to(device) | |
conv_in_8 = torch.nn.Conv2d(8, unet2.conv_in.out_channels, kernel_size=unet2.conv_in.kernel_size, | |
padding=unet2.conv_in.padding) | |
conv_in_8.requires_grad_(False) | |
unet2.conv_in.requires_grad_(False) | |
torch.nn.init.zeros_(conv_in_8.weight) | |
conv_in_8.weight[:, :4, :, :].copy_(unet2.conv_in.weight) | |
conv_in_8.bias.copy_(unet2.conv_in.bias) | |
unet2.conv_in = conv_in_8 | |
# Load or initialize ControlNet | |
controlnet = ControlNetModel.from_unet(unet2).to(device) | |
# state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location=torch.device('cpu')) | |
# state_dict2 = torch.load(args.model_path, map_location=torch.device('cpu')) | |
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model.bin"), map_location=torch.device('cpu')) | |
controlnet.load_state_dict(state_dict2, strict=False) | |
# Load 3D UNet motion module | |
prefix = "motion_module" | |
ckpt_num = "4140000" | |
save_path = os.path.join(args.model_path, f"{prefix}-{ckpt_num}.pth") | |
denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
args.pretrained_model_name_or_path, | |
save_path, | |
subfolder="unet", | |
unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
).to(device) | |
# Load projection and hair encoder | |
cc_projection = CCProjection().to(device) | |
state_dict3 = torch.load(os.path.join(args.model_path, "pytorch_model_1.bin"), map_location=torch.device('cpu')) | |
cc_projection.load_state_dict(state_dict3, strict=False) | |
from ref_encoder.reference_unet import ref_unet | |
Hair_Encoder = ref_unet.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False, | |
device_map=None, ignore_mismatched_sizes=True | |
).to(device) | |
state_dict2 = torch.load(os.path.join(args.model_path, "pytorch_model_2.bin"), map_location=torch.device('cpu')) | |
# state_dict2 = torch.load(os.path.join('/home/jichao.zhang/code/3dhair/train_sv3d/checkpoint-30000/', "pytorch_model.bin")) | |
Hair_Encoder.load_state_dict(state_dict2, strict=False) | |
# Run validation inference | |
log_validation( | |
vae, tokenizer, image_encoder, denoising_unet, | |
args, device, logger, | |
cc_projection, controlnet, Hair_Encoder | |
) | |
if __name__ == "__main__": | |
main() | |