|
import os |
|
from einops import rearrange |
|
from omegaconf import OmegaConf |
|
import torch |
|
import numpy as np |
|
import trimesh |
|
import torchvision |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from torchvision import transforms |
|
from torchvision.transforms import v2 |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
import rembg |
|
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline |
|
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel |
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler |
|
from pytorch_lightning import seed_everything |
|
import os |
|
|
|
from models.ISOMER.reconstruction_func import reconstruction |
|
from models.ISOMER.projection_func import projection |
|
from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video |
|
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl |
|
from models.lrm.utils.render_utils import rotate_x, rotate_y |
|
from models.lrm.utils.train_util import instantiate_from_config |
|
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras |
|
from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap |
|
from utils.tool import get_background, get_render_cameras_video, render_frames |
|
import time |
|
|
|
device = "cuda" |
|
resolution = 512 |
|
save_dir = "./outputs" |
|
zero123plus_diffusion_steps = 75 |
|
normal_transfer = NormalTransfer() |
|
rembg_session = rembg.new_session() |
|
isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device) |
|
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device) |
|
isomer_radius = 4.1 |
|
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device) |
|
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device) |
|
|
|
|
|
|
|
|
|
print('==> Loading Flux model ...') |
|
flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev" |
|
flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro") |
|
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16) |
|
|
|
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors') |
|
|
|
|
|
flux_pipe.to(device=device, dtype=torch.bfloat16) |
|
generator = torch.Generator(device=device).manual_seed(0) |
|
|
|
|
|
print('==> Loading LRM model ...') |
|
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml") |
|
model_config = config.model_config |
|
infer_config = config.infer_config |
|
model = instantiate_from_config(model_config) |
|
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt" |
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
model = model.to(device) |
|
model.init_flexicubes_geometry(device, fovy=50.0) |
|
model = model.eval() |
|
|
|
|
|
print('==> Loading diffusion model ...') |
|
zero123plus_pipeline = DiffusionPipeline.from_pretrained( |
|
"sudo-ai/zero123plus-v1.2", |
|
custom_pipeline="./models/zero123plus", |
|
torch_dtype=torch.float16, |
|
) |
|
zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
zero123plus_pipeline.scheduler.config, timestep_spacing='trailing' |
|
) |
|
unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt" |
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict'] |
|
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')} |
|
zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True) |
|
zero123plus_pipeline = zero123plus_pipeline.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
caption_model = AutoModelForCausalLM.from_pretrained( |
|
"/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True, |
|
).to(device) |
|
caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True) |
|
|
|
|
|
def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0, |
|
control_image=[], |
|
control_mode=[], |
|
control_guidance_start=None, |
|
control_guidance_end=None, |
|
controlnet_conditioning_scale=None, |
|
lora_scale=1.0 |
|
): |
|
control_mode_dict = { |
|
'canny': 0, |
|
'tile': 1, |
|
'depth': 2, |
|
'blur': 3, |
|
'pose': 4, |
|
'gray': 5, |
|
'lq': 6, |
|
} |
|
|
|
hparam_dict = { |
|
'prompt': prompt, |
|
'image': image, |
|
'strength': strength, |
|
'num_inference_steps': 30, |
|
'guidance_scale': 3.5, |
|
'num_images_per_prompt': 1, |
|
'width': resolution*4, |
|
'height': resolution*2, |
|
'output_type': 'np', |
|
'generator': generator, |
|
'joint_attention_kwargs': {"scale": lora_scale} |
|
} |
|
|
|
|
|
if len(control_image) > 0: |
|
assert len(control_mode) == len(control_image) |
|
|
|
ctrl_hparams = { |
|
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode], |
|
'control_image': control_image, |
|
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))], |
|
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))], |
|
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))], |
|
} |
|
|
|
hparam_dict.update(ctrl_hparams) |
|
|
|
|
|
with torch.no_grad(): |
|
image = flux_pipe( |
|
**hparam_dict |
|
).images |
|
return image |
|
|
|
|
|
def run_captioning(image): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.bfloat16 |
|
|
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
|
|
prompt = "<MORE_DETAILED_CAPTION>" |
|
inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) |
|
|
|
|
|
generated_ids = caption_model.generate( |
|
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 |
|
) |
|
|
|
generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = caption_processor.post_process_generation( |
|
generated_text, task=prompt, image_size=(image.width, image.height) |
|
) |
|
|
|
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "") |
|
return caption_text |
|
|
|
|
|
|
|
def multi_view_rgb_generation(cond_img): |
|
|
|
with torch.no_grad(): |
|
output_image = zero123plus_pipeline( |
|
cond_img, |
|
num_inference_steps=zero123plus_diffusion_steps, |
|
width=resolution*2, |
|
height=resolution*2, |
|
).images[0] |
|
return output_image |
|
|
|
|
|
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30): |
|
images = image.unsqueeze(0).to(device) |
|
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) |
|
|
|
with torch.no_grad(): |
|
|
|
planes = model.forward_planes(images, input_cameras) |
|
|
|
mesh_path_idx = os.path.join(save_path, f'{name}.obj') |
|
|
|
mesh_out = model.extract_mesh( |
|
planes, |
|
use_texture_map=export_texmap, |
|
**infer_config, |
|
) |
|
if export_texmap: |
|
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out |
|
save_obj_with_mtl( |
|
vertices.data.cpu().numpy(), |
|
uvs.data.cpu().numpy(), |
|
faces.data.cpu().numpy(), |
|
mesh_tex_idx.data.cpu().numpy(), |
|
tex_map.permute(1, 2, 0).data.cpu().numpy(), |
|
mesh_path_idx, |
|
) |
|
else: |
|
vertices, faces, vertex_colors = mesh_out |
|
save_obj(vertices, faces, vertex_colors, mesh_path_idx) |
|
print(f"Mesh saved to {mesh_path_idx}") |
|
|
|
render_size = 512 |
|
if if_save_video: |
|
video_path_idx = os.path.join(save_path, f'{name}.mp4') |
|
render_size = infer_config.render_resolution |
|
ENV = load_mipmap("models/lrm/env_mipmap/6") |
|
materials = (0.0,0.9) |
|
|
|
all_mv, all_mvp, all_campos = get_render_cameras_video( |
|
batch_size=1, |
|
M=240, |
|
radius=4.5, |
|
elevation=(90, 60.0), |
|
is_flexicubes=True, |
|
fov=30 |
|
) |
|
|
|
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( |
|
model, |
|
planes, |
|
render_cameras=all_mvp, |
|
camera_pos=all_campos, |
|
env=ENV, |
|
materials=materials, |
|
render_size=render_size, |
|
chunk_size=20, |
|
is_flexicubes=True, |
|
) |
|
normals = (torch.nn.functional.normalize(normals) + 1) / 2 |
|
normals = normals * alphas + (1-alphas) |
|
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3) |
|
|
|
|
|
save_video( |
|
all_frames, |
|
video_path_idx, |
|
fps=30, |
|
) |
|
print(f"Video saved to {video_path_idx}") |
|
|
|
if render_azimuths is not None and render_elevations is not None and render_radius is not None: |
|
render_size = infer_config.render_resolution |
|
ENV = load_mipmap("models/lrm/env_mipmap/6") |
|
materials = (0.0,0.9) |
|
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames( |
|
batch_size=1, |
|
radius=render_radius, |
|
azimuths=render_azimuths, |
|
elevations=render_elevations, |
|
fov=30 |
|
) |
|
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( |
|
model, |
|
planes, |
|
render_cameras=all_mvp, |
|
camera_pos=all_campos, |
|
env=ENV, |
|
materials=materials, |
|
render_size=render_size, |
|
render_mv = all_mv, |
|
local_normal=True, |
|
identity_mv=identity_mv, |
|
) |
|
else: |
|
normals = None |
|
frames = None |
|
albedos = None |
|
|
|
return vertices, faces, normals, frames, albedos |
|
|
|
|
|
def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False): |
|
""" |
|
input_normal: in range [-1, 1], shape (b c h w) |
|
""" |
|
|
|
input_normal = input_normal.permute(0, 2, 3, 1).cpu() |
|
|
|
azimuths_deg = np.array(azimuths_deg) |
|
elevations_deg = np.array(elevations_deg) |
|
|
|
if is_global_to_local: |
|
local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg) |
|
return local_normal.permute(0, 3, 1, 2) |
|
else: |
|
global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False) |
|
global_normal[..., 0] *= -1 |
|
return global_normal.permute(0, 3, 1, 2) |
|
|
|
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg): |
|
if local_normal_images.min() >= 0: |
|
local_normal = local_normal_images.float() * 2 - 1 |
|
else: |
|
local_normal = local_normal_images.float() |
|
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) |
|
global_normal[...,0] *= -1 |
|
global_normal = (global_normal + 1) / 2 |
|
global_normal = global_normal.permute(0, 3, 1, 2) |
|
return global_normal |
|
|
|
def main(): |
|
image_pth = "examples/蓝色小怪物.webp" |
|
save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0]) |
|
os.makedirs(save_dir_path, exist_ok=True) |
|
input_image = Image.open(image_pth) |
|
|
|
input_image = remove_background(input_image, rembg_session) |
|
input_image = resize_foreground(input_image, 0.85) |
|
|
|
|
|
image_caption = run_captioning(image_pth) |
|
|
|
|
|
output_image = multi_view_rgb_generation(input_image) |
|
|
|
|
|
rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0 |
|
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() |
|
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) |
|
|
|
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device) |
|
|
|
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \ |
|
lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', |
|
export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths, |
|
render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30) |
|
|
|
vertices = torch.from_numpy(vertices).to(device) |
|
faces = torch.from_numpy(faces).to(device) |
|
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3] |
|
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3] |
|
|
|
|
|
|
|
lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
control_image = [lrm_3D_bundle_image * 2 - 1] |
|
control_mode = ['tile'] |
|
control_guidance_start = [0.0] |
|
control_guidance_end = [0.3] |
|
controlnet_conditioning_scale = [0.8] |
|
|
|
flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode]) |
|
|
|
rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet( |
|
prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]), |
|
image=lrm_3D_bundle_image, |
|
strength=0.6, |
|
control_image=control_image, |
|
control_mode=control_mode, |
|
control_guidance_start=control_guidance_start, |
|
control_guidance_end=control_guidance_end, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
lora_scale=1.0 |
|
) |
|
|
|
rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float() |
|
rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) |
|
rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda() |
|
normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda() |
|
multi_view_mask = get_background(normal_multi_view).cuda() |
|
rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask) |
|
|
|
|
|
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda() |
|
|
|
global_normal = global_normal * multi_view_mask + (1-multi_view_mask) |
|
|
|
global_normal = global_normal.permute(0,2,3,1) |
|
multi_view_mask = multi_view_mask.squeeze(1) |
|
rgb_multi_view = rgb_multi_view.permute(0,2,3,1) |
|
|
|
|
|
|
|
|
|
|
|
meshes = reconstruction( |
|
normal_pils=global_normal, |
|
masks=multi_view_mask, |
|
weights=isomer_geo_weights, |
|
fov=30, |
|
radius=isomer_radius, |
|
camera_angles_azi=isomer_azimuths, |
|
camera_angles_ele=isomer_elevations, |
|
expansion_weight_stage1=0.1, |
|
init_type="file", |
|
init_verts=vertices, |
|
init_faces=faces, |
|
stage1_steps=0, |
|
stage2_steps=50, |
|
start_edge_len_stage1=0.1, |
|
end_edge_len_stage1=0.02, |
|
start_edge_len_stage2=0.02, |
|
end_edge_len_stage2=0.005, |
|
) |
|
|
|
save_glb_addr = projection( |
|
meshes=meshes, |
|
masks=multi_view_mask, |
|
images=rgb_multi_view, |
|
azimuths=isomer_azimuths, |
|
elevations=isomer_elevations, |
|
weights=isomer_color_weights, |
|
fov=30, |
|
radius=isomer_radius, |
|
save_dir=f"{save_dir_path}/ISOMER/", |
|
) |
|
print(f'saved to {save_glb_addr}') |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|