import os |
from einops import rearrange |
from omegaconf import OmegaConf |
import torch |
import numpy as np |
import trimesh |
import torchvision |
import torch.nn.functional as F |
from PIL import Image |
from torchvision import transforms |
from torchvision.transforms import v2 |
from diffusers import HeunDiscreteScheduler |
from diffusers import FluxPipeline |
from pytorch_lightning import seed_everything |
import os |
import time |
from models.lrm.utils.infer_util import save_video |
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl |
from models.lrm.utils.render_utils import rotate_x, rotate_y |
from models.lrm.utils.train_util import instantiate_from_config |
from models.lrm.utils.camera_util import get_flux_input_cameras |
from models.ISOMER.reconstruction_func import reconstruction |
from models.ISOMER.projection_func import projection |
from utils.tool import NormalTransfer, load_mipmap |
from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix |
device = "cuda" |
resolution = 512 |
save_dir = "./outputs/text2" |
normal_transfer = NormalTransfer() |
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device) |
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device) |
isomer_radius = 4.5 |
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device) |
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device) |
flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16) |
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors') |
flux_pipe.to(device=device, dtype=torch.bfloat16) |
generator = torch.Generator(device=device).manual_seed(10) |
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml") |
model_config = config.model_config |
infer_config = config.infer_config |
model = instantiate_from_config(model_config) |
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt" |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
model.load_state_dict(state_dict, strict=True) |
model = model.to(device) |
model.init_flexicubes_geometry(device, fovy=50.0) |
model = model.eval() |
def multi_view_rgb_normal_generation(prompt, save_path=None): |
with torch.no_grad(): |
image = flux_pipe( |
prompt=prompt, |
num_inference_steps=30, |
guidance_scale=3.5, |
num_images_per_prompt=1, |
width=resolution*4, |
height=resolution*2, |
output_type='np', |
generator=generator |
).images |
return image |
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False): |
images = image.unsqueeze(0).to(device) |
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) |
with torch.no_grad(): |
planes = model.forward_planes(images, input_cameras) |
mesh_path_idx = os.path.join(save_path, f'{name}.obj') |
mesh_out = model.extract_mesh( |
planes, |
use_texture_map=export_texmap, |
**infer_config, |
) |
if export_texmap: |
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out |
save_obj_with_mtl( |
vertices.data.cpu().numpy(), |
uvs.data.cpu().numpy(), |
faces.data.cpu().numpy(), |
mesh_tex_idx.data.cpu().numpy(), |
tex_map.permute(1, 2, 0).data.cpu().numpy(), |
mesh_path_idx, |
) |
else: |
vertices, faces, vertex_colors = mesh_out |
save_obj(vertices, faces, vertex_colors, mesh_path_idx) |
print(f"Mesh saved to {mesh_path_idx}") |
render_size = 512 |
if if_save_video: |
video_path_idx = os.path.join(save_path, f'{name}.mp4') |
render_size = infer_config.render_resolution |
ENV = load_mipmap("models/lrm/env_mipmap/6") |
materials = (0.0,0.9) |
all_mv, all_mvp, all_campos = get_render_cameras_video( |
batch_size=1, |
M=240, |
radius=4.5, |
elevation=(90, 60.0), |
is_flexicubes=True, |
fov=30 |
) |
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( |
model, |
planes, |
render_cameras=all_mvp, |
camera_pos=all_campos, |
env=ENV, |
materials=materials, |
render_size=render_size, |
chunk_size=20, |
is_flexicubes=True, |
) |
normals = (torch.nn.functional.normalize(normals) + 1) / 2 |
normals = normals * alphas + (1-alphas) |
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3) |
save_video( |
all_frames, |
video_path_idx, |
fps=30, |
) |
print(f"Video saved to {video_path_idx}") |
return vertices, faces |
def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg): |
if local_normal_images.min() >= 0: |
local_normal = local_normal_images.float() * 2 - 1 |
else: |
local_normal = local_normal_images.float() |
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) |
global_normal[...,0] *= -1 |
global_normal = (global_normal + 1) / 2 |
global_normal = global_normal.permute(0, 3, 1, 2) |
return global_normal |
def main(prompt = "a owl wearing a hat."): |
fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.' |
save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_")) |
os.makedirs(save_dir_path, exist_ok=True) |
prompt = fix_prompt+" "+prompt |
rgb_normal_grid = multi_view_rgb_normal_generation(prompt) |
images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() |
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) |
rgb_multi_view = images[:4, :3, :, :] |
normal_multi_view = images[4:, :3, :, :] |
multi_view_mask = get_background(normal_multi_view) |
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask) |
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device) |
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False) |
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations) |
global_normal = global_normal * multi_view_mask + (1-multi_view_mask) |
global_normal = global_normal.permute(0,2,3,1) |
rgb_multi_view = rgb_multi_view.permute(0,2,3,1) |
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1) |
vertices = torch.from_numpy(vertices).to(device) |
faces = torch.from_numpy(faces).to(device) |
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3] |
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3] |
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-6, blur=5) |
meshes = reconstruction( |
normal_pils=global_normal, |
masks=multi_view_mask, |
weights=isomer_geo_weights, |
fov=30, |
radius=isomer_radius, |
camera_angles_azi=isomer_azimuths, |
camera_angles_ele=isomer_elevations, |
expansion_weight_stage1=0.1, |
init_type="file", |
init_verts=vertices, |
init_faces=faces, |
stage1_steps=0, |
stage2_steps=50, |
start_edge_len_stage1=0.1, |
end_edge_len_stage1=0.02, |
start_edge_len_stage2=0.02, |
end_edge_len_stage2=0.005, |
) |
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5) |
save_glb_addr = projection( |
meshes, |
masks=multi_view_mask_proj, |
images=rgb_multi_view, |
azimuths=isomer_azimuths, |
elevations=isomer_elevations, |
weights=isomer_color_weights, |
fov=30, |
radius=isomer_radius, |
save_dir=f"{save_dir_path}/ISOMER/", |
) |
print(f'saved to {save_glb_addr}') |
if __name__ == '__main__': |
import time |
start_time = time.time() |
prompts = ["A red dragon soaring", "A running Chihuahua", "A dancing rabbit", "A girl with blue hair and white dress", "A teacher", "A tiger playing guitar", "A red rose", "A red peony", "A rose in a vase", "A golden retriever sitting", "A golden retriever running"] |
for prompt in prompts: |
main(prompt) |
end_time = time.time() |
print(f"Time taken: {end_time - start_time:.2f} seconds for {len(prompts)} prompts") |
breakpoint() |