import os from einops import rearrange from omegaconf import OmegaConf import torch import numpy as np import trimesh import torchvision import torch.nn.functional as F from PIL import Image from torchvision import transforms from torchvision.transforms import v2 from diffusers import HeunDiscreteScheduler from diffusers import FluxPipeline from pytorch_lightning import seed_everything import os import time from models.lrm.utils.infer_util import save_video from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl from models.lrm.utils.render_utils import rotate_x, rotate_y from models.lrm.utils.train_util import instantiate_from_config from models.lrm.utils.camera_util import get_flux_input_cameras from models.ISOMER.reconstruction_func import reconstruction from models.ISOMER.projection_func import projection from utils.tool import NormalTransfer, load_mipmap from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix device = "cuda" resolution = 512 save_dir = "./outputs/text2" normal_transfer = NormalTransfer() isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device) isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device) isomer_radius = 4.5 isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device) isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device) # model initialization and loading # flux flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16) flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors') flux_pipe.to(device=device, dtype=torch.bfloat16) generator = torch.Generator(device=device).manual_seed(10) # lrm config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml") model_config = config.model_config infer_config = config.infer_config model = instantiate_from_config(model_config) model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt" state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} model.load_state_dict(state_dict, strict=True) model = model.to(device) model.init_flexicubes_geometry(device, fovy=50.0) model = model.eval() # Flux multi-view generation def multi_view_rgb_normal_generation(prompt, save_path=None): # generate multi-view images with torch.no_grad(): image = flux_pipe( prompt=prompt, num_inference_steps=30, guidance_scale=3.5, num_images_per_prompt=1, width=resolution*4, height=resolution*2, output_type='np', generator=generator ).images return image # lrm reconstructions def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False): images = image.unsqueeze(0).to(device) images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) # breakpoint() with torch.no_grad(): # get triplane planes = model.forward_planes(images, input_cameras) mesh_path_idx = os.path.join(save_path, f'{name}.obj') mesh_out = model.extract_mesh( planes, use_texture_map=export_texmap, **infer_config, ) if export_texmap: vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out save_obj_with_mtl( vertices.data.cpu().numpy(), uvs.data.cpu().numpy(), faces.data.cpu().numpy(), mesh_tex_idx.data.cpu().numpy(), tex_map.permute(1, 2, 0).data.cpu().numpy(), mesh_path_idx, ) else: vertices, faces, vertex_colors = mesh_out save_obj(vertices, faces, vertex_colors, mesh_path_idx) print(f"Mesh saved to {mesh_path_idx}") render_size = 512 if if_save_video: video_path_idx = os.path.join(save_path, f'{name}.mp4') render_size = infer_config.render_resolution ENV = load_mipmap("models/lrm/env_mipmap/6") materials = (0.0,0.9) all_mv, all_mvp, all_campos = get_render_cameras_video( batch_size=1, M=240, radius=4.5, elevation=(90, 60.0), is_flexicubes=True, fov=30 ) frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( model, planes, render_cameras=all_mvp, camera_pos=all_campos, env=ENV, materials=materials, render_size=render_size, chunk_size=20, is_flexicubes=True, ) normals = (torch.nn.functional.normalize(normals) + 1) / 2 normals = normals * alphas + (1-alphas) all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3) save_video( all_frames, video_path_idx, fps=30, ) print(f"Video saved to {video_path_idx}") return vertices, faces def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg): if local_normal_images.min() >= 0: local_normal = local_normal_images.float() * 2 - 1 else: local_normal = local_normal_images.float() global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) global_normal[...,0] *= -1 global_normal = (global_normal + 1) / 2 global_normal = global_normal.permute(0, 3, 1, 2) return global_normal def main(prompt = "a owl wearing a hat."): fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.' # user prompt save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_")) os.makedirs(save_dir_path, exist_ok=True) prompt = fix_prompt+" "+prompt # generate multi-view images rgb_normal_grid = multi_view_rgb_normal_generation(prompt) # lrm reconstructions images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512) rgb_multi_view = images[:4, :3, :, :] normal_multi_view = images[4:, :3, :, :] multi_view_mask = get_background(normal_multi_view) rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask) input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device) vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False) # local normal to global normal global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations) global_normal = global_normal * multi_view_mask + (1-multi_view_mask) global_normal = global_normal.permute(0,2,3,1) rgb_multi_view = rgb_multi_view.permute(0,2,3,1) multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1) vertices = torch.from_numpy(vertices).to(device) faces = torch.from_numpy(faces).to(device) vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3] vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3] # global_normal: B,H,W,3 # multi_view_mask: B,H,W # rgb_multi_view: B,H,W,3 multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-6, blur=5) meshes = reconstruction( normal_pils=global_normal, masks=multi_view_mask, weights=isomer_geo_weights, fov=30, radius=isomer_radius, camera_angles_azi=isomer_azimuths, camera_angles_ele=isomer_elevations, expansion_weight_stage1=0.1, init_type="file", init_verts=vertices, init_faces=faces, stage1_steps=0, stage2_steps=50, start_edge_len_stage1=0.1, end_edge_len_stage1=0.02, start_edge_len_stage2=0.02, end_edge_len_stage2=0.005, ) multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5) save_glb_addr = projection( meshes, masks=multi_view_mask_proj, images=rgb_multi_view, azimuths=isomer_azimuths, elevations=isomer_elevations, weights=isomer_color_weights, fov=30, radius=isomer_radius, save_dir=f"{save_dir_path}/ISOMER/", ) print(f'saved to {save_glb_addr}') if __name__ == '__main__': import time start_time = time.time() prompts = ["A red dragon soaring", "A running Chihuahua", "A dancing rabbit", "A girl with blue hair and white dress", "A teacher", "A tiger playing guitar", "A red rose", "A red peony", "A rose in a vase", "A golden retriever sitting", "A golden retriever running"] for prompt in prompts: main(prompt) end_time = time.time() print(f"Time taken: {end_time - start_time:.2f} seconds for {len(prompts)} prompts") breakpoint()