diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9eec9927fa53a74357665122cc46805291eab763 --- /dev/null +++ b/app.py @@ -0,0 +1,319 @@ +import gradio as gr +import os +import subprocess +import shlex +subprocess.run( + shlex.split( + "pip install ./extension/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps" + ) +) + +subprocess.run( + shlex.split( + "pip install ./extension/renderutils_plugin-1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps" + ) +) +import torch +import numpy as np +from PIL import Image +from einops import rearrange +from diffusers import FluxPipeline +from models.lrm.utils.camera_util import get_flux_input_cameras +from models.lrm.utils.infer_util import save_video +from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl +from models.lrm.utils.render_utils import rotate_x, rotate_y +from models.lrm.utils.train_util import instantiate_from_config +from models.ISOMER.reconstruction_func import reconstruction +from models.ISOMER.projection_func import projection +import os +from einops import rearrange +from omegaconf import OmegaConf +import spaces +import torch +import numpy as np +import trimesh +import torchvision +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms +from torchvision.transforms import v2 +from diffusers import HeunDiscreteScheduler +from diffusers import FluxPipeline +from pytorch_lightning import seed_everything +import os +from huggingface_hub import hf_hub_download + + +from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames + +device = "cuda" +resolution = 512 +save_dir = "./outputs" +normal_transfer = NormalTransfer() +isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device) +isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device) +isomer_radius = 4.5 +isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device) +isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device) + +# model initialization and loading +# flux +flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16) +flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model") +flux_pipe.load_lora_weights(flux_lora_ckpt_path) + +flux_pipe.to(device=device, dtype=torch.bfloat16) +generator = torch.Generator(device=device).manual_seed(10) + +# lrm +config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml") +model_config = config.model_config +infer_config = config.infer_config +model = instantiate_from_config(model_config) +model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model") +state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] +state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} +model.load_state_dict(state_dict, strict=True) + +model = model.to(device) +model.init_flexicubes_geometry(device, fovy=50.0) +model = model.eval() + +@spaces.GPU +def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False): + images = image.unsqueeze(0).to(device) + images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) + # breakpoint() + with torch.no_grad(): + # get triplane + planes = model.forward_planes(images, input_cameras) + + mesh_path_idx = os.path.join(save_path, f'{name}.obj') + + mesh_out = model.extract_mesh( + planes, + use_texture_map=export_texmap, + **infer_config, + ) + if export_texmap: + vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out + save_obj_with_mtl( + vertices.data.cpu().numpy(), + uvs.data.cpu().numpy(), + faces.data.cpu().numpy(), + mesh_tex_idx.data.cpu().numpy(), + tex_map.permute(1, 2, 0).data.cpu().numpy(), + mesh_path_idx, + ) + else: + vertices, faces, vertex_colors = mesh_out + save_obj(vertices, faces, vertex_colors, mesh_path_idx) + print(f"Mesh saved to {mesh_path_idx}") + + render_size = 512 + if if_save_video: + video_path_idx = os.path.join(save_path, f'{name}.mp4') + render_size = infer_config.render_resolution + ENV = load_mipmap("models/lrm/env_mipmap/6") + materials = (0.0,0.9) + + all_mv, all_mvp, all_campos = get_render_cameras_video( + batch_size=1, + M=240, + radius=4.5, + elevation=(90, 60.0), + is_flexicubes=True, + fov=30 + ) + + frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( + model, + planes, + render_cameras=all_mvp, + camera_pos=all_campos, + env=ENV, + materials=materials, + render_size=render_size, + chunk_size=20, + is_flexicubes=True, + ) + normals = (torch.nn.functional.normalize(normals) + 1) / 2 + normals = normals * alphas + (1-alphas) + all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3) + + save_video( + all_frames, + video_path_idx, + fps=30, + ) + print(f"Video saved to {video_path_idx}") + + return vertices, faces + + +def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg): + if local_normal_images.min() >= 0: + local_normal = local_normal_images.float() * 2 - 1 + else: + local_normal = local_normal_images.float() + global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) + global_normal[...,0] *= -1 + global_normal = (global_normal + 1) / 2 + global_normal = global_normal.permute(0, 3, 1, 2) + return global_normal + +# 生成多视图图像 +@spaces.GPU +def generate_multi_view_images(prompt, seed): + generator = torch.manual_seed(seed) + with torch.no_grad(): + images = flux_pipe( + prompt=prompt, + num_inference_steps=30, + guidance_scale=3.5, + num_images_per_prompt=1, + width=resolution * 4, + height=resolution * 2, + output_type='np', + generator=generator + ).images + return images + +# 重建 3D 模型 +@spaces.GPU +def reconstruct_3d_model(images, prompt): + rgb_normal_grid = images + save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_")) + os.makedirs(save_dir_path, exist_ok=True) + + images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) + images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512) + rgb_multi_view = images[:4, :3, :, :] + normal_multi_view = images[4:, :3, :, :] + multi_view_mask = get_background(normal_multi_view) + rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask) + input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device) + vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False) + # local normal to global normal + + global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations) + global_normal = global_normal * multi_view_mask + (1-multi_view_mask) + + global_normal = global_normal.permute(0,2,3,1) + rgb_multi_view = rgb_multi_view.permute(0,2,3,1) + multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1) + vertices = torch.from_numpy(vertices).to(device) + faces = torch.from_numpy(faces).to(device) + vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3] + vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3] + + # global_normal: B,H,W,3 + # multi_view_mask: B,H,W + # rgb_multi_view: B,H,W,3 + + meshes = reconstruction( + normal_pils=global_normal, + masks=multi_view_mask, + weights=isomer_geo_weights, + fov=30, + radius=isomer_radius, + camera_angles_azi=isomer_azimuths, + camera_angles_ele=isomer_elevations, + expansion_weight_stage1=0.1, + init_type="file", + init_verts=vertices, + init_faces=faces, + stage1_steps=0, + stage2_steps=50, + start_edge_len_stage1=0.1, + end_edge_len_stage1=0.02, + start_edge_len_stage2=0.02, + end_edge_len_stage2=0.005, + ) + + + save_glb_addr = projection( + meshes, + masks=multi_view_mask, + images=rgb_multi_view, + azimuths=isomer_azimuths, + elevations=isomer_elevations, + weights=isomer_color_weights, + fov=30, + radius=isomer_radius, + save_dir=f"{save_dir_path}/ISOMER/", + ) + + return save_glb_addr + +# Gradio 接口函数 +def gradio_pipeline(prompt, seed): + # 生成多视图图像 + rgb_normal_grid = generate_multi_view_images(prompt, seed) + image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8)) + + # 3d reconstruction + + + # 重建 3D 模型并返回 glb 路径 + save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt) + + return image_preview, save_glb_addr + +# Gradio Blocks 应用 +with gr.Blocks() as demo: + with gr.Row(variant="panel"): + # 左侧输入区域 + with gr.Column(): + with gr.Row(): + prompt_input = gr.Textbox( + label="Enter Prompt", + placeholder="Describe your 3D model...", + lines=2, + elem_id="prompt_input" + ) + + with gr.Row(): + sample_seed = gr.Number(value=42, label="Seed Value", precision=0) + + with gr.Row(): + submit = gr.Button("Generate", elem_id="generate", variant="primary") + + with gr.Row(variant="panel"): + gr.Markdown("Examples:") + gr.Examples( + examples=[ + ["a castle on a hill"], + ["an owl wearing a hat"], + ["a futuristic car"] + ], + inputs=[prompt_input], + label="Prompt Examples" + ) + + # 右侧输出区域 + with gr.Column(): + with gr.Row(): + rgb_normal_grid_image = gr.Image( + label="RGB Normal Grid", + type="pil", + interactive=False + ) + + with gr.Row(): + with gr.Tab("GLB"): + output_glb_model = gr.Model3D( + label="Generated 3D Model (GLB Format)", + interactive=False + ) + gr.Markdown("Download the model for proper visualization.") + + # 处理逻辑 + submit.click( + fn=gradio_pipeline, inputs=[prompt_input, sample_seed], + outputs=[rgb_normal_grid_image, output_glb_model] + ) + +# 启动应用 +demo.queue(max_size=10) +demo.launch(server_port=1211) diff --git a/extension/put_here.txt b/extension/put_here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/image_to_mesh.py b/image_to_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7d86a4680e1d81706072b0f6ba1ed170b23fa1 --- /dev/null +++ b/image_to_mesh.py @@ -0,0 +1,437 @@ +import os +from einops import rearrange +from omegaconf import OmegaConf +import torch +import numpy as np +import trimesh +import torchvision +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms +from torchvision.transforms import v2 +from transformers import AutoProcessor, AutoModelForCausalLM +import rembg +from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline +from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler +from pytorch_lightning import seed_everything +import os + +from models.ISOMER.reconstruction_func import reconstruction +from models.ISOMER.projection_func import projection +from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video +from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl +from models.lrm.utils.render_utils import rotate_x, rotate_y +from models.lrm.utils.train_util import instantiate_from_config +from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras +from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap +from utils.tool import get_background, get_render_cameras_video, render_frames +import time + +device = "cuda" +resolution = 512 +save_dir = "./outputs" +zero123plus_diffusion_steps = 75 +normal_transfer = NormalTransfer() +rembg_session = rembg.new_session() +isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device) +isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device) +isomer_radius = 4.1 +isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device) +isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device) +# seed_everything(42) + +# model initialization and loading +# flux +print('==> Loading Flux model ...') +flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev" +flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro") +flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16) + +flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors') + + +flux_pipe.to(device=device, dtype=torch.bfloat16) +generator = torch.Generator(device=device).manual_seed(0) + +# lrm +print('==> Loading LRM model ...') +config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml") +model_config = config.model_config +infer_config = config.infer_config +model = instantiate_from_config(model_config) +model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt" +state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] +state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} +model.load_state_dict(state_dict, strict=True) + +model = model.to(device) +model.init_flexicubes_geometry(device, fovy=50.0) +model = model.eval() + +# zero123++ +print('==> Loading diffusion model ...') +zero123plus_pipeline = DiffusionPipeline.from_pretrained( + "sudo-ai/zero123plus-v1.2", + custom_pipeline="./models/zero123plus", + torch_dtype=torch.float16, +) +zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + zero123plus_pipeline.scheduler.config, timestep_spacing='trailing' +) +unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt" +state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict'] +state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')} +zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True) +zero123plus_pipeline = zero123plus_pipeline.to(device) + +# unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin" +# state_dict = torch.load(unet_ckpt_path, map_location='cpu') +# zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True) +# zero123plus_pipeline = zero123plus_pipeline.to(device) + +# florence +caption_model = AutoModelForCausalLM.from_pretrained( + "/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True, + ).to(device) +caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True) + +# Flux multi-view generation +def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0, + control_image=[], + control_mode=[], + control_guidance_start=None, + control_guidance_end=None, + controlnet_conditioning_scale=None, + lora_scale=1.0 + ): + control_mode_dict = { + 'canny': 0, + 'tile': 1, + 'depth': 2, + 'blur': 3, + 'pose': 4, + 'gray': 5, + 'lq': 6, + } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only + + hparam_dict = { + 'prompt': prompt, + 'image': image, + 'strength': strength, + 'num_inference_steps': 30, + 'guidance_scale': 3.5, + 'num_images_per_prompt': 1, + 'width': resolution*4, + 'height': resolution*2, + 'output_type': 'np', + 'generator': generator, + 'joint_attention_kwargs': {"scale": lora_scale} + } + + # append controlnet hparams + if len(control_image) > 0: + assert len(control_mode) == len(control_image) # the count of image should be the same as control mode + + ctrl_hparams = { + 'control_mode': [control_mode_dict[mode_] for mode_ in control_mode], + 'control_image': control_image, + 'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))], + 'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))], + 'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))], + } + + hparam_dict.update(ctrl_hparams) + + # generate multi-view images + with torch.no_grad(): + image = flux_pipe( + **hparam_dict + ).images + return image + +# captioning +def run_captioning(image): + device = "cuda" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.bfloat16 + + if isinstance(image, str): # If image is a file path + image = Image.open(image).convert("RGB") + + prompt = "" + inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) + # print(f"inputs {inputs}") + + generated_ids = caption_model.generate( + input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 + ) + + generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + parsed_answer = caption_processor.post_process_generation( + generated_text, task=prompt, image_size=(image.width, image.height) + ) + # print(f"parsed_answer = {parsed_answer}") + caption_text = parsed_answer[""].replace("The image is ", "") + return caption_text + + +# zero123++ multi-view generation +def multi_view_rgb_generation(cond_img): + # generate multi-view images + with torch.no_grad(): + output_image = zero123plus_pipeline( + cond_img, + num_inference_steps=zero123plus_diffusion_steps, + width=resolution*2, + height=resolution*2, + ).images[0] + return output_image + +# lrm reconstructions +def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30): + images = image.unsqueeze(0).to(device) + images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) + # breakpoint() + with torch.no_grad(): + # get triplane + planes = model.forward_planes(images, input_cameras) + + mesh_path_idx = os.path.join(save_path, f'{name}.obj') + + mesh_out = model.extract_mesh( + planes, + use_texture_map=export_texmap, + **infer_config, + ) + if export_texmap: + vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out + save_obj_with_mtl( + vertices.data.cpu().numpy(), + uvs.data.cpu().numpy(), + faces.data.cpu().numpy(), + mesh_tex_idx.data.cpu().numpy(), + tex_map.permute(1, 2, 0).data.cpu().numpy(), + mesh_path_idx, + ) + else: + vertices, faces, vertex_colors = mesh_out + save_obj(vertices, faces, vertex_colors, mesh_path_idx) + print(f"Mesh saved to {mesh_path_idx}") + + render_size = 512 + if if_save_video: + video_path_idx = os.path.join(save_path, f'{name}.mp4') + render_size = infer_config.render_resolution + ENV = load_mipmap("models/lrm/env_mipmap/6") + materials = (0.0,0.9) + + all_mv, all_mvp, all_campos = get_render_cameras_video( + batch_size=1, + M=240, + radius=4.5, + elevation=(90, 60.0), + is_flexicubes=True, + fov=30 + ) + + frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( + model, + planes, + render_cameras=all_mvp, + camera_pos=all_campos, + env=ENV, + materials=materials, + render_size=render_size, + chunk_size=20, + is_flexicubes=True, + ) + normals = (torch.nn.functional.normalize(normals) + 1) / 2 + normals = normals * alphas + (1-alphas) + all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3) + + # breakpoint() + save_video( + all_frames, + video_path_idx, + fps=30, + ) + print(f"Video saved to {video_path_idx}") + + if render_azimuths is not None and render_elevations is not None and render_radius is not None: + render_size = infer_config.render_resolution + ENV = load_mipmap("models/lrm/env_mipmap/6") + materials = (0.0,0.9) + all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames( + batch_size=1, + radius=render_radius, + azimuths=render_azimuths, + elevations=render_elevations, + fov=30 + ) + frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( + model, + planes, + render_cameras=all_mvp, + camera_pos=all_campos, + env=ENV, + materials=materials, + render_size=render_size, + render_mv = all_mv, + local_normal=True, + identity_mv=identity_mv, + ) + else: + normals = None + frames = None + albedos = None + + return vertices, faces, normals, frames, albedos + + +def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False): + """ + input_normal: in range [-1, 1], shape (b c h w) + """ + + input_normal = input_normal.permute(0, 2, 3, 1).cpu() + + azimuths_deg = np.array(azimuths_deg) + elevations_deg = np.array(elevations_deg) + + if is_global_to_local: + local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg) + return local_normal.permute(0, 3, 1, 2) + else: + global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False) + global_normal[..., 0] *= -1 + return global_normal.permute(0, 3, 1, 2) + +def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg): + if local_normal_images.min() >= 0: + local_normal = local_normal_images.float() * 2 - 1 + else: + local_normal = local_normal_images.float() + global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) + global_normal[...,0] *= -1 + global_normal = (global_normal + 1) / 2 + global_normal = global_normal.permute(0, 3, 1, 2) + return global_normal + +def main(): + image_pth = "examples/蓝色小怪物.webp" + save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0]) + os.makedirs(save_dir_path, exist_ok=True) + input_image = Image.open(image_pth) + # if not args.no_rembg: + input_image = remove_background(input_image, rembg_session) + input_image = resize_foreground(input_image, 0.85) + + # generate caption + image_caption = run_captioning(image_pth) + + # generate multi-view images + output_image = multi_view_rgb_generation(input_image) + + # lrm reconstructions + rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0 + rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) + rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512) + + input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device) + + vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \ + lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', + export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths, + render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30) + + vertices = torch.from_numpy(vertices).to(device) + faces = torch.from_numpy(faces).to(device) + vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3] + vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3] + + + # lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1] + lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1] + # rgb_multi_view[[3,0,1,2]] : (B,3,H,W) + # lrm_multi_view_normals : (B,3,H,W) + # combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2 + # torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png')) + # breakpoint() + # Use the low-quality controlnet by default, feel free to try the others + control_image = [lrm_3D_bundle_image * 2 - 1] + control_mode = ['tile'] + control_guidance_start = [0.0] + control_guidance_end = [0.3] + controlnet_conditioning_scale = [0.8] + + flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode]) + # breakpoint() + rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet( + prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]), + image=lrm_3D_bundle_image, + strength=0.6, + control_image=control_image, + control_mode=control_mode, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + controlnet_conditioning_scale=controlnet_conditioning_scale, + lora_scale=1.0 + ) # noted that rgb_normal_grid is a (b, h, w, c) numpy array + + rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float() + rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512) + rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda() + normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda() + multi_view_mask = get_background(normal_multi_view).cuda() + rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask) + + # local normal to global normal + global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda() + + global_normal = global_normal * multi_view_mask + (1-multi_view_mask) + + global_normal = global_normal.permute(0,2,3,1) + multi_view_mask = multi_view_mask.squeeze(1) + rgb_multi_view = rgb_multi_view.permute(0,2,3,1) + # global_normal: B,H,W,3 + # multi_view_mask: B,H,W + # rgb_multi_view: B,H,W,3 + + + meshes = reconstruction( + normal_pils=global_normal, + masks=multi_view_mask, + weights=isomer_geo_weights, + fov=30, + radius=isomer_radius, + camera_angles_azi=isomer_azimuths, + camera_angles_ele=isomer_elevations, + expansion_weight_stage1=0.1, + init_type="file", + init_verts=vertices, + init_faces=faces, + stage1_steps=0, + stage2_steps=50, + start_edge_len_stage1=0.1, + end_edge_len_stage1=0.02, + start_edge_len_stage2=0.02, + end_edge_len_stage2=0.005, + ) + + save_glb_addr = projection( + meshes=meshes, + masks=multi_view_mask, + images=rgb_multi_view, + azimuths=isomer_azimuths, + elevations=isomer_elevations, + weights=isomer_color_weights, + fov=30, + radius=isomer_radius, + save_dir=f"{save_dir_path}/ISOMER/", + ) + print(f'saved to {save_glb_addr}') + + + +if __name__ == '__main__': + main() diff --git a/models/ISOMER/__init__.py b/models/ISOMER/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/ISOMER/data/__init__.py b/models/ISOMER/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/ISOMER/data/utils.py b/models/ISOMER/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4408ecc161e6a148e75fc92845456e201caf149 --- /dev/null +++ b/models/ISOMER/data/utils.py @@ -0,0 +1,87 @@ +import torch +import numpy as np +from PIL import Image +import os +from pytorch3d.io import load_obj +import trimesh +from pytorch3d.structures import Meshes +# from rembg import remove + +def remove_color(arr): + if arr.shape[-1] == 4: + arr = arr[..., :3] + + # Convert to torch tensor + if type(arr) is not torch.Tensor: + arr = torch.tensor(arr, dtype=torch.int32) + + # Calculate diffs + base = arr[0, 0] + diffs = torch.abs(arr - base).sum(dim=-1) + alpha = (diffs <= 80) + + arr[alpha] = 255 + alpha = ~alpha + alpha = alpha.unsqueeze(-1).int() * 255 + arr = torch.cat([arr, alpha], dim=-1) + + return arr + +def simple_remove_bkg_normal(imgs, rm_bkg_with_rembg, return_Image=False): + """Only works for normal""" + rets = [] + for img in imgs: + if rm_bkg_with_rembg: + from rembg import remove + image = Image.fromarray(img.to(torch.uint8).detach().cpu().numpy()) if isinstance(img, torch.Tensor) else img + removed_image = remove(image) + arr = np.array(removed_image) + arr = torch.tensor(arr, dtype=torch.uint8) + else: + arr = remove_color(img) + + if return_Image: + rets.append(Image.fromarray(arr.to(torch.uint8).detach().cpu().numpy())) + else: + rets.append(arr.to(torch.uint8)) + + return rets + + +def load_glb(file_path): + # Load the .glb file as a scene and merge all meshes + scene_or_mesh = trimesh.load(file_path) + + mesh = scene_or_mesh.dump(concatenate=True) if isinstance(scene_or_mesh, trimesh.Scene) else scene_or_mesh + + # Extract vertices and faces from the merged mesh + verts = torch.tensor(mesh.vertices, dtype=torch.float32) + faces = torch.tensor(mesh.faces, dtype=torch.int64) + + + textured_mesh = Meshes(verts=[verts], faces=[faces]) + + + return textured_mesh + +def load_obj_with_verts_faces(file_path, return_mesh=True): + verts, faces, _ = load_obj(file_path) + + verts = torch.tensor(verts, dtype=torch.float32) + faces = faces.verts_idx + faces = torch.tensor(faces, dtype=torch.int64) + + if return_mesh: + return Meshes(verts=[verts], faces=[faces]) + else: + return verts, faces + +def normalize_mesh(vertices): + min_vals, _ = torch.min(vertices, axis=0) + max_vals, _ = torch.max(vertices, axis=0) + center = (max_vals + min_vals) / 2 + vertices = vertices - center + max_extent = torch.max(max_vals - min_vals) + scale = 2.0 / max_extent + vertices = vertices * scale + return vertices diff --git a/models/ISOMER/mesh_reconstruction/__init__.py b/models/ISOMER/mesh_reconstruction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/ISOMER/mesh_reconstruction/func.py b/models/ISOMER/mesh_reconstruction/func.py new file mode 100644 index 0000000000000000000000000000000000000000..399a030e70f975bfd247f94d393327cec95899f7 --- /dev/null +++ b/models/ISOMER/mesh_reconstruction/func.py @@ -0,0 +1,227 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import torch +import numpy as np +import trimesh +from typing import Tuple +from pytorch3d.renderer.cameras import camera_position_from_spherical_angles, look_at_rotation +from pytorch3d.renderer import ( + FoVOrthographicCameras, + look_at_view_transform, +) + +def to_numpy(*args): + def convert(a): + if isinstance(a,torch.Tensor): + return a.detach().cpu().numpy() + assert a is None or isinstance(a,np.ndarray) + return a + + return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args) + +def laplacian( + num_verts:int, + edges: torch.Tensor #E,2 + ) -> torch.Tensor: #sparse V,V + """create sparse Laplacian matrix""" + V = num_verts + E = edges.shape[0] + + #adjacency matrix, + idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E) + ones = torch.ones(2*E, dtype=torch.float32, device=edges.device) + A = torch.sparse.FloatTensor(idx, ones, (V, V)) + + #degree matrix + deg = torch.sparse.sum(A, dim=1).to_dense() + idx = torch.arange(V, device=edges.device) + idx = torch.stack([idx, idx], dim=0) + D = torch.sparse.FloatTensor(idx, deg, (V, V)) + + return D - A + +def _translation(x, y, z, device): + return torch.tensor([[1., 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]],device=device) #4,4 + + +def _perspective(fovy, aspect=1.0, n=0.1, f=1000.0, device=None): + fovy = fovy * torch.pi / 180 + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + """ + see https://blog.csdn.net/wodownload2/article/details/85069240/ + """ + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + p = torch.zeros([4,4],device=device) + p[0,0] = 2*n/(r-l) + p[0,2] = (r+l)/(r-l) + p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1) + p[1,2] = (t+b)/(t-b) + p[2,2] = -(f+n)/(f-n) + p[2,3] = -(2*f*n)/(f-n) + p[3,2] = -1 + return p #4,4 + +def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + o = torch.zeros([4,4],device=device) + o[0,0] = 2/(r-l) + o[0,3] = -(r+l)/(r-l) + o[1,1] = 2/(t-b) * (-1 if flip_y else 1) + o[1,3] = -(t+b)/(t-b) + o[2,2] = -2/(f-n) + o[2,3] = -(f+n)/(f-n) + o[3,3] = 1 + return o #4,4 + +def make_star_cameras_orig(phis,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): + if r is None: + r = 1/distance + A = len(phis) + P = pol_count + C = A * P # total number of cameras + + phi = phis * torch.pi / 180 + phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone() + phi_rot[:,0,2,2] = phi.cos() + phi_rot[:,0,2,0] = -phi.sin() + phi_rot[:,0,0,2] = phi.sin() + phi_rot[:,0,0,0] = phi.cos() + + theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2 + theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone() + theta_rot[0,:,1,1] = theta.cos() + theta_rot[0,:,1,2] = -theta.sin() + theta_rot[0,:,2,1] = theta.sin() + theta_rot[0,:,2,2] = theta.cos() + + mv = torch.empty((C,4,4), device=device) + mv[:] = torch.eye(4, device=device) + mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3) + mv_ = _translation(0, 0, -distance, device) @ mv + + return mv_, _projection(r,device) + +def make_star_cameras_mv_new(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True): + import glm + def sample_spherical(phi, theta, cam_radius): + theta = torch.deg2rad(theta) + phi = torch.deg2rad(phi) + + z = cam_radius * torch.cos(phi) * torch.sin(theta) + x = cam_radius * torch.sin(phi) * torch.sin(theta) + y = cam_radius * torch.cos(theta) + + return x, y, z + + all_mvs = [] + for i in range(len(phis)): + azimuth = - phis[i] + 1e-10 + ele = - eles[i] + 1e-10 + 90 + x, y, z = sample_spherical(azimuth, ele, distance) + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + all_mvs.append(torch.from_numpy(np.array(view_matrix)).cuda()) + mv = torch.stack(all_mvs) + + return mv + +def make_star_cameras_mv(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True): + if r is None: + r = 0.15 + A = len(phis) + assert len(eles) == A, f'len(phis): {len(phis)}, len(eles): {len(eles)}' + + phi = phis * torch.pi / 180 + phi_rot = torch.eye(3,device=device)[None].expand(A,3,3).clone() + phi_rot[:,2,2] = phi.cos() + phi_rot[:,2,0] = -phi.sin() + phi_rot[:,0,2] = phi.sin() + phi_rot[:,0,0] = phi.cos() + + + theta = eles * torch.pi / 180 + theta_rot = torch.eye(3,device=device)[None].expand(A,3,3).clone() + theta_rot[:,1,1] = theta.cos() + theta_rot[:,1,2] = -theta.sin() + theta_rot[:,2,1] = theta.sin() + theta_rot[:,2,2] = theta.cos() + + mv = torch.empty((A,4,4), device=device) + mv[:] = torch.eye(4, device=device) + mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(A,3,3) + + if translation: + mv_ = _translation(0, 0, -distance, device) @ mv + else: + mv_ = mv + return mv_ + +def make_star_cameras(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True): + mv_ = make_star_cameras_mv_new(phis, eles, distance, r, device=device, translation=translation) + return mv_, _perspective(fov,device=device) + +def make_star_cameras_perspective(phis, eles, distance:float=10., r=None, fov=None, device='cuda'): + + return make_star_cameras(phis, eles, distance, r, fov=fov, device=device, translation=True) + +def make_star_cameras_orthographic(phis, eles, distance:float=10., r=None, device='cuda'): + + mv = make_star_cameras_mv_new(phis, eles, distance, r, device=device) + if r is None: + r = 1 + return mv, _orthographic(r,device) + +def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]: + sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None) + vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius + faces = torch.tensor(sphere.faces, device=device, dtype=torch.long) + return vertices,faces + + +def get_camera(R, T, focal_length=1 / (2**0.5)): + focal_length = 1 / focal_length + camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) + return camera + +def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1): + R, T = look_at_view_transform(dist, 0, azim_list) + focal_length = 1 / focal + return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device) + + +def rotation_matrix_to_euler_angles(R, return_degrees=True): + sy = torch.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) + singular = sy < 1e-6 + if not singular: + x = torch.atan2(R[2, 1], R[2, 2]) + y = torch.atan2(-R[2, 0], sy) + z = torch.atan2(R[1, 0], R[0, 0]) + else: + x = torch.atan2(-R[1, 2], R[1, 1]) + y = torch.atan2(-R[2, 0], sy) + z = 0 + + if return_degrees: + return torch.tensor([x, y, z]) * 180 / np.pi + else: + return torch.tensor([x, y, z]) diff --git a/models/ISOMER/mesh_reconstruction/opt.py b/models/ISOMER/mesh_reconstruction/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..438948ef7e1f833049a838d64318bcf68f8422d6 --- /dev/null +++ b/models/ISOMER/mesh_reconstruction/opt.py @@ -0,0 +1,191 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import time +import torch +import torch_scatter +from typing import Tuple +from ..mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges + +@torch.no_grad() +def remesh( + vertices_etc:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + min_edgelen:torch.Tensor, #V + max_edgelen:torch.Tensor, #V + flip:bool, + max_vertices=1e6 + ): + + # dummies + vertices_etc,faces = prepend_dummies(vertices_etc,faces) + vertices = vertices_etc[:,:3] #V,3 + nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device) + min_edgelen = torch.concat((nan_tensor,min_edgelen)) + max_edgelen = torch.concat((nan_tensor,max_edgelen)) + + # collapse + edges,face_to_edge = calc_edges(faces) #E,2 F,3 + edge_length = calc_edge_length(vertices,edges) #E + face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3 + vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3 + # then calculates the face collapses, which are the faces that can be removed without changing the overall shape of the object. + face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5) + shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0 + priority = face_collapse.float() + shortness + vertices_etc, faces = collapse_edges(vertices_etc, faces, edges, priority) + + # split: If the number of vertices is less than the maximum allowed, the function splits the edges that are longer than the maximum edge length. + if vertices.shape[0] max_edgelen[edges].mean(dim=-1) + vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False) + + vertices_etc,faces = pack(vertices_etc,faces) + vertices = vertices_etc[:,:3] + + if flip: # flips the edges of the faces + edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3 + flip_edges(vertices,faces,edges,edge_to_face,with_border=False) + + return remove_dummies(vertices_etc,faces) + +def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int): + """lerp with adam's bias correction""" + c_prev = 1-weight**(step-1) + c = 1-weight**step + a_weight = weight*c_prev/c + b_weight = (1-weight)/c + a.mul_(a_weight).add_(b, alpha=b_weight) + + +class MeshOptimizer: + """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh().""" + + def __init__(self, + vertices:torch.Tensor, #V,3 + faces:torch.Tensor, #F,3 + lr=0.3, #learning rate + betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu + gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing) + nu_ref=0.3, #reference velocity for edge length controller + edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length + edge_len_tol=.5, #edge length tolerance for split and collapse + gain=.2, #gain value for edge length controller + laplacian_weight=.02, #for laplacian smoothing/regularization + ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0]) + grad_lim=10., #gradients are clipped to m1.abs()*grad_lim + remesh_interval=1, #larger intervals are faster but with worse mesh quality + local_edgelen=True, #set to False to use a global scalar reference edge length instead + ): + self._vertices = vertices + self._faces = faces + self._lr = lr + self._betas = betas + self._gammas = gammas + self._nu_ref = nu_ref + self._edge_len_lims = edge_len_lims + self._edge_len_tol = edge_len_tol + self._gain = gain + self._laplacian_weight = laplacian_weight + self._ramp = ramp + self._grad_lim = grad_lim + self._remesh_interval = remesh_interval + self._local_edgelen = local_edgelen + self._step = 0 + + V = self._vertices.shape[0] + # prepare continuous tensor for all vertex-based data + self._vertices_etc = torch.zeros([V,9],device=vertices.device) + self._split_vertices_etc() + self.vertices.copy_(vertices) #initialize vertices + self._vertices.requires_grad_() + self._ref_len.fill_(edge_len_lims[1]) + + @property + def vertices(self): + return self._vertices + + @property + def faces(self): + return self._faces + + def _split_vertices_etc(self): + self._vertices = self._vertices_etc[:,:3] + self._m2 = self._vertices_etc[:,3] + self._nu = self._vertices_etc[:,4] + self._m1 = self._vertices_etc[:,5:8] + self._ref_len = self._vertices_etc[:,8] + + with_gammas = any(g!=0 for g in self._gammas) + self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3] + + def zero_grad(self): + self._vertices.grad = None + + @torch.no_grad() + def step(self): + + eps = 1e-8 + + self._step += 1 + + # spatial smoothing + edges,_ = calc_edges(self._faces) #E,2 + E = edges.shape[0] + edge_smooth = self._smooth[edges] #E,2,S + neighbor_smooth = torch.zeros_like(self._smooth) #V,S + torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth) + + #apply optional smoothing of m1,m2,nu + if self._gammas[0]: + self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0]) + if self._gammas[1]: + self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1]) + if self._gammas[2]: + self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2]) + + #add laplace smoothing to gradients + laplace = self._vertices - neighbor_smooth[:,:3] + grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight) + + #gradient clipping + if self._step>1: + grad_lim = self._m1.abs().mul_(self._grad_lim) + grad.clamp_(min=-grad_lim,max=grad_lim) + + # moment updates + lerp_unbiased(self._m1, grad, self._betas[0], self._step) + lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step) + + velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3 + speed = velocity.norm(dim=-1) #V + + if self._betas[2]: + lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V + else: + self._nu.copy_(speed) #V + + # update vertices + ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp) + self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr) + + # update target edge length + if self._step % self._remesh_interval == 0: + if self._local_edgelen: + len_change = (1 + (self._nu - self._nu_ref) * self._gain) + else: + len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain) + self._ref_len *= len_change + self._ref_len.clamp_(*self._edge_len_lims) + + def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]: + min_edge_len = self._ref_len * (1 - self._edge_len_tol) + max_edge_len = self._ref_len * (1 + self._edge_len_tol) + + self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6) + + self._split_vertices_etc() + self._vertices.requires_grad_() + + return self._vertices, self._faces diff --git a/models/ISOMER/mesh_reconstruction/recon.py b/models/ISOMER/mesh_reconstruction/recon.py new file mode 100644 index 0000000000000000000000000000000000000000..54bfc8cdc7de13e9f2f82b4529434769e9e8aba2 --- /dev/null +++ b/models/ISOMER/mesh_reconstruction/recon.py @@ -0,0 +1,58 @@ +from tqdm import tqdm +from PIL import Image +import numpy as np +import torch +from torchvision.utils import make_grid +from typing import List +from ..mesh_reconstruction.remesh import calc_vertex_normals +from ..mesh_reconstruction.opt import MeshOptimizer +from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d +from ..mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer +from ..scripts.utils import to_py3d_mesh, init_target + +def reconstruct_stage1(pils: List[Image.Image], mv, proj, steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1, use_remesh=True): + + vertices, faces = vertices.to("cuda"), faces.to("cuda") + + renderer = NormalsRenderer(mv,proj,list(pils[0].size)) + + + target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s + + opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len)) + + vertices = opt.vertices + + mask = target_images[..., -1] < 0.5 + + for i in tqdm(range(steps)): + opt._lr *= decay + + normals = calc_vertex_normals(vertices,faces) + images = renderer.render(vertices,normals,faces) + + loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean() + + t_mask = images[..., -1] > 0.5 + loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean() + + loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() + + loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight + + loss_oob = (vertices.abs() > 0.99).float().mean() * 10 + loss = loss + loss_oob + + + loss.backward() + opt.step() + + if use_remesh: + vertices,faces = opt.remesh(poisson=False) + + vertices, faces = vertices.detach(), faces.detach() + + if return_mesh: + return to_py3d_mesh(vertices, faces) + else: + return vertices, faces diff --git a/models/ISOMER/mesh_reconstruction/refine.py b/models/ISOMER/mesh_reconstruction/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7fa5b98131c98b71487912f239dcfb33cd1b4d --- /dev/null +++ b/models/ISOMER/mesh_reconstruction/refine.py @@ -0,0 +1,86 @@ +from tqdm import tqdm +from PIL import Image +import torch +import numpy as np +from typing import List +from ..mesh_reconstruction.remesh import calc_vertex_normals +from ..mesh_reconstruction.opt import MeshOptimizer +from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d +from ..mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer +from ..scripts.project_mesh import multiview_color_projection, get_cameras_list +from ..scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target + +def run_mesh_refine(vertices, faces, pils: List[Image.Image], mv, proj, weights, cameras, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True, use_remesh=True, loss_expansion_weight=0): + + if process_inputs: + vertices = vertices * 2 / 1.35 + vertices[..., [0, 2]] = - vertices[..., [0, 2]] + + poission_steps = [] + + renderer = NormalsRenderer(mv,proj,list(pils[0].size)) + + + target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s + + opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02) + + vertices = opt.vertices + alpha_init = None + + mask = target_images[..., -1] < 0.5 + + for i in tqdm(range(steps)): + opt.zero_grad() + opt._lr *= decay + normals = calc_vertex_normals(vertices,faces) + images = renderer.render(vertices,normals,faces) + + if alpha_init is None: + alpha_init = images.detach() + + # update explicit target and render images for L_ET calculation + if i < update_warmup or i % update_normal_interval == 0: + with torch.no_grad(): + + py3d_mesh = to_py3d_mesh(vertices, faces, normals) + + _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=weights, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear')) + + target_normal = target_normal * 2 - 1 + target_normal = torch.nn.functional.normalize(target_normal, dim=-1) + debug_images = renderer.render(vertices,target_normal,faces) + + d_mask = images[..., -1] > 0.5 + loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean() + + loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() + + loss = loss_debug_l2 + loss_alpha_target_mask_l2 + + loss_oob = (vertices.abs() > 0.99).float().mean() * 10 + + loss = loss + loss_oob + + + # this loss_expand does not exist in original ISOMER. we add it here (but default loss_expansion_weight is 0) + loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean() + loss += loss_expand * loss_expansion_weight + + loss.backward() + opt.step() + + + if use_remesh: + vertices,faces = opt.remesh(poisson=(i in poission_steps)) + + vertices, faces = vertices.detach(), faces.detach() + + if process_outputs: + vertices = vertices / 2 * 1.35 + vertices[..., [0, 2]] = - vertices[..., [0, 2]] + + if return_mesh: + return to_py3d_mesh(vertices, faces) + else: + return vertices, faces diff --git a/models/ISOMER/mesh_reconstruction/remesh.py b/models/ISOMER/mesh_reconstruction/remesh.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e78d1df700bc25ae6cd07a6daad99b4eab2b9a --- /dev/null +++ b/models/ISOMER/mesh_reconstruction/remesh.py @@ -0,0 +1,363 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import torch +import torch.nn.functional as tfunc +import torch_scatter +from typing import Tuple + +def prepend_dummies( + vertices:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + )->Tuple[torch.Tensor,torch.Tensor]: + """prepend dummy elements to vertices and faces to enable "masked" scatter operations""" + V,D = vertices.shape + vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0) + faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0) + return vertices,faces + +def remove_dummies( + vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced + faces:torch.Tensor, #F,3 long - first face all zeros + )->Tuple[torch.Tensor,torch.Tensor]: + """remove dummy elements added with prepend_dummies()""" + return vertices[1:],faces[1:]-1 + + +def calc_edges( + faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros + with_edge_to_face: bool = False + ) -> Tuple[torch.Tensor, ...]: + """ + returns Tuple of + - edges E,2 long, 0 for unused, lower vertex index first + - face_to_edge F,3 long + - (optional) edge_to_face shape=E,[left,right],[face,side] + + o-<-----e1 e0,e1...edge, e0-o + """ + + F = faces.shape[0] + + # make full edges, lower vertex index first + face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2 + full_edges = face_edges.reshape(F*3,2) + sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 + + # make unique edges + edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3) + E = edges.shape[0] + face_to_edge = full_to_unique.reshape(F,3) #F,3 + + if not with_edge_to_face: + return edges, face_to_edge + + is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3 + edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2 + scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2 + edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2 + edge_to_face[0] = 0 + return edges, face_to_edge, edge_to_face + +def calc_edge_length( + vertices:torch.Tensor, #V,3 first may be dummy + edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused + )->torch.Tensor: #E + + full_vertices = vertices[edges] #E,2,3 + a,b = full_vertices.unbind(dim=1) #E,3 + return torch.norm(a-b,p=2,dim=-1) + +def calc_face_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + normalize:bool=False, + )->torch.Tensor: #F,3 + """ + n + | + c0 corners ordered counterclockwise when + / \ looking onto surface (in neg normal direction) + c1---c2 + """ + full_vertices = vertices[faces] #F,C=3,3 + v0,v1,v2 = full_vertices.unbind(dim=1) #F,3 + face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3 + if normalize: + face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) + return face_normals #F,3 + +def calc_vertex_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + face_normals:torch.Tensor=None, #F,3, not normalized + )->torch.Tensor: #F,3 + + F = faces.shape[0] + + if face_normals is None: + face_normals = calc_face_normals(vertices,faces) # this no grad + + vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3 + + + vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) # This no grad + vertex_normals = vertex_normals.sum(dim=1) #V,3 + return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) + +def calc_face_ref_normals( + faces:torch.Tensor, #F,3 long, 0 for unused + vertex_normals:torch.Tensor, #V,3 first unused + normalize:bool=False, + )->torch.Tensor: #F,3 + """calculate reference normals for face flip detection""" + full_normals = vertex_normals[faces] #F,C=3,3 + ref_normals = full_normals.sum(dim=1) #F,3 + if normalize: + ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1) + return ref_normals + +def pack( + vertices:torch.Tensor, #V,3 first unused and nan + faces:torch.Tensor, #F,3 long, 0 for unused + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused + """removes unused elements in vertices and faces""" + V = vertices.shape[0] + + # remove unused faces + used_faces = faces[:,0]!=0 + used_faces[0] = True + faces = faces[used_faces] #sync + + # remove unused vertices + used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device) + used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') + used_vertices = used_vertices.any(dim=1) + used_vertices[0] = True + vertices = vertices[used_vertices] #sync + + # update used faces + ind = torch.zeros(V,dtype=torch.long,device=vertices.device) + V1 = used_vertices.sum() + ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync + faces = ind[faces] + + return vertices,faces + +def split_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + splits, #E bool + pack_faces:bool=True, + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + # c2 c2 c...corners = faces + # . . . . s...side_vert, 0 means no split + # . . .N2 . S...shrunk_face + # . . . . Ni...new_faces + # s2 s1 s2|c2...s1|c1 + # . . . . . + # . . . S . . + # . . . . N1 . + # c0...(s0=0)....c1 s0|c0...........c1 + # + # pseudo-code: + # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2] + # split = side_vert!=0 example:[False,True,True] + # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0] + # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0] + # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1] + + V = vertices.shape[0] + F = faces.shape[0] + S = splits.sum().item() #sync + + if S==0: + return vertices,faces + + edge_vert = torch.zeros_like(splits, dtype=torch.long) #E + edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync + side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split + split_edges = edges[splits] #S sync + + #vertices + split_vertices = vertices[split_edges].mean(dim=1) #S,3 + vertices = torch.concat((vertices,split_vertices),dim=0) + + #faces + side_split = side_vert!=0 #F,3 + shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split + new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3 + faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3 + if pack_faces: + mask = faces[:,0]!=0 + mask[0] = True + faces = faces[mask] #F',3 sync + + return vertices,faces + +def collapse_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + priorities:torch.Tensor, #E float + stable:bool=False, #only for unit testing + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + V = vertices.shape[0] + + # check spacing + _,order = priorities.sort(stable=stable) #E + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + edge_rank = rank #E + for i in range(3): + torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank) + edge_rank,_ = vert_rank[edges].max(dim=-1) #E + candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2 + + # check connectivity + vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + vert_connections[candidates[:,0]] = 1 #start + edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start + vert_connections[candidates] = 0 #clear start and end + edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start + collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end + + # mean vertices + vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) + + # update faces + dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V + dest[collapses[:,1]] = dest[collapses[:,0]] + faces = dest[faces] #F,3 + c0,c1,c2 = faces.unbind(dim=-1) + collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2) + faces[collapsed] = 0 + + return vertices,faces + +def calc_face_collapses( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + edge_length:torch.Tensor, #E + face_normals:torch.Tensor, #F,3 + vertex_normals:torch.Tensor, #V,3 first unused + min_edge_length:torch.Tensor=None, #V + area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio + shortest_probability = 0.8 + )->torch.Tensor: #E edges to collapse + + E = edges.shape[0] + F = faces.shape[0] + + # face flips + ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3 + face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F + + # small faces + if min_edge_length is not None: + min_face_length = min_edge_length[faces].mean(dim=-1) #F + min_area = min_face_length**2 * area_ratio #F + face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F + face_collapses[0] = False + + # faces to edges + face_length = edge_length[face_to_edge] #F,3 + + if shortest_probability<1: + #select shortest edge with shortest_probability chance + randlim = round(2/(1-shortest_probability)) + rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face + sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3 + local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None]) + else: + local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face + + edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index + edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device) + edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) + + return edge_collapses.bool() + +def flip_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused + edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first + edge_to_face:torch.Tensor, #E,[left,right],[face,side] + with_border:bool=True, #handle border edges (D=4 instead of D=6) + with_normal_check:bool=True, #check face normal flips + stable:bool=False, #only for unit testing + ): + V = vertices.shape[0] + E = edges.shape[0] + device=vertices.device + vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long + vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add') + neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner + neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2 + edge_is_inside = neighbors.all(dim=-1) #E + + if with_border: + # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices + # need to use float for masks in order to use scatter(reduce='multiply') + vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float + src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float + vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply') + vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long + vertex_degree -= 2 * vertex_is_inside #V long + + neighbor_degrees = vertex_degree[neighbors] #E,LR=2 + edge_degrees = vertex_degree[edges] #E,2 + # + # loss = Sum_over_affected_vertices((new_degree-6)**2) + # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2) + # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2) + # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree)) + # + loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E + candidates = torch.logical_and(loss_change<0, edge_is_inside) #E + loss_change = loss_change[candidates] #E' + if loss_change.shape[0]==0: + return + + edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4 + _,order = loss_change.sort(descending=True, stable=stable) #E' + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4 + torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank) + vertex_rank,_ = vertex_rank.max(dim=-1) #V + neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E' + flip = rank==neighborhood_rank #E' + + if with_normal_check: + # cl-<-----e1 e0,e1...edge, e0-cr + v = vertices[edges_neighbors] #E",4,3 + v = v - v[:,0:1] #make relative to e0 + e1 = v[:,1] + cl = v[:,2] + cr = v[:,3] + n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors + flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face + flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face + + flip_edges_neighbors = edges_neighbors[flip] #E",4 + flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2 + flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3 + faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3)) diff --git a/models/ISOMER/mesh_reconstruction/render.py b/models/ISOMER/mesh_reconstruction/render.py new file mode 100644 index 0000000000000000000000000000000000000000..699428b99003a3a0f95a05cdc851eb9a0334cdb3 --- /dev/null +++ b/models/ISOMER/mesh_reconstruction/render.py @@ -0,0 +1,142 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import nvdiffrast.torch as dr +import torch +from typing import Tuple + +def _warmup(glctx, device=None): + device = 'cuda' if device is None else device + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + def tensor(*args, **kwargs): + return torch.tensor(*args, device=device, **kwargs) + + # defines a triangle in homogeneous coordinates and calls dr.rasterize to render this triangle, which may help to initialize or warm up the GPU context + pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) + tri = tensor([[0, 1, 2]], dtype=torch.int32) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +# glctx = dr.RasterizeGLContext(output_db=False, device="cuda") +glctx = dr.RasterizeCudaContext(device="cuda") + +class NormalsRenderer: + + _glctx:dr.RasterizeCudaContext = None + + def __init__( + self, + mv: torch.Tensor, #C,4,4 # normal column-major (unlike pytorch3d) + proj: torch.Tensor, #C,4,4 + image_size: Tuple[int,int], + mvp = None, + device=None, + ): + if mvp is None: + self._mvp = proj @ mv #C,4,4 + else: + self._mvp = mvp + self._image_size = image_size + self._glctx = glctx + _warmup(self._glctx, device) + + def render(self, + vertices: torch.Tensor, #V,3 float + normals: torch.Tensor, #V,3 float in [-1, 1] + faces: torch.Tensor, #F,3 long + ) ->torch.Tensor: #C,H,W,4 + + V = vertices.shape[0] + faces = faces.type(torch.int32) + vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4 + # transforms the vertices into clip space using the mvp matrix. + vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 # the .transpose(-2,-1) operation ensures that the matrix multiplication aligns with the row-major convention. + rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 -> 4 includes the barycentric coordinates and other data. + vert_col = (normals+1)/2 #V,3 + # this function takes the attributes (colors) defined at the vertices and computes their values at each pixel (or fragment) within the triangles + col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3 + alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1 + col = torch.concat((col,alpha),dim=-1) #C,H,W,4 + col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4 + return col #C,H,W,4 + + + +from pytorch3d.structures import Meshes +from pytorch3d.renderer.mesh.shader import ShaderBase +from pytorch3d.renderer import ( + RasterizationSettings, + MeshRendererWithFragments, + TexturesVertex, + MeshRasterizer, + BlendParams, + FoVOrthographicCameras, + look_at_view_transform, + hard_rgb_blend, +) + +class VertexColorShader(ShaderBase): + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + blend_params = kwargs.get("blend_params", self.blend_params) + texels = meshes.sample_textures(fragments) + return hard_rgb_blend(texels, fragments, blend_params) + +def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"): + if len(mesh) != len(cameras): + if len(cameras) % len(mesh) == 0: + mesh = mesh.extend(len(cameras)) + else: + raise NotImplementedError() + + # render requires everything in float16 or float32 + input_dtype = dtype + blend_params = BlendParams(1e-4, 1e-4, bkgd) + + # Define the settings for rasterization and shading + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=True, + bin_size=None, + max_faces_per_bin=None, + ) + + # Create a renderer by composing a rasterizer and a shader + # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used) + renderer = MeshRendererWithFragments( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=raster_settings + ), + shader=VertexColorShader( + device=device, + cameras=cameras, + blend_params=blend_params + ) + ) + + # render RGB and depth, get mask + with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type): + images, _ = renderer(mesh) + return images # BHW4 + +class Pytorch3DNormalsRenderer: # 100 times slower!!! + def __init__(self, cameras, image_size, device): + self.cameras = cameras.to(device) + self._image_size = image_size + self.device = device + + def render(self, + vertices: torch.Tensor, #V,3 float + normals: torch.Tensor, #V,3 float in [-1, 1] + faces: torch.Tensor, #F,3 long + ) ->torch.Tensor: #C,H,W,4 + mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device) + return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device) + +def save_tensor_to_img(tensor, save_dir): + from PIL import Image + import numpy as np + for idx, img in enumerate(tensor): + img = img[..., :3].cpu().numpy() + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + img.save(save_dir + f"{idx}.png") diff --git a/models/ISOMER/model/__init__.py b/models/ISOMER/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/ISOMER/model/inference_pipeline.py b/models/ISOMER/model/inference_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..754713ed84fa4a82975285cf88d7c891ca26bdc3 --- /dev/null +++ b/models/ISOMER/model/inference_pipeline.py @@ -0,0 +1,189 @@ +import os +import numpy as np +import torch +from PIL import Image + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import TexturesVertex + +from ..scripts.fast_geo import fast_geo, create_sphere, create_box +from ..scripts.project_mesh import get_cameras_list_azi_ele +from ..mesh_reconstruction.recon import reconstruct_stage1 +from ..mesh_reconstruction.refine import run_mesh_refine +from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_perspective + +from ..data.utils import ( + simple_remove_bkg_normal, + load_glb, + load_obj_with_verts_faces) +from ..scripts.utils import ( + to_pyml_mesh, + simple_clean_mesh, + normal_rotation_img2img_c2w, + rotate_normal_R, + get_rotation_matrix_azi_ele, + manage_elevation_azimuth) + +@torch.enable_grad() +def reconstruction_pipe(normal_pils, + rotation_angles_azi, + rotation_angles_ele, + front_index=0, + back_index=2, + side_index=1, + weights=None, + expansion_weight=0.1, + expansion_weight_stage2=0.0, + init_type="ball", + sphere_r=None, # only used if init_type=="ball" + box_width=1.0, # only used if init_type=="box" + box_length=1.0, # only used if init_type=="box" + box_height=1.0, # only used if init_type=="box" + init_verts=None, + init_faces=None, + init_mesh_from_file="", + stage1_steps=200, + stage2_steps=200, + projection_type="orthographic", + fovy=None, + radius=None, + ortho_dist=1.1, + camera_angles_azi=None, + camera_angles_ele=None, + rm_bkg=False, + rm_bkg_with_rembg=False, # only used if rm_bkg + normal_rotation_R=None, + train_stage1=True, + train_stage2=True, + use_remesh_stage1=True, + use_remesh_stage2=True, + start_edge_len_stage1=0.1, + end_edge_len_stage1=0.02, + start_edge_len_stage2=0.02, + end_edge_len_stage2=0.005, + ): + + assert projection_type in ['perspective', 'orthographic'], f"projection_type ({projection_type}) should be one of ['perspective', 'orthographic']" + + if stage1_steps == 0: + train_stage1 = False + if stage2_steps == 0: + train_stage2 = False + + if normal_rotation_R is not None: + assert normal_rotation_R.shape[-2] == 3 and normal_rotation_R.shape[-1] == 3 + assert len(normal_rotation_R.shape) == 2 + normal_rotation_R = normal_rotation_R.float() + + camera_angles_azi = camera_angles_azi.float() + camera_angles_ele = camera_angles_ele.float() + + camera_angles_ele, camera_angles_azi = manage_elevation_azimuth(camera_angles_ele, camera_angles_azi) + + if init_type in ["std", "thin"]: + assert camera_angles_azi[front_index]%360==0, f"the camera_angles_azi associated with front image (index {front_index}) should be 0 not {camera_angles_azi[front_index]}" + assert camera_angles_azi[back_index]%360==180, f"the camera_angles_azi associated with back image (index {back_index}) should be 180 not {camera_angles_azi[back_index]}" + assert camera_angles_azi[side_index]%360==90, f"the camera_angles_azi associated with left side image (index {side_index}) should be 90, not {camera_angles_azi[back_index]}" + + if rm_bkg: + if rm_bkg_with_rembg: + os.environ["OMP_NUM_THREADS"] = '8' + normal_pils = simple_remove_bkg_normal(normal_pils,rm_bkg_with_rembg) + + if rotation_angles_azi is not None: + rotation_angles_azi = -rotation_angles_azi.float() + rotation_angles_ele = rotation_angles_ele.float() + + rotation_angles_ele, rotation_angles_azi = manage_elevation_azimuth(rotation_angles_ele, rotation_angles_azi) + + assert len(normal_pils) == len(rotation_angles_azi), f'len(normal_pils) ({len(normal_pils)}) != len(rotation_angles_azi) ({len(rotation_angles_azi)})' + if rotation_angles_ele is None: + rotation_angles_ele = [0] * len(normal_pils) + + normal_pils_rotated = [] + for i in range(len(normal_pils)): + c2w_R = get_rotation_matrix_azi_ele(rotation_angles_azi[i], rotation_angles_ele[i]) + + rotated_ = normal_rotation_img2img_c2w(normal_pils[i], c2w=c2w_R) + normal_pils_rotated.append(rotated_) + + normal_pils = normal_pils_rotated + + if normal_rotation_R is not None: + normal_pils_rotated = [] + for i in range(len(normal_pils)): + rotated_ = rotate_normal_R(normal_pils[i], normal_rotation_R, save_addr="", device="cuda") + normal_pils_rotated.append(rotated_) + + normal_pils = normal_pils_rotated + + normal_stg1 = [img for img in normal_pils] + + if init_type in ['thin', 'std']: + front_ = normal_stg1[front_index] + back_ = normal_stg1[back_index] + side_ = normal_stg1[side_index] + meshes, depth_front, depth_back, mesh_front, mesh_back = fast_geo(front_, back_, side_, init_type=init_type, return_depth_and_sep_mesh=True) + + + elif init_type in ["ball", "box"]: + + if init_type == "ball": + assert sphere_r is not None, f"sphere_r ({sphere_r}) should not be None when init_type is 'ball'" + meshes = create_sphere(sphere_r) + + if init_type == "box": + assert box_width is not None and box_length is not None and box_height is not None, f"box_width ({box_width}), box_length ({box_length}), and box_height ({box_height}) should not be None when init_type is 'box'" + meshes = create_box(width=box_width, length=box_length, height=box_height) + + # add texture just in case + num_meshes = len(meshes) + num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes + black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda") + textures = TexturesVertex(verts_features=black_texture) + meshes.textures = textures + + elif init_type == "file": + assert init_mesh_from_file or (init_verts is not None and init_faces is not None), f"init_mesh_from_file ({init_mesh_from_file}) should not be None when init_type is 'file', else init_verts and init_faces should not be None" + + if init_verts is not None and init_faces is not None: + meshes = Meshes(verts=[init_verts], faces=[init_faces]).to('cuda') + elif init_mesh_from_file.endswith('.glb'): + meshes = load_glb(init_mesh_from_file).to('cuda') + else: + meshes = load_obj_with_verts_faces(init_mesh_from_file).to('cuda') + + # add texture just in case + num_meshes = len(meshes) + num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes + black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda") + textures = TexturesVertex(verts_features=black_texture) + meshes.textures = textures + + if projection_type == 'perspective': + assert fovy is not None and radius is not None, f"fovy ({fovy}) and radius ({radius}) should not be None when projection_type is 'perspective'" + cameras = get_cameras_list_azi_ele(camera_angles_azi, camera_angles_ele, fov_in_degrees=fovy,device="cuda", dist=radius, cam_type='fov') + + elif projection_type == 'orthographic': + cameras = get_cameras_list_azi_ele(camera_angles_azi, camera_angles_ele, fov_in_degrees=fovy, device="cuda", focal=1., dist=ortho_dist, cam_type='orthographic') + + vertices, faces = meshes.verts_list()[0], meshes.faces_list()[0] + + render_camera_angles_azi = -camera_angles_azi + render_camera_angles_ele = camera_angles_ele + if projection_type == 'orthographic': + mv, proj = make_star_cameras_orthographic(render_camera_angles_azi, render_camera_angles_ele) + else: + mv, proj = make_star_cameras_perspective(render_camera_angles_azi, render_camera_angles_ele, distance=radius, r=radius, fov=fovy, device='cuda') + + # stage 1 + if train_stage1: + vertices, faces = reconstruct_stage1(normal_stg1, mv=mv, proj=proj, steps=stage1_steps, vertices=vertices, faces=faces, start_edge_len=start_edge_len_stage1, end_edge_len=end_edge_len_stage1, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight, use_remesh=use_remesh_stage1) + + # stage 2 + if train_stage2: + vertices, faces = run_mesh_refine(vertices, faces, normal_pils, mv=mv, proj=proj, weights=weights, steps=stage2_steps, start_edge_len=start_edge_len_stage2, end_edge_len=end_edge_len_stage2, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False, cameras=cameras, use_remesh=use_remesh_stage2, loss_expansion_weight=expansion_weight_stage2) + + meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda") + + return meshes diff --git a/models/ISOMER/projection_func.py b/models/ISOMER/projection_func.py new file mode 100644 index 0000000000000000000000000000000000000000..5704227fe47e26231c42c6b13b5827a01cb9f5fb --- /dev/null +++ b/models/ISOMER/projection_func.py @@ -0,0 +1,86 @@ +import os +import numpy as np +import torch +from PIL import Image +import os +from .scripts.proj_commands import projection as isomer_projection +from .data.utils import simple_remove_bkg_normal + +# mesh_address, +def projection( + meshes, + masks, + images, + azimuths, + elevations, + weights, + fov, + radius, + save_dir, + save_glb_addr=None, + remove_background=False, + auto_center=False, + projection_type="perspective", + below_confidence_strategy="smooth", + complete_unseen=True, + mesh_scale_factor=1.0, + rm_bkg_with_rembg=True, +): + + if save_glb_addr is None: + os.makedirs(save_dir, exist_ok=True) + save_glb_addr=os.path.join(save_dir, "rgb_projected.glb") + + bs = len(images) + assert len(azimuths) == bs, f'len(azimuths) ({len(azimuths)} != batchsize ({bs}))' + assert len(elevations) == bs, f'len(elevations) ({len(elevations)} != batchsize ({bs}))' + assert len(weights) == bs, f'len(weights) ({len(weights)} != batchsize ({bs}))' + + image_rgba = torch.cat([images[:,:,:,:3], masks.unsqueeze(-1)], dim=-1) + + assert image_rgba.shape[-1] == 4, f'image_rgba.shape is {image_rgba.shape}' + + img_list = [Image.fromarray((image.cpu()*255).numpy().astype(np.uint8)) for image in image_rgba] + + + if remove_background: + if rm_bkg_with_rembg: + os.environ["OMP_NUM_THREADS"] = '8' + img_list = simple_remove_bkg_normal(img_list, rm_bkg_with_rembg, return_Image=True) + + resolution = img_list[0].size[0] + new_img_list = [] + for i in range(len(img_list)): + new_img = img_list[i].resize((resolution,resolution)) + + path_dir = os.path.join(save_dir, f'projection_images') + os.makedirs(path_dir, exist_ok=True) + + path_ = os.path.join(path_dir, f'ProjectionImg{i}.png') + + new_img.save(path_) + + new_img_list.append(new_img) + + img_list = new_img_list + + isomer_projection(meshes, + img_list=img_list, + weights=weights, + azimuths=azimuths, + elevations=elevations, + projection_type=projection_type, + auto_center=auto_center, + resolution=resolution, + fovy=fov, + radius=radius, + scale_factor=mesh_scale_factor, + save_glb_addr=save_glb_addr, + scale_verts=True, + complete_unseen=complete_unseen, + below_confidence_strategy=below_confidence_strategy + ) + + return save_glb_addr + + diff --git a/models/ISOMER/reconstruction_func.py b/models/ISOMER/reconstruction_func.py new file mode 100644 index 0000000000000000000000000000000000000000..81408315f7cccd830aff49fd825337fa75d45fb2 --- /dev/null +++ b/models/ISOMER/reconstruction_func.py @@ -0,0 +1,88 @@ +import os +import numpy as np +import torch +from PIL import Image +import os +from .model.inference_pipeline import reconstruction_pipe + +def reconstruction( + normal_pils, + masks, + weights, + fov, + radius, + camera_angles_azi, + camera_angles_ele, + expansion_weight_stage1=0.1, + init_type="ball", + init_verts=None, + init_faces=None, + init_mesh_from_file="", + stage1_steps=200, + stage2_steps=200, + projection_type="perspective", + need_normal_rotation=False, + rotation_angles_azi=None, # only used if need_normal_rotation + rotation_angles_ele=None, # only used if need_normal_rotation + normal_rotation_R=None, # only used if need_normal_rotation + rm_bkg=False, + rm_bkg_with_rembg=True, # only used if rm_bkg + start_edge_len_stage1=0.1, + end_edge_len_stage1=0.02, + start_edge_len_stage2=0.02, + end_edge_len_stage2=0.005, + expansion_weight_stage2=0.0, +): + + if init_type == "file": + assert ((init_verts is not None and init_faces is not None) or init_mesh_from_file), f'init_mesh_from_file or (init_verts and init_faces) must be provided if init_type=="file"' + + if not need_normal_rotation: + rotation_angles_azi = None + rotation_angles_ele = None + normal_rotation_R = None + + bs = len(normal_pils) + + assert len(camera_angles_azi) == bs, f'len(camera_angles_azi) ({len(camera_angles_azi)} != batchsize ({bs}))' + assert len(camera_angles_ele) == bs, f'len(camera_angles_ele) ({len(camera_angles_ele)} != batchsize ({bs}))' + + normal_pils_rgba = torch.cat([normal_pils[:,:,:,:3], masks.unsqueeze(-1)], dim=-1) + + assert normal_pils_rgba.shape[-1] == 4, f'normal_pils_rgba.shape is {normal_pils_rgba.shape}' + + + normal_pils = [Image.fromarray((normal_pil.cpu()*255).numpy().astype(np.uint8)) for normal_pil in normal_pils_rgba] + + + meshes = reconstruction_pipe( + normal_pils=normal_pils, + rotation_angles_azi=rotation_angles_azi, + rotation_angles_ele=rotation_angles_ele, + weights=weights, + expansion_weight=expansion_weight_stage1, + init_type=init_type, + stage1_steps=stage1_steps, + stage2_steps=stage2_steps, + projection_type=projection_type, + fovy=fov, + radius=radius, + camera_angles_azi=camera_angles_azi, + camera_angles_ele=camera_angles_ele, + rm_bkg=rm_bkg, rm_bkg_with_rembg=rm_bkg_with_rembg, + normal_rotation_R=normal_rotation_R, + init_mesh_from_file=init_mesh_from_file, + start_edge_len_stage1=start_edge_len_stage1, + end_edge_len_stage1=end_edge_len_stage1, + start_edge_len_stage2=start_edge_len_stage2, + end_edge_len_stage2=end_edge_len_stage2, + expansion_weight_stage2=expansion_weight_stage2, + init_verts=init_verts, + init_faces=init_faces, + + ) + + + return meshes + + diff --git a/models/ISOMER/scripts/__init__.py b/models/ISOMER/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/ISOMER/scripts/all_typing.py b/models/ISOMER/scripts/all_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..87b19aaefff18dc04a8a7c185cd9086a27e91c62 --- /dev/null +++ b/models/ISOMER/scripts/all_typing.py @@ -0,0 +1,42 @@ +# code from https://github.com/threestudio-project + +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker diff --git a/models/ISOMER/scripts/fast_geo.py b/models/ISOMER/scripts/fast_geo.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb623bde1e30dc78d3507b165e7005d425bb287 --- /dev/null +++ b/models/ISOMER/scripts/fast_geo.py @@ -0,0 +1,86 @@ +import os +from PIL import Image +from .mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast +from pytorch3d.structures import Meshes, join_meshes_as_scene +import numpy as np + +import torch +from pytorch3d.structures import Meshes +from pytorch3d.utils import ico_sphere + +def create_sphere(radius, device='cuda'): + + sphere_mesh = ico_sphere(3, device=device) # Increase the subdivision level (e.g., 2) for higher resolution sphere + sphere_mesh = sphere_mesh.scale_verts(radius) + + meshes = Meshes(verts=[sphere_mesh.verts_list()[0]], faces=[sphere_mesh.faces_list()[0]]) + return meshes + + +def create_box(width, length, height, device='cuda'): + """ + Create a box mesh given the width, length, and height. + + Args: + width (float): Width of the box. + length (float): Length of the box. + height (float): Height of the box. + device (str): Device for the tensor operations, default is 'cuda'. + + Returns: + Meshes: A PyTorch3D Meshes object representing the box. + """ + # Define the 8 vertices of the box + verts = torch.tensor([ + [-width / 2, -length / 2, -height / 2], + [ width / 2, -length / 2, -height / 2], + [ width / 2, length / 2, -height / 2], + [-width / 2, length / 2, -height / 2], + [-width / 2, -length / 2, height / 2], + [ width / 2, -length / 2, height / 2], + [ width / 2, length / 2, height / 2], + [-width / 2, length / 2, height / 2] + ], device=device) + + # Define the 12 triangles (faces) of the box using vertex indices + faces = torch.tensor([ + [0, 1, 2], [0, 2, 3], # Bottom face + [4, 5, 6], [4, 6, 7], # Top face + [0, 1, 5], [0, 5, 4], # Front face + [1, 2, 6], [1, 6, 5], # Right face + [2, 3, 7], [2, 7, 6], # Back face + [3, 0, 4], [3, 4, 7] # Left face + ], device=device) + + # Create the Meshes object + meshes = Meshes(verts=[verts], faces=[faces]) + + return meshes + + +# stage 0 inital mesh estimation +def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std", return_depth_and_sep_mesh=False): + + import time + assert front_normal.mode != "RGB" + assert back_normal.mode != "RGB" + assert side_normal.mode != "RGB" + + front_normal = front_normal.resize((192, 192)) + back_normal = back_normal.resize((192, 192)) + side_normal = side_normal.resize((192, 192)) + + # build mesh with front back projection # ~3s + side_w_over_h = calc_w_over_h(side_normal) + mesh_front, depth_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type, return_depth=True) + mesh_back, depth_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type, return_depth=True) + meshes = join_meshes_as_scene([mesh_front, mesh_back]) + + # poisson reconstruction which guarantees a smooth connection between meshes + # and simplify into 2000 fewer faces + meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000) + + + if return_depth_and_sep_mesh: + return meshes, depth_front, depth_back, mesh_front, mesh_back + return meshes \ No newline at end of file diff --git a/models/ISOMER/scripts/load_onnx.py b/models/ISOMER/scripts/load_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..954cd444b988d372314cc1c7983bfc9dca4e998e --- /dev/null +++ b/models/ISOMER/scripts/load_onnx.py @@ -0,0 +1,48 @@ +import onnxruntime +import torch + +providers = [ + ('TensorrtExecutionProvider', { + 'device_id': 0, + 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024, + 'trt_fp16_enable': True, + 'trt_engine_cache_enable': True, + }), + ('CUDAExecutionProvider', { + 'device_id': 0, + 'arena_extend_strategy': 'kSameAsRequested', + 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, + 'cudnn_conv_algo_search': 'HEURISTIC', + }) +] + +def load_onnx(file_path: str): + assert file_path.endswith(".onnx") + sess_opt = onnxruntime.SessionOptions() + ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers) + return ort_session + + +def load_onnx_caller(file_path: str, single_output=False): + ort_session = load_onnx(file_path) + def caller(*args): + torch_input = isinstance(args[0], torch.Tensor) + if torch_input: + torch_input_dtype = args[0].dtype + torch_input_device = args[0].device + # check all are torch.Tensor and have same dtype and device + assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor" + assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor" + assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor" + args = [arg.cpu().float().numpy() for arg in args] + + ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))} + ort_outs = ort_session.run(None, ort_inputs) + + if torch_input: + ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs] + + if single_output: + return ort_outs[0] + return ort_outs + return caller diff --git a/models/ISOMER/scripts/mesh_init.py b/models/ISOMER/scripts/mesh_init.py new file mode 100644 index 0000000000000000000000000000000000000000..a68a4c4b312e1cb6ffa3244bb0f91126db6a6a7d --- /dev/null +++ b/models/ISOMER/scripts/mesh_init.py @@ -0,0 +1,142 @@ +from PIL import Image +import torch +import numpy as np +from pytorch3d.structures import Meshes +from pytorch3d.renderer import TexturesVertex +from .utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh +import pymeshlab + +_MAX_THREAD = 8 + +# rgb and depth to mesh +def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing='xy' + ) + i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device) + + origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3 + directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 + + return origins, directions + +def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False): + if valid_HWC is None: + valid_HWC = torch.ones_like(pred_HWC).bool() + H, W = rgb_BCHW.shape[-2:] + rgb_BCHW = rgb_BCHW.flip(-2) + pred_HWC = pred_HWC.flip(0) + valid_HWC = valid_HWC.flip(0) + rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device) + verts = rays_o + rays_d * pred_HWC # [H, W, 3] + verts = verts.reshape(-1, 3) # [V, 3] + indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device) + faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1) + # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1] + faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1] + faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1) + # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:] + faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:] + faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), + faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], + dim=0) # (F, 3) + colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3) + if is_back: + verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device) + + used_verts = faces.unique() + old_to_new_mapping = torch.zeros_like(verts[..., 0]).long() + old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device) + new_faces = old_to_new_mapping[faces] + mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]])) + return mesh + +def normalmap_to_depthmap(normal_np): + from .normal_to_height_map import estimate_height_map + height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96) + return height + +def transform_back_normal_to_front(normal_pil): + arr = np.array(normal_pil) # in [0, 255] + arr[..., 0] = 255-arr[..., 0] + arr[..., 2] = 255-arr[..., 2] + return Image.fromarray(arr.astype(np.uint8)) + +def calc_w_over_h(normal_pil): + if isinstance(normal_pil, Image.Image): + arr = np.array(normal_pil) + else: + assert isinstance(normal_pil, np.ndarray) + arr = normal_pil + if arr.shape[-1] == 4: + alpha = arr[..., -1] / 255. + alpha[alpha >= 0.5] = 1 + alpha[alpha < 0.5] = 0 + else: + alpha = ~(arr.min(axis=-1) >= 250) + h_min, w_min = np.min(np.where(alpha), axis=1) + h_max, w_max = np.max(np.where(alpha), axis=1) + return (w_max - w_min) / (h_max - h_min) + +def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0, return_depth=False): + if is_back: + normal_pil = transform_back_normal_to_front(normal_pil) + normal_img = np.array(normal_pil) + rgb_img = np.array(rgb_pil) + if normal_img.shape[-1] == 4: + valid_HWC = normal_img[..., [3]] / 255 + elif rgb_img.shape[-1] == 4: + valid_HWC = rgb_img[..., [3]] / 255 + else: + raise ValueError("invalid input, either normal or rgb should have alpha channel") + + # object area pixels height + real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0]) + + heights = normalmap_to_depthmap(normal_img) + rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None] + valid_HWC[valid_HWC < 0.5] = 0 + valid_HWC[valid_HWC >= 0.5] = 1 + valid_HWC = torch.from_numpy(valid_HWC).bool() + + if init_type == "std": + # accurate but not stable + pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None] + elif init_type == "thin": + heights = heights - heights.min() + heights = (heights / heights.max() * 0.2) + pred_HWC = torch.from_numpy(heights * scale).float()[..., None] + else: + # stable but not accurate + heights = heights - heights.min() + heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1] + pred_HWC = torch.from_numpy(heights * scale).float()[..., None] + + # set the boarder pixels to 0 height + import cv2 + # edge filter + edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255) + edge = torch.from_numpy(edge).bool()[..., None] + pred_HWC[edge] = 0 + + valid_HWC[pred_HWC < clamp_min] = False + rt_mesh = depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back) + + if return_depth: + return rt_mesh, pred_HWC + return rt_mesh + +# poisson reconstruction which guarantees a smooth connection between meshes +# and simplify into 2000 fewer faces +def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0): + ms = pymeshlab.MeshSet() + ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh") + if simplification > 0: + ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) + ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True) + if simplification > 0: + ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) + return meshlab_mesh_to_py3dmesh(ms.current_mesh()) diff --git a/models/ISOMER/scripts/normal_to_height_map.py b/models/ISOMER/scripts/normal_to_height_map.py new file mode 100644 index 0000000000000000000000000000000000000000..86cec902a37c280714659363e6b466324c8e7455 --- /dev/null +++ b/models/ISOMER/scripts/normal_to_height_map.py @@ -0,0 +1,205 @@ +# code modified from https://github.com/YertleTurtleGit/depth-from-normals +import numpy as np +import cv2 as cv +from multiprocessing.pool import ThreadPool as Pool +from multiprocessing import cpu_count +from typing import Tuple, List, Union +import numba + + +def calculate_gradients( + normals: np.ndarray, mask: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1)) + left_gradients = np.zeros(normals.shape[:2]) + left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign( + horizontal_angle_map[mask != 0] - np.pi / 2 + ) + + vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1)) + top_gradients = np.zeros(normals.shape[:2]) + top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign( + vertical_angle_map[mask != 0] - np.pi / 2 + ) + + return left_gradients, top_gradients + + +@numba.jit(nopython=True) +def integrate_gradient_field( + gradient_field: np.ndarray, axis: int, mask: np.ndarray +) -> np.ndarray: + heights = np.zeros(gradient_field.shape) + + for d1 in numba.prange(heights.shape[1 - axis]): # numba.prange: executes the loop in parallel + sum_value = 0 + for d2 in range(heights.shape[axis]): + coordinates = (d1, d2) if axis == 1 else (d2, d1) + + if mask[coordinates] != 0: + sum_value = sum_value + gradient_field[coordinates] # equation 1 in paper along `axis` axis + heights[coordinates] = sum_value + else: + sum_value = 0 + + return heights + +# equation 1 in paper wrt these directions +def calculate_heights( + left_gradients: np.ndarray, top_gradients, mask: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + left_heights = integrate_gradient_field(left_gradients, 1, mask) + right_heights = np.fliplr( + integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask)) + ) + top_heights = integrate_gradient_field(top_gradients, 0, mask) + bottom_heights = np.flipud( + integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask)) + ) + return left_heights, right_heights, top_heights, bottom_heights + + +def combine_heights(*heights: np.ndarray) -> np.ndarray: + return np.mean(np.stack(heights, axis=0), axis=0) + + +def rotate(matrix: np.ndarray, angle: float) -> np.ndarray: + h, w = matrix.shape[:2] + center = (w / 2, h / 2) + + rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0) + corners = cv.transform( + np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix + )[0] + + _, _, w, h = cv.boundingRect(corners) + + rotation_matrix[0, 2] += w / 2 - center[0] + rotation_matrix[1, 2] += h / 2 - center[1] + result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR) + + return result + + +def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray: + angle = np.radians(angle) + cos_angle = np.cos(angle) + sin_angle = np.sin(angle) + + rotated_normals = np.empty_like(normals) + rotated_normals[:, :, 0] = ( + normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle + ) + rotated_normals[:, :, 1] = ( + normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle + ) + + return rotated_normals + + +def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray: + return image[ + (image.shape[0] - target_resolution[0]) + // 2 : (image.shape[0] - target_resolution[0]) + // 2 + + target_resolution[0], + (image.shape[1] - target_resolution[1]) + // 2 : (image.shape[1] - target_resolution[1]) + // 2 + + target_resolution[1], + ] + + +def integrate_vector_field( + vector_field: np.ndarray, + mask: np.ndarray, + target_iteration_count: int, + thread_count: int, +) -> np.ndarray: + shape = vector_field.shape[:2] + angles = np.linspace(0, 90, target_iteration_count, endpoint=False) + + def integrate_vector_field_angles(angles: List[float]) -> np.ndarray: + all_combined_heights = np.zeros(shape) + + for angle in angles: + rotated_vector_field = rotate_vector_field_normals( + rotate(vector_field, angle), angle + ) # rotate twice: first rotate the whole in image level, then rotate the individual normal vectors + + rotated_mask = rotate(mask, angle) + + left_gradients, top_gradients = calculate_gradients( + rotated_vector_field, rotated_mask + ) + ( + left_heights, + right_heights, + top_heights, + bottom_heights, + ) = calculate_heights(left_gradients, top_gradients, rotated_mask) + + combined_heights = combine_heights( + left_heights, right_heights, top_heights, bottom_heights + ) # = mean of these heights + combined_heights = centered_crop(rotate(combined_heights, -angle), shape) + all_combined_heights += combined_heights / len(angles) + + return all_combined_heights + + with Pool(processes=thread_count) as pool: + heights = pool.map( + integrate_vector_field_angles, + np.array( + np.array_split(angles, thread_count), + dtype=object, + ), + ) + pool.close() + pool.join() + + isotropic_height = np.zeros(shape) + for height in heights: + isotropic_height += height / thread_count + + return isotropic_height + + +def estimate_height_map( + normal_map: np.ndarray, + mask: Union[np.ndarray, None] = None, + height_divisor: float = 1, + target_iteration_count: int = 250, + thread_count: int = cpu_count(), + raw_values: bool = False, +) -> np.ndarray: + if mask is None: + if normal_map.shape[-1] == 4: + mask = normal_map[:, :, 3] / 255 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + else: + mask = np.ones(normal_map.shape[:2], dtype=np.uint8) + + normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2 + heights = integrate_vector_field( + normals, mask, target_iteration_count, thread_count + ) # equation 1 in paper, repeat `target_iteration_count` (8?) times with rotation in angle np.linspace(0, 90, target_iteration_count), then find mean + # target_iteration_count=8 ? defined _MAX_THREAD = 8 in mesh_init.py + + if raw_values: + return heights + + heights /= height_divisor + heights[mask > 0] += 1 / 2 + heights[mask == 0] = 1 / 2 + + heights *= 2**16 - 1 + + if np.min(heights) < 0 or np.max(heights) > 2**16 - 1: + raise OverflowError("Height values are clipping.") + + heights = np.clip(heights, 0, 2**16 - 1) + heights = heights.astype(np.uint16) + + return heights diff --git a/models/ISOMER/scripts/proj_commands.py b/models/ISOMER/scripts/proj_commands.py new file mode 100644 index 0000000000000000000000000000000000000000..f76e7e5446c78f97092863b81082e0c97f89b3dc --- /dev/null +++ b/models/ISOMER/scripts/proj_commands.py @@ -0,0 +1,69 @@ +import numpy as np +import torch +from PIL import Image +from pytorch3d.renderer import ( + TexturesVertex, +) +from .project_mesh import ( + get_cameras_list_azi_ele, + multiview_color_projection + +) +from .utils import save_py3dmesh_with_trimesh_fast + +def projection(meshes, + img_list, + weights, + azimuths, + elevations, + projection_type='orthographic', + auto_center=True, + resolution=1024, + fovy=None, + radius=None, + ortho_dist=1.1, + scale_factor=1.0, + save_glb_addr=None, + scale_verts=True, + complete_unseen=True, + below_confidence_strategy="smooth" + ): + + assert len(img_list) == len(azimuths) == len(elevations) == len(weights), f"len(img_list) ({len(img_list)}) != len(azimuths) ({len(azimuths)}) != len(elevations) ({len(elevations)}) != len(weights) ({len(weights)})" + + projection_types = ['perspective', 'orthographic'] + assert projection_type in projection_types, f"projection_type ({projection_type}) should be one of {projection_types}" + + if auto_center: + verts = meshes.verts_packed() + max_bb = (verts - 0).max(0)[0] + min_bb = (verts - 0).min(0)[0] + scale = (max_bb - min_bb).max() / 2 + center = (max_bb + min_bb) / 2 + meshes.offset_verts_(-center) + if scale_verts: + meshes.scale_verts_((scale_factor / float(scale))) + elif scale_verts: + meshes.scale_verts_((scale_factor)) + + if projection_type == 'perspective': + assert fovy is not None and radius is not None, f"fovy ({fovy}) and radius ({radius}) should not be None when projection_type is 'perspective'" + cameras = get_cameras_list_azi_ele(azimuths, elevations, fov_in_degrees=fovy,device="cuda", dist=radius, cam_type='fov') + elif projection_type == 'orthographic': + cameras = get_cameras_list_azi_ele(azimuths, elevations, fov_in_degrees=fovy, device="cuda", focal=2/1.35, dist=ortho_dist, cam_type='orthographic') + + + num_meshes = len(meshes) + num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes + black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda") + textures = TexturesVertex(verts_features=black_texture) + meshes.textures = textures + + + proj_mesh = multiview_color_projection(meshes, img_list, cameras, weights=weights, eps=0.05, resolution=resolution, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=complete_unseen, below_confidence_strategy=below_confidence_strategy) + + + if save_glb_addr is not None: + save_py3dmesh_with_trimesh_fast(proj_mesh, save_glb_addr) + + diff --git a/models/ISOMER/scripts/project_mesh.py b/models/ISOMER/scripts/project_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..443b17d3e3a83f9ee4d2bea427ec7d5b39969032 --- /dev/null +++ b/models/ISOMER/scripts/project_mesh.py @@ -0,0 +1,401 @@ +from typing import List +import torch +import numpy as np +from PIL import Image +from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase +from pytorch3d.io import load_objs_as_meshes +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + RasterizationSettings, + TexturesVertex, + FoVPerspectiveCameras, + FoVOrthographicCameras, +) +from pytorch3d.renderer import MeshRasterizer + +def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'): + # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183 + R = world_to_cam[:3, :3].t()[None, ...] + T = world_to_cam[:3, 3][None, ...] + if cam_type == 'fov': + assert fov_in_degrees is not None, "fov_in_degrees should not be None when cam_type is fov" + camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True) + else: + focal_length = 1 / focal_length + camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) + return camera + +def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1): + """ + Renders pix2face of visible faces. + + :param mesh: Pytorch3d.structures.Meshes + :param cameras: pytorch3d.renderer.Cameras + :param H: target image height + :param W: target image width + :param blur_radius: Float distance in the range [0, 2] used to expand the face + bounding boxes for rasterization. Setting blur radius + results in blurred edges around the shape instead of a + hard boundary. Set to 0 for no blur. + :param faces_per_pixel: (int) Number of faces to keep track of per pixel. + We return the nearest faces_per_pixel faces along the z-axis. + """ + # Define the settings for rasterization and shading + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel + ) + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=raster_settings + ) + fragments: Fragments = rasterizer(meshes, cameras=cameras) + return { + "pix_to_face": fragments.pix_to_face[..., 0], + } + +import nvdiffrast.torch as dr + +def _warmup(glctx, device=None): + device = 'cuda' if device is None else device + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + def tensor(*args, **kwargs): + return torch.tensor(*args, device=device, **kwargs) + pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) + tri = tensor([[0, 1, 2]], dtype=torch.int32) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +class Pix2FacesRenderer: + def __init__(self, device="cuda"): + # self._glctx = dr.RasterizeGLContext(output_db=False, device=device) + self._glctx = dr.RasterizeCudaContext(device=device) + self.device = device + _warmup(self._glctx, device) + + def transform_vertices(self, meshes: Meshes, cameras: CamerasBase): + vertices = cameras.transform_points_ndc(meshes.verts_padded()) + + perspective_correct = cameras.is_perspective() + znear = cameras.get_znear() + if isinstance(znear, torch.Tensor): + znear = znear.min().item() + z_clip = None #if not perspective_correct or znear is None else znear / 2 + + if z_clip: + vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip + vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices) + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32) + return vertices + + def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512): + meshes = meshes.to(self.device) + cameras = cameras.to(self.device) + vertices = self.transform_vertices(meshes, cameras) + faces = meshes.faces_packed().to(torch.int32) + rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4 + pix_to_face = rast_out[..., -1].to(torch.int32) - 1 + return pix_to_face + +pix2faces_renderer = Pix2FacesRenderer() + +def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024): + # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face'] + pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution) + + unique_faces = torch.unique(pix_to_face.flatten()) + unique_faces = unique_faces[unique_faces != -1] + return unique_faces + +def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object. + cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object. + pil_image (PIL.Image.Image): The input image. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + debug (bool, optional): Whether to save debug images. Defaults to False. + + Returns: + dict: A dictionary containing the following keys: + - "new_texture" (TexturesVertex): The updated texture with interpolated colors. + - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected. + - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices. + """ + meshes = meshes.to(device) + cameras = cameras.to(device) + image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.] + unique_faces = get_visible_faces(meshes, cameras, resolution=resolution) + + # visible faces + faces_normals = meshes.faces_normals_packed()[unique_faces] + faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True) + world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0] + view_direction = world_points[1] - world_points[0] + view_direction = view_direction / view_direction.norm(dim=0, keepdim=True) + + + # find invalid faces + cos_angles = (faces_normals * view_direction).sum(dim=1) + # assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}" + selected_faces = unique_faces[cos_angles < -eps] + + # find verts + faces = meshes.faces_packed()[selected_faces] # [N, 3] + verts = torch.unique(faces.flatten()) # [N, 1] + verts_coordinates = meshes.verts_packed()[verts] # [N, 3] + + # compute color + pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points + valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1 dict: + """ + meshes: the mesh with vertex color to be completed. + valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1] + """ + valid_index = valid_index.to(meshes.device) + colors = meshes.textures.verts_features_packed() # [V, 3] + V = colors.shape[0] + + invalid_index = torch.ones_like(colors[:, 0]).bool() # [V] + invalid_index[valid_index] = False + invalid_index = torch.arange(V).to(meshes.device)[invalid_index] + + L = meshes.laplacian_packed() # connectivity + E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device) + L = L + E + # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device) + # L = L + E + colored_count = torch.ones_like(colors[:, 0]) # [V] + colored_count[invalid_index] = 0 + L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V] + + total_colored = colored_count.sum() + coloring_round = 0 + stage = "uncolored" + from tqdm import tqdm + pbar = tqdm(miniters=100) + while stage == "uncolored" or coloring_round > 0: + new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3] + new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1] + colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index]) + colored_count[invalid_index] = (new_count[:, 0] > 0).float() + + new_total_colored = colored_count.sum() + if new_total_colored > total_colored: + total_colored = new_total_colored + coloring_round += 1 + else: + stage = "colored" + coloring_round -= 1 + pbar.update(1) + if coloring_round > 10000: + print("coloring_round > 10000, break") + break + assert not torch.isnan(colors).any() + meshes.textures = TexturesVertex(verts_features=[colors]) + return meshes + +def load_glb_mesh(glb_path, device="cuda"): + meshes = load_objs_as_meshes([glb_path], device=device) + return meshes + +def get_separated_images_from_img_grid(img_grid_path, image_num): + img_list = [] + grid = Image.open(img_grid_path) + w, h = grid.size + for i in range(0, image_num): + img_list.append(grid.crop((i*h, 0, i*h + h, h))) + return img_list + +def get_fov_camera_(azimuth, elevation, fovy, radius, mesh, auto_center, scale_factor, device='cuda'): + if auto_center: + verts = mesh.verts_packed() + max_bb = (verts - 0).max(0)[0] + min_bb = (verts - 0).min(0)[0] + scale = (max_bb - min_bb).max() / 2 + center = (max_bb + min_bb) / 2 + mesh.offset_verts_(-center) + mesh.scale_verts_((scale_factor / float(scale))) + else: + mesh.scale_verts_((scale_factor)) + R, T = look_at_view_transform(radius, azimuth, elevation, device=device) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=fovy) + return cameras + +def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase], weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth") -> Meshes: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh. + image_list (PIL.Image.Image): List of images. + cameras_list (list): List of cameras. + weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1. + complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False. + + Returns: + Meshes: the colored mesh + """ + + if image_list is None: + raise ValueError("image_list is None") + + + meshes = meshes.clone().to(device) + if weights is None: + weights = [1. for _ in range(len(cameras_list))] + + assert len(cameras_list) == len(image_list) == len(weights), f'the following three lengths should be equal: len(cameras_list)({len(cameras_list)}), len(image_list)({len(image_list)}), len(weights)({len(weights)})' + + original_color = meshes.textures.verts_features_packed() + assert not torch.isnan(original_color).any() + texture_counts = torch.zeros_like(original_color[..., :1]) + texture_values = torch.zeros_like(original_color) + max_texture_counts = torch.zeros_like(original_color[..., :1]) + max_texture_values = torch.zeros_like(original_color) + for camera, image, weight in zip(cameras_list, image_list, weights): + ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha) + if reweight_with_cosangle == "linear": + weight = (ret['cos_angles'].abs() * weight)[:, None] + elif reweight_with_cosangle == "square": + weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None] + if use_alpha: + weight = weight * ret['valid_alpha'] + + try: + assert weight.min() > -0.0001, f'weight.min() is {weight.min()}, but shoule be > -0.0001' + except Exception as e: + raise e + + texture_counts[ret['valid_verts']] += weight + texture_values[ret['valid_verts']] += ret['valid_colors'] * weight + max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']]) + max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight) + + texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values) + if below_confidence_strategy == "smooth": + texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values) + elif below_confidence_strategy == "original": + texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values) + else: + raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported") + assert not torch.isnan(texture_values).any() + meshes.textures = TexturesVertex(verts_features=[texture_values]) + + if complete_unseen: + meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold]) + ret_mesh = meshes.detach() + del meshes + return ret_mesh + +def get_cameras_list(azim_list, device, elevation, fov_in_degrees=None, focal=2/1.35, dist=1.1, cam_type='orthographic'): + ret = [] + for azim in azim_list: + R, T = look_at_view_transform(dist, elevation, azim) + w2c = torch.cat([R[0].T, T[0, :, None]], dim=1) + cameras = get_camera(w2c, fov_in_degrees=fov_in_degrees, focal_length=focal, cam_type=cam_type).to(device) + ret.append(cameras) + return ret + +def get_cameras_list_azi_ele(azim_list, elev_list, device, fov_in_degrees=None, focal=2/1.35, dist=1.1, cam_type='orthographic'): + ret = [] + for i in range(len(azim_list)): + R, T = look_at_view_transform(dist, elev_list[i], azim_list[i]) + w2c = torch.cat([R[0].T, T[0, :, None]], dim=1) + cameras = get_camera(w2c, fov_in_degrees=fov_in_degrees, focal_length=focal, cam_type=cam_type).to(device) + ret.append(cameras) + return ret + +def get_8view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 225, 270, 315, 0, 45, 90, 135], elevation=0, device=device, focal=focal) + +def get_6view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], elevation=0, device=device, focal=focal) + +def get_4view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 270, 0, 90], elevation=0, device=device, focal=focal) + +def get_2view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 0], elevation=0, device=device, focal=focal) + +def get_multiple_view_cameras(device, focal=2/1.35, offset=180, num_views=8, dist=1.1): + return get_cameras_list(azim_list = (np.linspace(0, 360, num_views+1)[:-1] + offset) % 360, elevation=0, device=device, focal=focal, dist=dist) + +def align_with_alpha_bbox(source_img, target_img, final_size=1024): + # align source_img with target_img using alpha channel + # source_img and target_img are PIL.Image.Image + source_img = source_img.convert("RGBA") + target_img = target_img.convert("RGBA").resize((final_size, final_size)) + source_np = np.array(source_img) + target_np = np.array(target_img) + source_alpha = source_np[:, :, 3] + target_alpha = target_np[:, :, 3] + bbox_source_min, bbox_source_max = np.argwhere(source_alpha > 0).min(axis=0), np.argwhere(source_alpha > 0).max(axis=0) + bbox_target_min, bbox_target_max = np.argwhere(target_alpha > 0).min(axis=0), np.argwhere(target_alpha > 0).max(axis=0) + source_content = source_np[bbox_source_min[0]:bbox_source_max[0]+1, bbox_source_min[1]:bbox_source_max[1]+1, :] + # resize source_content to fit in the position of target_content + source_content = Image.fromarray(source_content).resize((bbox_target_max[1]-bbox_target_min[1]+1, bbox_target_max[0]-bbox_target_min[0]+1), resample=Image.BICUBIC) + target_np[bbox_target_min[0]:bbox_target_max[0]+1, bbox_target_min[1]:bbox_target_max[1]+1, :] = np.array(source_content) + return Image.fromarray(target_np) + +def load_image_list_from_mvdiffusion(mvdiffusion_path, front_from_pil_or_path=None): + import os + image_list = [] + for dir in ['front', 'front_right', 'right', 'back', 'left', 'front_left']: + image_path = os.path.join(mvdiffusion_path, f"rgb_000_{dir}.png") + pil = Image.open(image_path) + if dir == 'front': + if front_from_pil_or_path is not None: + if isinstance(front_from_pil_or_path, str): + replace_pil = Image.open(front_from_pil_or_path) + else: + replace_pil = front_from_pil_or_path + # align replace_pil with pil using bounding box in alpha channel + pil = align_with_alpha_bbox(replace_pil, pil, final_size=1024) + image_list.append(pil) + return image_list + +def load_image_list_from_img_grid(img_grid_path, resolution = 1024): + img_list = [] + grid = Image.open(img_grid_path) + w, h = grid.size + for row in range(0, h, resolution): + for col in range(0, w, resolution): + img_list.append(grid.crop((col, row, col + resolution, row + resolution))) + return img_list \ No newline at end of file diff --git a/models/ISOMER/scripts/refine_lr_to_sr.py b/models/ISOMER/scripts/refine_lr_to_sr.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a50cab8cd6d722ce0bf47ceb42ba168070fd9d --- /dev/null +++ b/models/ISOMER/scripts/refine_lr_to_sr.py @@ -0,0 +1,60 @@ +import torch +import os + +import numpy as np +from hashlib import md5 +def hash_img(img): + return md5(np.array(img).tobytes()).hexdigest() +def hash_any(obj): + return md5(str(obj).encode()).hexdigest() + +def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.): + with torch.no_grad(): + images = pipe( + image=pil_image_list, + ip_adapter_image=concept_img_list, + prompt=prompt_list, + neg_prompt=neg_prompt_list, + num_inference_steps=50, + strength=strength, + height=output_size[0], + width=output_size[1], + control_image=control_image_list, + guidance_scale=5.0, + controlnet_conditioning_scale=controlnet_conditioning_scale, + generator=torch.manual_seed(233), + ).images + return images + +SR_cache = None + +def run_sr_fast(source_pils, scale=4): + from PIL import Image + from scripts.upsampler import RealESRGANer + import numpy as np + global SR_cache + if SR_cache is not None: + upsampler = SR_cache + else: + upsampler = RealESRGANer( + scale=4, + onnx_path="ckpt/realesrgan-x4.onnx", + tile=0, + tile_pad=10, + pre_pad=0, + half=True, + gpu_id=0, + ) + ret_pils = [] + for idx, img_pils in enumerate(source_pils): + np_in = isinstance(img_pils, np.ndarray) + assert isinstance(img_pils, (Image.Image, np.ndarray)) + img = np.array(img_pils) + output, _ = upsampler.enhance(img, outscale=scale) + if np_in: + ret_pils.append(output) + else: + ret_pils.append(Image.fromarray(output)) + if SR_cache is None: + SR_cache = upsampler + return ret_pils diff --git a/models/ISOMER/scripts/sd_model_zoo.py b/models/ISOMER/scripts/sd_model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5e271ca5b8f242998e151482b86984aac1c10d --- /dev/null +++ b/models/ISOMER/scripts/sd_model_zoo.py @@ -0,0 +1,131 @@ +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline +from transformers import CLIPVisionModelWithProjection +import torch +from copy import deepcopy + +ENABLE_CPU_CACHE = False +DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5" + +cached_models = {} # cache for models to avoid repeated loading, key is model name +def cache_model(func): + def wrapper(*args, **kwargs): + if ENABLE_CPU_CACHE: + model_name = func.__name__ + str(args) + str(kwargs) + if model_name not in cached_models: + cached_models[model_name] = func(*args, **kwargs) + return cached_models[model_name] + else: + return func(*args, **kwargs) + return wrapper + +def copied_cache_model(func): + def wrapper(*args, **kwargs): + if ENABLE_CPU_CACHE: + model_name = func.__name__ + str(args) + str(kwargs) + if model_name not in cached_models: + cached_models[model_name] = func(*args, **kwargs) + return deepcopy(cached_models[model_name]) + else: + return func(*args, **kwargs) + return wrapper + +def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): + if ckpt_or_pretrained.endswith(".safetensors"): + pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) + else: + pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) + return pipe + +@copied_cache_model +def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): + model_kwargs = dict( + torch_dtype=torch_dtype, + requires_safety_checker=False, + safety_checker=None, + ) + pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( + base_model, + StableDiffusionPipeline, + **model_kwargs + ) + pipe.to("cpu") + return pipe.components + +@cache_model +def load_controlnet(controlnet_path, torch_dtype=torch.float16): + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) + return controlnet + +@cache_model +def load_image_encoder(): + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="models/image_encoder", + torch_dtype=torch.float16, + ) + return image_encoder + +def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): + model_kwargs = dict( + torch_dtype=torch_dtype, + device_map=device, + requires_safety_checker=False, + safety_checker=None, + ) + components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) + model_kwargs.update(components) + model_kwargs.update(kwargs) + + if controlnet is not None: + if isinstance(controlnet, list): + controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] + else: + controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) + model_kwargs.update(controlnet=controlnet) + + if pipeline_class is None: + if controlnet is not None: + pipeline_class = StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionPipeline + + pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( + base_model, + pipeline_class, + **model_kwargs + ) + + if ip_adapter: + image_encoder = load_image_encoder() + pipe.image_encoder = image_encoder + if plus_model: + pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") + else: + pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") + pipe.set_ip_adapter_scale(1.0) + else: + pipe.unload_ip_adapter() + + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + + if model_cpu_offload_seq is None: + if isinstance(pipe, StableDiffusionControlNetPipeline): + pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" + elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): + pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" + else: + pipe.model_cpu_offload_seq = model_cpu_offload_seq + + if enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload() + else: + pipe = pipe.to("cuda") + pass + # pipe.enable_model_cpu_offload() + if vae_slicing: + pipe.enable_vae_slicing() + + import gc + gc.collect() + return pipe + diff --git a/models/ISOMER/scripts/upsampler.py b/models/ISOMER/scripts/upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..befa44aad183bb51895057112d5a0ed8d13b7828 --- /dev/null +++ b/models/ISOMER/scripts/upsampler.py @@ -0,0 +1,260 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torch.nn import functional as F +from scripts.load_onnx import load_onnx_caller +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + onnx_path, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + print('about to initialize model') + # initialize model + if gpu_id: + self.device = torch.device( + f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + print('self.device set') + print(f'about to self.model = load_onnx_caller({onnx_path}, single_output=True)') + self.model = load_onnx_caller(onnx_path, single_output=True) + print('self.model loaded') + + print('about to warm up') + # warm up + sample_input = torch.randn(1,3,512,512).cuda().float() + print(f'sample_input.shape = {sample_input.shape}') + self.model(sample_input) + print('finished warming up') + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + print('inside enhance') + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + print('about to process image (without the alpha channel)') + self.pre_process(img) + if self.tile_size > 0: + print(f'self.tile_size is {self.tile_size}, thus about to self.tile_process()') + self.tile_process() + print('finished self.tile_process()') + else: + print('about to self.process()') + self.process() + print('finished self.process()') + + print('about to self.post_process()') + output_img = self.post_process() + print('finished self.post_process()') + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + print('finished process image (without the alpha channel)') + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + print("img_mode == 'RGBA' thus about to process alpha channel") + if alpha_upsampler == 'realesrgan': + print(f"alpha_upsampler == 'realesrgan', about to self.pre_process({alpha})") + self.pre_process(alpha) + print('finished self.pre_process') + if self.tile_size > 0: + print(f'self.tile_size is {self.tile_size}, thus about to self.tile_process()') + self.tile_process() + print('finished self.tile_process()') + else: + print('about to self.process()') + self.process() + print('finished self.process()') + print('about to self.post_process()') + output_alpha = self.post_process() + print('finished self.post_process()') + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + print('about to use the cv2 resize for alpha channel') + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + print('about to merge the alpha channel') + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + print('finished process alpha channel') + + print('about to resize and return') + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + diff --git a/models/ISOMER/scripts/utils.py b/models/ISOMER/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58ab46810d99dadeed941a578920433f015f6b28 --- /dev/null +++ b/models/ISOMER/scripts/utils.py @@ -0,0 +1,611 @@ +import torch +import numpy as np +from PIL import Image +import pymeshlab +import pymeshlab as ml +from pymeshlab import PercentageValue +from pytorch3d.renderer import TexturesVertex +from pytorch3d.structures import Meshes +import torch +import torch.nn.functional as F +from typing import List, Tuple +from PIL import Image +import trimesh + +EPSILON = 1e-8 + +def load_mesh_with_trimesh(file_name, file_type=None): + import trimesh + mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type) + if isinstance(mesh, trimesh.Scene): + assert len(mesh.geometry) > 0 + # save to obj first and load again to avoid offset issue + from io import BytesIO + with BytesIO() as f: + mesh.export(f, file_type="obj") + f.seek(0) + mesh = trimesh.load(f, file_type="obj") + if isinstance(mesh, trimesh.Scene): + # we lose texture information here + mesh = trimesh.util.concatenate( + tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) + for g in mesh.geometry.values())) + assert isinstance(mesh, trimesh.Trimesh) + + vertices = torch.from_numpy(mesh.vertices).T + faces = torch.from_numpy(mesh.faces).T + colors = None + if mesh.visual is not None: + if hasattr(mesh.visual, 'vertex_colors'): + colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255. + if colors is None: + colors = torch.ones_like(vertices) * 0.5 + return vertices, faces, colors + +def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes: + verts = torch.from_numpy(mesh.vertex_matrix()).float() + faces = torch.from_numpy(mesh.face_matrix()).long() + colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float() + textures = TexturesVertex(verts_features=[colors]) + return Meshes(verts=[verts], faces=[faces], textures=textures) + + +def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh: + colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64) + m1 = pymeshlab.Mesh( + vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64), + face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32), + v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64), + v_color_matrix=colors_in) + return m1 + + +def to_pyml_mesh(vertices,faces): + m1 = pymeshlab.Mesh( + vertex_matrix=vertices.cpu().float().numpy().astype(np.float64), + face_matrix=faces.cpu().long().numpy().astype(np.int32), + ) + return m1 + + +def to_py3d_mesh(vertices, faces, normals=None): + from pytorch3d.structures import Meshes + from pytorch3d.renderer.mesh.textures import TexturesVertex + mesh = Meshes(verts=[vertices], faces=[faces], textures=None) + if normals is None: + normals = mesh.verts_normals_packed() + # set normals as vertext colors + mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5]) + return mesh + + +def from_py3d_mesh(mesh): + return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed() + +def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float): + """ + rotate along y-axis + normal_map: np.array, shape=(H, W, 3) in [-1, 1] + angle: float, in degree + """ + angle = angle / 180 * np.pi + R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]) + return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape) + +# from view coord to front view world coord +def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255] + n_views = len(normal_pils) + ret = [] + for idx, rgba_normal in enumerate(normal_pils): + # rotate normal + normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1] + alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1] + normal_np = normal_np * 2 - 1 + normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views)) + normal_np = (normal_np + 1) / 2 + normal_np = normal_np * alpha_np[..., None] # make bg black + rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1) + if return_types == 'np': + ret.append(rgba_normal_np) + elif return_types == 'pil': + ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) + else: + raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}") + return ret + + +def rotate_normalmap_by_angle_torch(normal_map, angle): + """ + rotate along y-axis + normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda' + angle: float, in degree + """ + angle = torch.tensor(angle / 180 * np.pi).to(normal_map) + R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)], + [0, 1, 0], + [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map) + return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape) + +def do_rotate(rgba_normal, angle): + rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255 + rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle) + rotated_normal_tensor = (rotated_normal_tensor + 1) / 2 + rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black + rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy() + return rgba_normal_np + +def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1): + n_views = len(normal_pils) + ret = [] + for idx, rgba_normal in enumerate(normal_pils): + # rotate normal + angle = rotate_direction * idx * (360 / n_views) + rgba_normal_np = do_rotate(np.array(rgba_normal), angle) + if return_types == 'np': + ret.append(rgba_normal_np) + elif return_types == 'pil': + ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) + else: + raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}") + return ret + +def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)): + ret = [] + new_bkgd = np.array(new_bkgd).reshape(1, 1, 3) + for rgba_img in img_pils: + img_np = np.array(rgba_img)[:, :, :3] / 255 + alpha_np = np.array(rgba_img)[:, :, 3] / 255 + ori_bkgd = img_np[:1, :1] + # color = ori_color * alpha + bkgd * (1-alpha) + # ori_color = (color - bkgd * (1-alpha)) / alpha + alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero + ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None] + img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd) + rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1) + ret.append(Image.fromarray(rgba_img_np.astype(np.uint8))) + return ret + +def change_bkgd_to_normal(normal_pils) -> List[Image.Image]: + n_views = len(normal_pils) + ret = [] + for idx, rgba_normal in enumerate(normal_pils): + # calcuate background normal + target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views)) + normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1] + alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1] + normal_np = normal_np * 2 - 1 + old_bkgd = normal_np[:1,:1] + normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None] + normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None]) + normal_np = (normal_np + 1) / 2 + rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1) + ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) + return ret + + +def fix_vert_color_glb(mesh_path): + from pygltflib import GLTF2, Material, PbrMetallicRoughness + obj1 = GLTF2().load(mesh_path) + obj1.meshes[0].primitives[0].material = 0 + obj1.materials.append(Material( + pbrMetallicRoughness = PbrMetallicRoughness( + baseColorFactor = [1.0, 1.0, 1.0, 1.0], + metallicFactor = 0., + roughnessFactor = 1.0, + ), + emissiveFactor = [0.0, 0.0, 0.0], + doubleSided = True, + )) + obj1.save(mesh_path) + + +def srgb_to_linear(c_srgb): + c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4) + return c_linear.clip(0, 1.) + + +def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True): + # convert from pytorch3d meshes to trimesh mesh + vertices = meshes.verts_packed().cpu().float().numpy() + triangles = meshes.faces_packed().cpu().long().numpy() + np_color = meshes.textures.verts_features_packed().cpu().float().numpy() + if save_glb_path.endswith(".glb"): + # rotate 180 along +Y + vertices[:, [0, 2]] = -vertices[:, [0, 2]] + + if apply_sRGB_to_LinearRGB: + np_color = srgb_to_linear(np_color) + assert vertices.shape[0] == np_color.shape[0] + assert np_color.shape[1] == 3 + assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}" + mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color) + mesh.remove_unreferenced_vertices() + # save mesh + mesh.export(save_glb_path) + if save_glb_path.endswith(".glb"): + fix_vert_color_glb(save_glb_path) + + +def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]: + import time + if '.' in save_mesh_prefix: + save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1]) + if with_timestamp: + save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}" + ret_mesh = save_mesh_prefix + ".glb" + # optimizied version + save_py3dmesh_with_trimesh_fast(meshes, ret_mesh) + return ret_mesh, None + + +def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25): + ms = ml.MeshSet() + ms.add_mesh(pyml_mesh, "cube_mesh") + + if apply_smooth: + ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False) + if apply_sub_divide: # 5s, slow + ms.apply_filter("meshing_repair_non_manifold_vertices") + ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces') + ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold)) + return meshlab_mesh_to_py3dmesh(ms.current_mesh()) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + + +def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"): + new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device) + + imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255 + img_nps = imgs[..., :3] + alpha_nps = imgs[..., 3] + ori_bkgds = img_nps[:, :1, :1] + + alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1) + ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1) + ori_img_nps = torch.clamp(ori_img_nps, 0, 1) + img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd) + + rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1) + return rgba_img_np + + + +def rotation_matrix_axis_angle(axis, angle, device='cuda'): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by angle degrees, using PyTorch. + """ + if type(axis) != torch.tensor: + axis = torch.tensor(axis, device=device) + axis = axis.float().to(device) + if type(angle) != torch.tensor: + angle = torch.tensor(angle, device=device) + angle = angle.float().to(device) + + theta = angle * torch.pi / 180.0 + axis = torch.tensor(axis, dtype=torch.float32) + if torch.dot(axis, axis) > 0: + denom = torch.sqrt(torch.dot(axis, axis)) + demon = torch.where(denom == 0, torch.tensor(EPSILON).to(denom.device), denom) + axis = axis / torch.sqrt(demon) + a = torch.cos(theta / 2.0) + b, c, d = -axis[0] * torch.sin(theta / 2.0), -axis[1] * torch.sin(theta / 2.0), -axis[2] * torch.sin(theta / 2.0) + + aa, bb, cc, dd = a*a, b*b, c*c, d*d + bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d + return torch.stack([ + torch.stack([aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)]), + torch.stack([2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)]), + torch.stack([2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]) + ]) + else: + return torch.eye(3) + + + +def normal_rotation_img2img_angle_axis(image, angle, axis=None, device='cuda'): + """ + Rotate an image by a given angle around a given axis using PyTorch. + + Args: + image: Input Image to rotate. + angle: Rotation angle in degrees. + axis: Rotation axis as a array of 3 floats. + + Returns: + Image: Rotated Image. + """ + if axis is None: + axis = [0,1,0] + axis = torch.tensor(axis, device=device) + + + if type(image) == Image.Image: + image_array = torch.tensor(np.array(image, dtype='float32')) + else: + image_array = image + image_array = image_array.to(device) + + if type(angle) != torch.Tensor: + angle = torch.tensor(angle) + angle = angle.to(device) + + if image_array.shape[2] == 4: + rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3] + else: + rgb_array = image_array + alpha_array = None + + rgb_array = rgb_array / 255.0 - 0.5 + + rgb_array = rgb_array.permute(2, 0, 1) + + rotated_tensor = apply_rotation_angle_axis(rgb_array.unsqueeze(0), axis, torch.tensor([angle], device=rgb_array.device)) + + + rotated_array = rotated_tensor.squeeze().permute(1, 2, 0) + + rotated_array = (rotated_array/2 + 0.5) * 255 + + if alpha_array is not None: + rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2) + + rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8') + + rotated_normal = Image.fromarray(rotated_array_uint8) + + return rotated_normal + +def normal_rotation_img2img_c2w(image, c2w, device='cuda'): + + if type(image) != torch.Tensor: + image_array = torch.tensor(np.array(image, dtype='float32')) + else: + image_array = image + + + image_array = image_array.to(device) + + if image_array.shape[2] == 4: + rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3] + else: + rgb_array = image_array + alpha_array = None + + rgb_array = rgb_array / 255.0 - 0.5 + + rotation_matrix = c2w + + rotated_tensor = transform_normals_R(rgb_array, rotation_matrix) + + rotated_array = rotated_tensor.squeeze().permute(1, 2, 0) + rotated_array = (rotated_array/2 + 0.5) * 255 + + if alpha_array is not None: + rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2) + + rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8') + + rotated_normal = Image.fromarray(rotated_array_uint8) + + return rotated_normal + +def normal_rotation_img2img_azi_ele(image, azi, ele, device='cuda'): + """ + Rotate an image by a given angle around a given axis using PyTorch. + + Args: + image: Input Image to rotate. + + Returns: + Image: Rotated Image. + """ + + if type(image) == Image.Image: + image_array = torch.tensor(np.array(image, dtype='float32')) + else: + image_array = image + image_array = image_array.to(device) + + if type(azi) != torch.Tensor: + azi = torch.tensor(azi) + azi = azi.to(device) + + if type(ele) != torch.Tensor: + ele = torch.tensor(ele) + ele = ele.to(device) + + if image_array.shape[2] == 4: + rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3] + else: + rgb_array = image_array + alpha_array = None + + rgb_array = rgb_array / 255.0 - 0.5 + + rotation_matrix = get_rotation_matrix_azi_ele(azi, ele) + rotated_tensor = transform_normals_R(rgb_array, rotation_matrix) + + rotated_array = rotated_tensor.squeeze().permute(1, 2, 0) + + rotated_array = (rotated_array/2 + 0.5) * 255 + + if alpha_array is not None: + rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2) + + rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8') + + rotated_normal = Image.fromarray(rotated_array_uint8) + + return rotated_normal + + +def rotate_normal_R(image, rotation_matrix, save_addr="", device="cuda"): + """ + Rotate a normal map by a given Rotation matrix using PyTorch. + + Args: + image: Input Image to rotate. + + Returns: + Image: Rotated Image. + """ + + if type(image) != torch.tensor: + image_array = torch.tensor(np.array(image, dtype='float32')) + else: + image_array = image + image_array = image_array.to(device) + + if image_array.shape[2] == 4: + rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3] + else: + rgb_array = image_array + alpha_array = None + + rgb_array = rgb_array / 255.0 - 0.5 + + rotated_tensor = transform_normals_R(rgb_array, rotation_matrix.to(device)) + + rotated_array = rotated_tensor.squeeze().permute(1, 2, 0) + + rotated_array = (rotated_array/2 + 0.5) * 255 + + if alpha_array is not None: + rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2) + + rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8') + + rotated_normal = Image.fromarray(rotated_array_uint8) + + if save_addr: + rotated_normal.save(save_addr) + return rotated_normal + + + +def transform_normals_R(local_normals, rotation_matrix): + assert local_normals.shape[2] ==3 ,f'local_normals.shape[2]: {local_normals.shape[2]}. only support rgb image' + + h, w = local_normals.shape[:2] + local_normals_flat = local_normals.view(-1, 3).permute(1, 0) + + images_flat = local_normals_flat.unsqueeze(0) + rotation_matrices = rotation_matrix.unsqueeze(0) + rotated_images_flat = torch.bmm(rotation_matrices, images_flat) + + rotated_images = rotated_images_flat.view(1, 3, h, w) + + norms = torch.norm(rotated_images, p=2, dim=1, keepdim=True) + norms = torch.where(norms == 0, torch.tensor(EPSILON).to(norms.device), norms) + normalized_images = rotated_images / norms + + return normalized_images + + +def manage_elevation_azimuth(ele_list, azi_list): + """deal with cases when elevation > 90""" + + for i in range(len(ele_list)): + elevation = ele_list[i] % 360 + azimuth = azi_list[i] % 360 + if elevation > 90 and elevation<=270: + # when elevation is too big,camera gets to the other side + # print(f'!!! elevation({elevation}) > 90 and <=270, set to 180-elevation, and add 180 to azimuth') + elevation = 180 - elevation + azimuth = azimuth + 180 + # print(f'new elevation: {elevation}, new azimuth: {azimuth}') + + elif elevation>270: + # print(f'!!! elevation({elevation}) > 270, set to elevation-360, and use original azimuth') + elevation = elevation - 360 + azimuth = azimuth + # print(f'new elevation: {elevation}, new azimuth: {azimuth}') + + ele_list[i] = elevation + azi_list[i] = azimuth + + return ele_list, azi_list + +def get_rotation_matrix_azi_ele(azimuth, elevation): + + ele = elevation/180 * torch.pi + azi = azimuth/180 * torch.pi + + Rz = torch.tensor([ + [torch.cos(azi), 0, -torch.sin(azi)], + [0, 1, 0], + [torch.sin(azi), 0, torch.cos(azi)], + ]).to(azimuth.device) + + Re = torch.tensor([ + [1, 0, 0], + [0, torch.cos(ele), torch.sin(ele)], + [0, -torch.sin(ele), torch.cos(ele)], + ]).to(elevation.device) + + return torch.matmul(Rz,Re).to(azimuth.device) + + +def rotate_vector(vector, axis, angle, device='cuda'): + rot_matrix = rotation_matrix_axis_angle(axis, angle) + return torch.matmul(vector.to(device).float(), rot_matrix.to(device).float()) + +def apply_rotation_angle_axis(image, axis, angle, device='cuda'): + """Apply rotation to a batch of images with shape [batch_size, 3(rgb), h, w] using PyTorch. + + Args: + image (torch.Tensor): Input RGB image tensor of shape [batch_size, 3, h, w]. each pixel's rgb channels refer to direction of normal (can be negative) + axis (torch.Tensor): Rotation axis of shape [3]. + angle (torch.Tensor): Rotation angles in degrees, of shape [batch_size]. + Returns: + torch.Tensor: Rotated image tensor of shape [batch_size, 3, h, w]. values between [-1., 1.] + + """ + + if not isinstance(image, torch.Tensor): + image_tensor = torch.tensor(image).to(device) + else: + image_tensor = image.to(device) + + if not isinstance(axis, torch.Tensor): + axis = torch.tensor(axis) + axis = axis.to(device) + + if not isinstance(angle, torch.Tensor): + angle = torch.tensor(angle) + angle = angle.to(device) + + batch_size, channels, h, w = image_tensor.shape + rot_matrix = rotation_matrix_axis_angle(axis, angle) + + rotation_matrices = rot_matrix.permute(2, 0, 1) + + batch_size, c, h, w = image_tensor.shape + images_flat = image_tensor.view(batch_size, c, h * w) + + rotated_images_flat = torch.bmm(rotation_matrices, images_flat) + + rotated_images = rotated_images_flat.view(batch_size, c, h, w) + + norms = torch.norm(rotated_images, p=2, dim=1, keepdim=True) + + norms = torch.where(norms == 0, torch.tensor(EPSILON).to(norms.device), norms) + + normalized_images = rotated_images / norms + + return normalized_images diff --git a/models/lrm/config/PRM_inference.yaml b/models/lrm/config/PRM_inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b585790199534f499a19b9d40a56a18fc0c8c452 --- /dev/null +++ b/models/lrm/config/PRM_inference.yaml @@ -0,0 +1,22 @@ +model_config: + target: models.lrm.models.lrm_mesh.PRM + params: + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + rendering_samples_per_ray: 128 + grid_res: 128 + grid_scale: 2.1 + + +infer_config: + unet_path: ckpts/diffusion_pytorch_model.bin + model_path: ckpts/final_ckpt.ckpt + texture_resolution: 2048 + render_resolution: 512 \ No newline at end of file diff --git a/models/lrm/models/__init__.py b/models/lrm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/lrm/models/decoder/__init__.py b/models/lrm/models/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/lrm/models/decoder/transformer.py b/models/lrm/models/decoder/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e628c0bf589ee827908c894b93cc107f1c58b9 --- /dev/null +++ b/models/lrm/models/decoder/transformer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class BasicTransformerBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. + """ + # use attention from torch.nn.MultiHeadAttention + # Block contains a cross-attention layer, a self-attention layer, and a MLP + def __init__( + self, + inner_dim: int, + cond_dim: int, + num_heads: int, + eps: float, + attn_drop: float = 0., + attn_bias: bool = False, + mlp_ratio: float = 4., + mlp_drop: float = 0., + ): + super().__init__() + + self.norm1 = nn.LayerNorm(inner_dim) + self.cross_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm3 = nn.LayerNorm(inner_dim) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa)[0] + x = x + self.mlp(self.norm3(x)) + return x + + +class TriplaneTransformer(nn.Module): + """ + Transformer with condition that generates a triplane representation. + + Reference: + Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 + """ + def __init__( + self, + inner_dim: int, + image_feat_dim: int, + triplane_low_res: int, + triplane_high_res: int, + triplane_dim: int, + num_layers: int, + num_heads: int, + eps: float = 1e-6, + ): + super().__init__() + + # attributes + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + + # modules + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5) + self.layers = nn.ModuleList([ + BasicTransformerBlock( + inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(inner_dim, eps=eps) + self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0) + + def forward(self, image_feats): + # image_feats: [N, L_cond, D_cond] + + N = image_feats.shape[0] + H = W = self.triplane_low_res + L = 3 * H * W + + x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] + for layer in self.layers: + x = layer(x, image_feats) + x = self.norm(x) + + # separate each plane and apply deconv + x = x.view(N, 3, H, W, -1) + x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.deconv(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + + return x diff --git a/models/lrm/models/encoder/__init__.py b/models/lrm/models/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/lrm/models/encoder/dino.py b/models/lrm/models/encoder/dino.py new file mode 100644 index 0000000000000000000000000000000000000000..684444cab2a13979bcd5688069e9f7729d4ca784 --- /dev/null +++ b/models/lrm/models/encoder/dino.py @@ -0,0 +1,550 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ViT model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, +) +from transformers import PreTrainedModel, ViTConfig +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer + + +class ViTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ViTPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class ViTSelfAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTSelfOutput(nn.Module): + """ + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.attention = ViTSelfAttention(config) + self.output = ViTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTIntermediate(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTOutput(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class ViTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTAttention(config) + self.intermediate = ViTIntermediate(config) + self.output = ViTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) + ) + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward( + self, + hidden_states: torch.Tensor, + adaln_input: torch.Tensor = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + self_attention_outputs = self.attention( + modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ViTEncoder(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + adaln_input: torch.Tensor = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + adaln_input, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +class ViTModel(ViTPreTrainedModel): + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = ViTPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + adaln_input: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + adaln_input=adaln_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTPooler(nn.Module): + def __init__(self, config: ViTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output \ No newline at end of file diff --git a/models/lrm/models/encoder/dino_wrapper.py b/models/lrm/models/encoder/dino_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e84fd51e7dfcfd1a969b763f5a49aeb7f608e6f9 --- /dev/null +++ b/models/lrm/models/encoder/dino_wrapper.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from transformers import ViTImageProcessor +from einops import rearrange, repeat +from .dino import ViTModel + + +class DinoWrapper(nn.Module): + """ + Dino v1 wrapper using huggingface transformer implementation. + """ + def __init__(self, model_name: str, freeze: bool = True): + super().__init__() + self.model, self.processor = self._build_dino(model_name) + self.camera_embedder = nn.Sequential( + nn.Linear(16, self.model.config.hidden_size, bias=True), + nn.SiLU(), + nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True) + ) + if freeze: + self._freeze() + + def forward(self, image, camera): + # image: [B, N, C, H, W] + # camera: [B, N, D] + # RGB image with [0,1] scale and properly sized + if image.ndim == 5: + image = rearrange(image, 'b n c h w -> (b n) c h w') + dtype = image.dtype + inputs = self.processor( + images=image.float(), + return_tensors="pt", + do_rescale=False, + do_resize=False, + ).to(self.model.device).to(dtype) + # embed camera + N = camera.shape[1] + camera_embeddings = self.camera_embedder(camera) + camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') + embeddings = camera_embeddings + # This resampling of positional embedding uses bicubic interpolation + outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True) + last_hidden_states = outputs.last_hidden_state + return last_hidden_states + + def _freeze(self): + print(f"======== Freezing DinoWrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): + import requests + try: + model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) + processor = ViTImageProcessor.from_pretrained(model_name) + return model, processor + except requests.exceptions.ProxyError as err: + if proxy_error_retries > 0: + print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") + import time + time.sleep(proxy_error_cooldown) + return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) + else: + raise err diff --git a/models/lrm/models/geometry/__init__.py b/models/lrm/models/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89e9a6c2fffe82a55693885dae78c1a630924389 --- /dev/null +++ b/models/lrm/models/geometry/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. diff --git a/models/lrm/models/geometry/camera/__init__.py b/models/lrm/models/geometry/camera/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c7082e47c65a08e25489b3c3fd010d07ad9758 --- /dev/null +++ b/models/lrm/models/geometry/camera/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from torch import nn + + +class Camera(nn.Module): + def __init__(self): + super(Camera, self).__init__() + pass diff --git a/models/lrm/models/geometry/camera/perspective_camera.py b/models/lrm/models/geometry/camera/perspective_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcab0d2a321a77a5d3c2d4c3f40ba2cc32f6dfa --- /dev/null +++ b/models/lrm/models/geometry/camera/perspective_camera.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from . import Camera +import numpy as np + + +def projection(x=0.1, n=1.0, f=50.0, near_plane=None): + if near_plane is None: + near_plane = n + return np.array( + [[n / x, 0, 0, 0], + [0, n / -x, 0, 0], + [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], + [0, 0, -1, 0]]).astype(np.float32) + + +class PerspectiveCamera(Camera): + def __init__(self, fovy=49.0, device='cuda'): + super(PerspectiveCamera, self).__init__() + self.device = device + focal = np.tan(fovy / 180.0 * np.pi * 0.5) + self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) + + def project(self, points_bxnx4): + out = torch.matmul( + points_bxnx4, + torch.transpose(self.proj_mtx, 1, 2)) + return out diff --git a/models/lrm/models/geometry/render/__init__.py b/models/lrm/models/geometry/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..483cfabbf395853f1ca3e67b856d5f17b9889d1b --- /dev/null +++ b/models/lrm/models/geometry/render/__init__.py @@ -0,0 +1,8 @@ +import torch + +class Renderer(): + def __init__(self): + pass + + def forward(self): + pass \ No newline at end of file diff --git a/models/lrm/models/geometry/render/neural_render.py b/models/lrm/models/geometry/render/neural_render.py new file mode 100644 index 0000000000000000000000000000000000000000..5d86fcc3f752fa4fcc7e7088438e0f980d6cf64a --- /dev/null +++ b/models/lrm/models/geometry/render/neural_render.py @@ -0,0 +1,293 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr +from . import Renderer +from . import util +from . import renderutils as ru +_FG_LUT = None + + +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate( + attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + +def xfm_points(points, matrix, use_python=True): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def compute_vertex_normal(v_pos, t_pos_idx): + i0 = t_pos_idx[:, 0] + i1 = t_pos_idx[:, 1] + i2 = t_pos_idx[:, 2] + + v0 = v_pos[i0, :] + v1 = v_pos[i1, :] + v2 = v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + +class NeuralRender(Renderer): + def __init__(self, device='cuda', camera_model=None): + super(NeuralRender, self).__init__() + self.device = device + self.ctx = dr.RasterizeCudaContext(device=device) + self.projection_mtx = None + self.camera = camera_model + + # ============================================================================================== + # pixel shader + # ============================================================================================== + # def shade( + # self, + # gb_pos, + # gb_geometric_normal, + # gb_normal, + # gb_tangent, + # gb_texc, + # gb_texc_deriv, + # view_pos, + # ): + + # ################################################################################ + # # Texture lookups + # ################################################################################ + # breakpoint() + # # Separate kd into alpha and color, default alpha = 1 + # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) + # kd = kd[..., 0:3] + + # ################################################################################ + # # Normal perturbation & normal bend + # ################################################################################ + + # perturbed_nrm = None + + # gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) + + # ################################################################################ + # # Evaluate BSDF + # ################################################################################ + + # assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type" + # bsdf = material['bsdf'] if bsdf is None else bsdf + # if bsdf == 'pbr': + # if isinstance(lgt, light.EnvironmentLight): + # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) + # else: + # assert False, "Invalid light type" + # elif bsdf == 'diffuse': + # if isinstance(lgt, light.EnvironmentLight): + # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False) + # else: + # assert False, "Invalid light type" + # elif bsdf == 'normal': + # shaded_col = (gb_normal + 1.0)*0.5 + # elif bsdf == 'tangent': + # shaded_col = (gb_tangent + 1.0)*0.5 + # elif bsdf == 'kd': + # shaded_col = kd + # elif bsdf == 'ks': + # shaded_col = ks + # else: + # assert False, "Invalid BSDF '%s'" % bsdf + + # # Return multiple buffers + # buffers = { + # 'shaded' : torch.cat((shaded_col, alpha), dim=-1), + # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), + # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1) + # } + # return buffers + + # ============================================================================================== + # Render a depth slice of the mesh (scene), some limitations: + # - Single mesh + # - Single light + # - Single material + # ============================================================================================== + def render_layer( + self, + rast, + rast_deriv, + mesh, + view_pos, + resolution, + spp, + msaa + ): + + # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution + rast_out_s = rast + rast_out_deriv_s = rast_deriv + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) + + # Compute geometric normals. We need those because of bent normals trick (for bump mapping) + v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] + v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] + v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] + face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) + face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) + + # Compute tangent space + assert mesh.v_nrm is not None and mesh.v_tng is not None + gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) + gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents + + # Texture coordinate + # assert mesh.v_tex is not None + # gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s) + perturbed_nrm = None + gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) + + return gb_pos, gb_normal + + def render_mesh( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + mesh, + camera_mv_bx4x4, + camera_pos, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + + # view_pos = torch.linalg.inv(mtx_in)[:, :3, 3] + view_pos = camera_pos + v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization] + + layers = [] + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False) + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + # normal = F.normalize(normal, dim=-1) + # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal + + def render_mesh_light( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + mesh, + camera_mv_bx4x4, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + + v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + normal = F.normalize(normal, dim=-1) + normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background + + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal diff --git a/models/lrm/models/geometry/render/renderutils/__init__.py b/models/lrm/models/geometry/render/renderutils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f29739f961e48de71c58b4bbc45801654df49a70 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith +__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ] diff --git a/models/lrm/models/geometry/render/renderutils/bsdf.py b/models/lrm/models/geometry/render/renderutils/bsdf.py new file mode 100644 index 0000000000000000000000000000000000000000..38457ed58ee447cdf74bb780eb7457d4db1f7f92 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/bsdf.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch + +NORMAL_THRESHOLD = 0.1 + +################################################################################ +# Vector utility functions +################################################################################ + +def _dot(x, y): + return torch.sum(x*y, -1, keepdim=True) + +def _reflect(x, n): + return 2*_dot(x, n)*n - x + +def _safe_normalize(x): + return torch.nn.functional.normalize(x, dim = -1) + +def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading): + # Swap normal direction for backfacing surfaces + if two_sided_shading: + smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm) + geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm) + + t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1) + return torch.lerp(geom_nrm, smooth_nrm, t) + + +def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl): + smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm)) + if opengl: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + else: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + return _safe_normalize(shading_nrm) + +def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + smooth_nrm = _safe_normalize(smooth_nrm) + smooth_tng = _safe_normalize(smooth_tng) + view_vec = _safe_normalize(view_pos - pos) + shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl) + return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading) + +################################################################################ +# Simple lambertian diffuse BSDF +################################################################################ + +def bsdf_lambert(nrm, wi): + return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi + +################################################################################ +# Frostbite diffuse +################################################################################ + +def bsdf_frostbite(nrm, wi, wo, linearRoughness): + wiDotN = _dot(wi, nrm) + woDotN = _dot(wo, nrm) + + h = _safe_normalize(wo + wi) + wiDotH = _dot(wi, h) + + energyBias = 0.5 * linearRoughness + energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness + f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness + f0 = 1.0 + + wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN) + woScatter = bsdf_fresnel_shlick(f0, f90, woDotN) + res = wiScatter * woScatter * energyFactor + return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res)) + +################################################################################ +# Phong specular, loosely based on mitsuba implementation +################################################################################ + +def bsdf_phong(nrm, wo, wi, N): + dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0) + dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0) + return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi) + +################################################################################ +# PBR's implementation of GGX specular +################################################################################ + +specular_epsilon = 1e-4 + +def bsdf_fresnel_shlick(f0, f90, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0 + +def bsdf_ndf_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1 + return alphaSqr / (d * d * math.pi) + +def bsdf_lambda_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + cosThetaSqr = _cosTheta * _cosTheta + tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr + res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0) + return res + +def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO): + lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI) + lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO) + return 1 / (1 + lambdaI + lambdaO) + +def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08): + _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0) + alphaSqr = _alpha * _alpha + + h = _safe_normalize(wo + wi) + woDotN = _dot(wo, nrm) + wiDotN = _dot(wi, nrm) + woDotH = _dot(wo, h) + nDotH = _dot(nrm, h) + + D = bsdf_ndf_ggx(alphaSqr, nDotH) + G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN) + F = bsdf_fresnel_shlick(col, 1, woDotH) + + w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon) + + frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon) + return torch.where(frontfacing, w, torch.zeros_like(w)) + +def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + wo = _safe_normalize(view_pos - pos) + wi = _safe_normalize(light_pos - pos) + + spec_str = arm[..., 0:1] # x component + roughness = arm[..., 1:2] # y component + metallic = arm[..., 2:3] # z component + ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str) + kd = kd * (1.0 - metallic) + + if BSDF == 0: + diffuse = kd * bsdf_lambert(nrm, wi) + else: + diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness) + specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness) + return diffuse + specular diff --git a/models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu b/models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu new file mode 100644 index 0000000000000000000000000000000000000000..c167214f9a4cb42b8d640202969e3950be8b806d --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "bsdf.h" + +#define SPECULAR_EPSILON 1e-4f + +//------------------------------------------------------------------------ +// Lambert functions + +__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi) +{ + return max(dot(nrm, wi) / M_PI, 0.0f); +} + +__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) +{ + if (dot(nrm, wi) > 0.0f) + bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); +} + +//------------------------------------------------------------------------ +// Fresnel Schlick + +__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); + } +} + +__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); + } +} + +//------------------------------------------------------------------------ +// Frostbite diffuse + +__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + return wiScatter * woScatter * energyFactor; + } + else return 0.0f; +} + +__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + // -------------- BWD -------------- + // Backprop: return wiScatter * woScatter * energyFactor; + float d_wiScatter = d_out * woScatter * energyFactor; + float d_woScatter = d_out * wiScatter * energyFactor; + float d_energyFactor = d_out * wiScatter * woScatter; + + // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f; + bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter); + + // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN); + float d_wiDotN = 0.0f; + bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter); + + // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float d_energyBias = d_f90; + float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness; + d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH; + + // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor; + + // Backprop: float energyBias = 0.5f * linearRoughness; + d_linearRoughness += 0.5 * d_energyBias; + + // Backprop: float wiDotH = dot(wi, h); + vec3f d_h(0); + bwdDot(wi, h, d_wi, d_h, d_wiDotH); + + // Backprop: vec3f h = safeNormalize(wo + wi); + vec3f d_wo_wi(0); + bwdSafeNormalize(wo + wi, d_wo_wi, d_h); + d_wi += d_wo_wi; d_wo += d_wo_wi; + + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + } +} + +//------------------------------------------------------------------------ +// Ndf GGX + +__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + // Torch only back propagates if clamp doesn't trigger + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + } +} + +//------------------------------------------------------------------------ +// Lambda GGX + +__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + return res; +} + +__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + + d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); +} + +//------------------------------------------------------------------------ +// Masking GGX + +__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) +{ + // FWD eval + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + + // BWD eval + float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); + bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); + bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); +} + +//------------------------------------------------------------------------ +// GGX specular + +__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) +{ + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + return frontfacing ? w : 0.0f; +} + +__device__ void bwdPbrSpecular( + const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, + vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) +{ + /////////////////////////////////////////////////////////////////////// + // FWD eval + + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + + if (frontfacing) + { + /////////////////////////////////////////////////////////////////////// + // BWD eval + + vec3f d_F = d_out * D * G * 0.25f / woDotN; + float d_D = sum(d_out * F * G * 0.25f / woDotN); + float d_G = sum(d_out * F * D * 0.25f / woDotN); + + float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); + + vec3f d_f90(0); + float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); + bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); + bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); + bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); + + vec3f d_h(0); + bwdDot(nrm, h, d_nrm, d_h, d_nDotH); + bwdDot(wo, h, d_wo, d_h, d_woDotH); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + + vec3f d_h_unnorm(0); + bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); + d_wo += d_h_unnorm; + d_wi += d_h_unnorm; + + if (alpha > min_roughness * min_roughness) + d_alpha += d_alphaSqr * 2 * alpha; + } +} + +//------------------------------------------------------------------------ +// Full PBR BSDF + +__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF) +{ + vec3f wo = safeNormalize(view_pos - pos); + vec3f wi = safeNormalize(light_pos - pos); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + vec3f diffuse = diff_col * diff; + vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); + + return diffuse + specular; +} + +__device__ void bwdPbrBSDF( + const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF, + vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _wi = light_pos - pos; + vec3f _wo = view_pos - pos; + vec3f wi = safeNormalize(_wi); + vec3f wo = safeNormalize(_wo); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + + //////////////////////////////////////////////////////////////////////// + // BWD + + float d_alpha(0); + vec3f d_spec_col(0), d_wi(0), d_wo(0); + bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + float d_diff = sum(diff_col * d_out); + if (BSDF == 0) + bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); + else + bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff); + + // Backprop: diff_col = kd * (1.0f - arm.z) + vec3f d_diff_col = d_out * diff; + d_kd += d_diff_col * (1.0f - arm.z); + d_arm.z -= sum(d_diff_col * kd); + + // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) + d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; + d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); + d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); + + // Backprop: alpha = arm.y * arm.y + d_arm.y += d_alpha * 2 * arm.y; + + // Backprop: vec3f wi = safeNormalize(light_pos - pos); + vec3f d__wi(0); + bwdSafeNormalize(_wi, d__wi, d_wi); + d_light_pos += d__wi; + d_pos -= d__wi; + + // Backprop: vec3f wo = safeNormalize(view_pos - pos); + vec3f d__wo(0); + bwdSafeNormalize(_wo, d__wo, d_wo); + d_view_pos += d__wo; + d_pos -= d__wo; +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void LambertFwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + + float res = fwdLambert(nrm, wi); + + p.out.store(px, py, pz, res); +} + +__global__ void LambertBwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + vec3f d_nrm(0), d_wi(0); + bwdLambert(nrm, wi, d_nrm, d_wi, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); +} + +__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + + float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness); + + p.out.store(px, py, pz, res); +} + +__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_linearRoughness = 0.0f; + vec3f d_nrm(0), d_wi(0), d_wo(0); + bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); + p.wo.store_grad(px, py, pz, d_wo); + p.linearRoughness.store_grad(px, py, pz, d_linearRoughness); +} + +__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + + vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); + p.out.store(px, py, pz, res); +} + +__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_f0(0), d_f90(0); + float d_cosTheta(0); + bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); + + p.f0.store_grad(px, py, pz, d_f0); + p.f90.store_grad(px, py, pz, d_f90); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void ndfGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdNdfGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void ndfGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void lambdaGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdLambdaGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void lambdaGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void maskingSmithFwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); + + p.out.store(px, py, pz, res); +} + +__global__ void maskingSmithBwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); + bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); + p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); +} + +__global__ void pbrSpecularFwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + + vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); + + p.out.store(px, py, pz, res); +} + +__global__ void pbrSpecularBwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + float d_alpha(0); + vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); + bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + p.col.store_grad(px, py, pz, d_col); + p.nrm.store_grad(px, py, pz, d_nrm); + p.wo.store_grad(px, py, pz, d_wo); + p.wi.store_grad(px, py, pz, d_wi); + p.alpha.store_grad(px, py, pz, d_alpha); +} + +__global__ void pbrBSDFFwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + + vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF); + + p.out.store(px, py, pz, res); +} +__global__ void pbrBSDFBwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); + bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); + + p.kd.store_grad(px, py, pz, d_kd); + p.arm.store_grad(px, py, pz, d_arm); + p.pos.store_grad(px, py, pz, d_pos); + p.nrm.store_grad(px, py, pz, d_nrm); + p.view_pos.store_grad(px, py, pz, d_view_pos); + p.light_pos.store_grad(px, py, pz, d_light_pos); +} diff --git a/models/lrm/models/geometry/render/renderutils/c_src/bsdf.h b/models/lrm/models/geometry/render/renderutils/c_src/bsdf.h new file mode 100644 index 0000000000000000000000000000000000000000..59adbf097490c5a643ebdcff9c3784173522e070 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/bsdf.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct LambertKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor out; + dim3 gridSize; +}; + +struct FrostbiteDiffuseKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor wo; + Tensor linearRoughness; + Tensor out; + dim3 gridSize; +}; + +struct FresnelShlickKernelParams +{ + Tensor f0; + Tensor f90; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct NdfGGXParams +{ + Tensor alphaSqr; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct MaskingSmithParams +{ + Tensor alphaSqr; + Tensor cosThetaI; + Tensor cosThetaO; + Tensor out; + dim3 gridSize; +}; + +struct PbrSpecular +{ + Tensor col; + Tensor nrm; + Tensor wo; + Tensor wi; + Tensor alpha; + Tensor out; + dim3 gridSize; + float min_roughness; +}; + +struct PbrBSDF +{ + Tensor kd; + Tensor arm; + Tensor pos; + Tensor nrm; + Tensor view_pos; + Tensor light_pos; + Tensor out; + dim3 gridSize; + float min_roughness; + int BSDF; +}; diff --git a/models/lrm/models/geometry/render/renderutils/c_src/common.cpp b/models/lrm/models/geometry/render/renderutils/c_src/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..445895e57f7d0bcd6a2812f5ba97d7be2ddfbe28 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/common.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +//------------------------------------------------------------------------ +// Block and grid size calculators for kernel launches. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims) +{ + int maxThreads = maxWidth * maxHeight; + if (maxThreads <= 1 || (dims.x * dims.y) <= 1) + return dim3(1, 1, 1); // Degenerate. + + // Start from max size. + int bw = maxWidth; + int bh = maxHeight; + + // Optimizations for weirdly sized buffers. + if (dims.x < bw) + { + // Decrease block width to smallest power of two that covers the buffer width. + while ((bw >> 1) >= dims.x) + bw >>= 1; + + // Maximize height. + bh = maxThreads / bw; + if (bh > dims.y) + bh = dims.y; + } + else if (dims.y < bh) + { + // Halve height and double width until fits completely inside buffer vertically. + while (bh > dims.y) + { + bh >>= 1; + if (bw < dims.x) + bw <<= 1; + } + } + + // Done. + return dim3(bw, bh, 1); +} + +// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync) +dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + std::min(blockSize.x, 32u), + std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), + std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z)) + ); +} + +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims) +{ + dim3 gridSize; + gridSize.x = (dims.x - 1) / blockSize.x + 1; + gridSize.y = (dims.y - 1) / blockSize.y + 1; + gridSize.z = (dims.z - 1) / blockSize.z + 1; + return gridSize; +} + +//------------------------------------------------------------------------ diff --git a/models/lrm/models/geometry/render/renderutils/c_src/common.h b/models/lrm/models/geometry/render/renderutils/c_src/common.h new file mode 100644 index 0000000000000000000000000000000000000000..5abaeebdd3f0a0910f7df3e9e0470a9fa682d507 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/common.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +#include "vec3f.h" +#include "vec4f.h" +#include "tensor.h" + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims); +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims); + +#ifdef __CUDACC__ + +#ifdef _MSC_VER +#define M_PI 3.14159265358979323846f +#endif + +__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + min(blockSize.x, 32u), + min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)), + min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z)) + ); +} + +__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); } +#else +dim3 getWarpSize(dim3 blockSize); +#endif \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu b/models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu new file mode 100644 index 0000000000000000000000000000000000000000..2ce21d83b2dd6759da30874cf8e01b7fd88e9217 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "cubemap.h" +#include + +// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf +__device__ float pixel_area(int x, int y, int N) +{ + if (N > 1) + { + int H = N / 2; + x = abs(x - H); + y = abs(y - H); + float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H); + float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H); + return dx * dy; + } + else + return 1; +} + +__device__ vec3f cube_to_dir(int x, int y, int side, int N) +{ + float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f; + float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f; + switch (side) + { + case 0: return safeNormalize(vec3f(1, -fy, -fx)); + case 1: return safeNormalize(vec3f(-1, -fy, fx)); + case 2: return safeNormalize(vec3f(fx, 1, fy)); + case 3: return safeNormalize(vec3f(fx, -1, -fy)); + case 4: return safeNormalize(vec3f(fx, -fy, 1)); + case 5: return safeNormalize(vec3f(-fx, -fy, -1)); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ vec3f dir_to_side(int side, vec3f v) +{ + switch (side) + { + case 0: return vec3f(-v.z, -v.y, v.x); + case 1: return vec3f( v.z, -v.y, -v.x); + case 2: return vec3f( v.x, v.z, v.y); + case 3: return vec3f( v.x, -v.z, -v.y); + case 4: return vec3f( v.x, -v.y, v.z); + case 5: return vec3f(-v.x, -v.y, -v.z); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max) +{ + float l = sqrtf(x * x + z * z); + float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l; + float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l; + if (pzl <= 0.00001f) + _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX; + else + _min = pxl / pzl; + if (pzr <= 0.00001f) + _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX; + else + _max = pxr / pzr; +} + +__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax) +{ + vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1 + + if (theta < 0.785398f) // PI/4 + { + float xmin, xmax, ymin, ymax; + extents_1d(c.x, c.z, theta, xmin, xmax); + extents_1d(c.y, c.z, theta, ymin, ymax); + + if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f) + { + _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb + } + else + { + _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + } + } + else + { + _xmin = 0.0f; + _xmax = (float)(N-1); + _ymin = 0.0f; + _ymax = (float)(N-1); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Diffuse kernel +__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + + vec3f col(0); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + col += p.cubemap.fetch3(x, y, s) * w; + } + } + } + + p.out.store(px, py, pz, col); +} + +__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + vec3f grad = p.out.fetch3(px, py, pz); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// GGX splitsum kernel + +__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, 0.0, 1.0f); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p) +{ + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.gridSize.x; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + const int TILE_SIZE = 16; + + // Brute force entire cubemap and compute bounds for the cone + for (int s = 0; s < p.gridSize.z; ++s) + { + // Assume empty BBox + int _min_x = p.gridSize.x - 1, _max_x = 0; + int _min_y = p.gridSize.y - 1, _max_y = 0; + + // For each (8x8) tile + for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++) + { + for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++) + { + // Compute tile extents + int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE; + int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y); + + // Use some blunt interval arithmetics to cull tiles + vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx); + vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx); + + float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x)); + float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y)); + float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z)); + + float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z); + if (maxdp >= p.costheta_cutoff) + { + // Test all pixels in tile. + for (int y = tsy; y < tey; ++y) + { + for (int x = tsx; x < tex; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + _min_x = min(_min_x, x); + _max_x = max(_max_x, x); + _min_y = min(_min_y, y); + _max_y = max(_max_y, y); + } + } + } + } + } + } + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y); + } +} + +__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + float wsum = 0.0f; + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + col += p.cubemap.fetch3(x, y, s) * w; + wsum += w; + } + } + } + } + } + + p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x); + p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y); + p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z); + p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum); +} + +__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + vec3f grad = p.out.fetch3(px, py, pz); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } + } + } +} diff --git a/models/lrm/models/geometry/render/renderutils/c_src/cubemap.h b/models/lrm/models/geometry/render/renderutils/c_src/cubemap.h new file mode 100644 index 0000000000000000000000000000000000000000..f395cc237d4a46c660bcde18609068a21f3c3fea --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/cubemap.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct DiffuseCubemapKernelParams +{ + Tensor cubemap; + Tensor out; + dim3 gridSize; +}; + +struct SpecularCubemapKernelParams +{ + Tensor cubemap; + Tensor bounds; + Tensor out; + dim3 gridSize; + float costheta_cutoff; + float roughness; +}; + +struct SpecularBoundsKernelParams +{ + float costheta_cutoff; + Tensor out; + dim3 gridSize; +}; diff --git a/models/lrm/models/geometry/render/renderutils/c_src/loss.cu b/models/lrm/models/geometry/render/renderutils/c_src/loss.cu new file mode 100644 index 0000000000000000000000000000000000000000..aae5272de3c5364c22ee0bd5fde023d908e9153d --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/loss.cu @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +#include "common.h" +#include "loss.h" + +//------------------------------------------------------------------------ +// Utils + +__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; } + +__device__ float warpSum(float val) { + for (int i = 1; i < 32; i *= 2) + val += __shfl_xor_sync(0xFFFFFFFF, val, i); + return val; +} + +//------------------------------------------------------------------------ +// Tonemapping + +__device__ inline float fwdSRGB(float x) +{ + return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f); +} + +__device__ inline void bwdSRGB(float x, float &d_x, float d_out) +{ + if (x > 0.0031308f) + d_x += d_out * 0.439583f / powf(x, 0.583333f); + else if (x > 0.0f) + d_x += d_out * 12.92f; +} + +__device__ inline vec3f fwdTonemapLogSRGB(vec3f x) +{ + return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f))); +} + +__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out) +{ + if (x.x > 0.0f && x.x < 65535.0f) + { + bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x); + d_x.x *= 1 / (x.x + 1.0f); + } + if (x.y > 0.0f && x.y < 65535.0f) + { + bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y); + d_x.y *= 1 / (x.y + 1.0f); + } + if (x.z > 0.0f && x.z < 65535.0f) + { + bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z); + d_x.z *= 1 / (x.z + 1.0f); + } +} + +__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f) +{ + return (img - target) * (img - target) / (img * img + target * target + eps); +} + +__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f) +{ + float denom = (target * target + img * img + eps); + d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom); + d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom); +} + +__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f) +{ + return abs(img - target) / (img + target + eps); +} + +__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f) +{ + float denom = (target + img + eps); + d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom); + d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom); +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void imgLossFwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + float floss = 0.0f; + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z) + { + vec3f img = p.img.fetch3(px, py, pz); + vec3f target = p.target.fetch3(px, py, pz); + + img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f)); + target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f)); + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + vec3f vloss(0); + if (p.loss == LOSS_MSE) + vloss = (img - target) * (img - target); + else if (p.loss == LOSS_RELMSE) + vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z)); + else if (p.loss == LOSS_SMAPE) + vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z)); + else + vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z)); + + floss = sum(vloss) / 3.0f; + } + + floss = warpSum(floss); + + dim3 warpSize = getWarpSize(blockDim); + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0) + p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss); +} + +__global__ void imgLossBwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + dim3 warpSize = getWarpSize(blockDim); + + vec3f _img = p.img.fetch3(px, py, pz); + vec3f _target = p.target.fetch3(px, py, pz); + float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z); + + ///////////////////////////////////////////////////////////////////// + // FWD + + vec3f img = _img, target = _target; + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + ///////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f; + + vec3f d_img(0), d_target(0); + if (p.loss == LOSS_MSE) + { + d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z)); + d_target = -d_img; + } + else if (p.loss == LOSS_RELMSE) + { + bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else if (p.loss == LOSS_SMAPE) + { + bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else + { + d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z)); + d_target = -d_img; + } + + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + vec3f d__img(0), d__target(0); + bwdTonemapLogSRGB(_img, d__img, d_img); + bwdTonemapLogSRGB(_target, d__target, d_target); + d_img = d__img; d_target = d__target; + } + + if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0; + if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0; + if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0; + if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0; + if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0; + if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0; + + p.img.store_grad(px, py, pz, d_img); + p.target.store_grad(px, py, pz, d_target); +} \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/c_src/loss.h b/models/lrm/models/geometry/render/renderutils/c_src/loss.h new file mode 100644 index 0000000000000000000000000000000000000000..26790bf02de2afd9d27e541edf23d1b064f6f9a9 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/loss.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +enum TonemapperType +{ + TONEMAPPER_NONE = 0, + TONEMAPPER_LOG_SRGB = 1 +}; + +enum LossType +{ + LOSS_L1 = 0, + LOSS_MSE = 1, + LOSS_RELMSE = 2, + LOSS_SMAPE = 3 +}; + +struct LossKernelParams +{ + Tensor img; + Tensor target; + Tensor out; + dim3 gridSize; + TonemapperType tonemapper; + LossType loss; +}; diff --git a/models/lrm/models/geometry/render/renderutils/c_src/mesh.cu b/models/lrm/models/geometry/render/renderutils/c_src/mesh.cu new file mode 100644 index 0000000000000000000000000000000000000000..3690ea3621c38beae03ac9ff228cf5605d303663 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/mesh.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +#include "common.h" +#include "mesh.h" + + +//------------------------------------------------------------------------ +// Kernels + +__global__ void xfmPointsFwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + if (p.isPoints) + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]); + p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]); + } + else + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]); + } +} + +__global__ void xfmPointsBwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + vec4f d_out( + p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0)) + ); + + if (p.isPoints) + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]); + } + else + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]); + } +} \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/c_src/mesh.h b/models/lrm/models/geometry/render/renderutils/c_src/mesh.h new file mode 100644 index 0000000000000000000000000000000000000000..16e2166cc55f41c4482b2c5010529e9c75182d7b --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/mesh.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct XfmKernelParams +{ + bool isPoints; + Tensor points; + Tensor matrix; + Tensor out; + dim3 gridSize; +}; diff --git a/models/lrm/models/geometry/render/renderutils/c_src/normal.cu b/models/lrm/models/geometry/render/renderutils/c_src/normal.cu new file mode 100644 index 0000000000000000000000000000000000000000..a50e49e6b5b4061a60ec4d5d8edca2fb0833570e --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/normal.cu @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "normal.h" + +#define NORMAL_THRESHOLD 0.1f + +//------------------------------------------------------------------------ +// Perturb shading normal by tangent frame + +__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl) +{ + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + return safeNormalize(_shading_nrm); +} + +__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + vec3f d_shading_nrm(0); + bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out); + + vec3f d_smooth_bitng(0); + + if (perturbed_nrm.z > 0.0f) + { + d_smooth_nrm += d_shading_nrm * perturbed_nrm.z; + d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm); + } + + d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y; + d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng); + + d_smooth_tng += d_shading_nrm * perturbed_nrm.x; + d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng); + + vec3f d__smooth_bitng(0); + bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng); + + bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng); +} + +//------------------------------------------------------------------------ +#define bent_nrm_eps 0.001f + +__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm) +{ + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + return geom_nrm * (1.0f - t) + smooth_nrm * t; +} + +__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + if (dp > NORMAL_THRESHOLD) + d_smooth_nrm += d_out; + else + { + // geom_nrm * (1.0f - t) + smooth_nrm * t; + d_geom_nrm += d_out * (1.0f - t); + d_smooth_nrm += d_out * t; + float d_t = sum(d_out * (smooth_nrm - geom_nrm)); + + float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD; + + bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp); + } +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f view_vec = safeNormalize(view_pos - pos); + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + vec3f res; + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm); + else + res = fwdBendNormal(view_vec, shading_nrm, geom_nrm); + + p.out.store(px, py, pz, res); +} + +__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // FWD + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f _view_vec = view_pos - pos; + vec3f view_vec = safeNormalize(view_pos - pos); + + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0); + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + { + bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + d_shading_nrm = -d_shading_nrm; + d_geom_nrm = -d_geom_nrm; + } + else + bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + + vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0); + bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl); + + vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0); + bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec); + bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm); + bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng); + + p.pos.store_grad(px, py, pz, -d__view_vec); + p.view_pos.store_grad(px, py, pz, d__view_vec); + p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm); + p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm); + p.smooth_tng.store_grad(px, py, pz, d__smooth_tng); + p.geom_nrm.store_grad(px, py, pz, d_geom_nrm); +} \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/c_src/normal.h b/models/lrm/models/geometry/render/renderutils/c_src/normal.h new file mode 100644 index 0000000000000000000000000000000000000000..8882c225cfba5e747462c056d6fcf0b04dd48751 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/normal.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct PrepareShadingNormalKernelParams +{ + Tensor pos; + Tensor view_pos; + Tensor perturbed_nrm; + Tensor smooth_nrm; + Tensor smooth_tng; + Tensor geom_nrm; + Tensor out; + dim3 gridSize; + bool two_sided_shading, opengl; +}; diff --git a/models/lrm/models/geometry/render/renderutils/c_src/tensor.h b/models/lrm/models/geometry/render/renderutils/c_src/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..1dfb4e85c46f0394821f2533dc98468e5b7248af --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/tensor.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#if defined(__CUDACC__) && defined(BFLOAT16) +#include // bfloat16 is float32 compatible with less mantissa bits +#endif + +//--------------------------------------------------------------------------------- +// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16 + +struct Tensor +{ + void* val; + void* d_val; + int dims[4], _dims[4]; + int strides[4]; + bool fp16; + +#if defined(__CUDA__) && !defined(__CUDA_ARCH__) + Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {} +#endif + +#ifdef __CUDACC__ + // Helpers to index and read/write a single element + __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; } + __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); } + __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; } +#ifdef BFLOAT16 + __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; } +#else + __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; } +#endif + + ////////////////////////////////////////////////////////////////////////////////////////// + // Fetch, use broadcasting for tensor dimensions of size 1 + __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const + { + return fetch(nhwcIndex(z, y, x, 0)); + } + + __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const + { + return vec3f( + fetch(nhwcIndex(z, y, x, 0)), + fetch(nhwcIndex(z, y, x, 1)), + fetch(nhwcIndex(z, y, x, 2)) + ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store(_nhwcIndex(z, y, x, 0), _val); + } + + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store(_nhwcIndex(z, y, x, 0), _val.x); + store(_nhwcIndex(z, y, x, 1), _val.y); + store(_nhwcIndex(z, y, x, 2), _val.z); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val); + } + + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x); + store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y); + store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z); + } +#endif + +}; diff --git a/models/lrm/models/geometry/render/renderutils/c_src/torch_bindings.cpp b/models/lrm/models/geometry/render/renderutils/c_src/torch_bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64c9e70f79507944490cb978233c34ac9e3e97a6 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/torch_bindings.cpp @@ -0,0 +1,1062 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#ifdef _MSC_VER +#pragma warning(push, 0) +#include +#pragma warning(pop) +#else +#include +#endif + +#include +#include +#include +#include + +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); } +#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } +#define CHECK_TENSOR(X, DIMS, CHANNELS) \ + TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \ + TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \ + TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \ + TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels") + +#include "common.h" +#include "loss.h" +#include "normal.h" +#include "cubemap.h" +#include "bsdf.h" +#include "mesh.h" + +#define BLOCK_X 8 +#define BLOCK_Y 8 + +//------------------------------------------------------------------------ +// mesh.cu + +void xfmPointsFwdKernel(XfmKernelParams p); +void xfmPointsBwdKernel(XfmKernelParams p); + +//------------------------------------------------------------------------ +// loss.cu + +void imgLossFwdKernel(LossKernelParams p); +void imgLossBwdKernel(LossKernelParams p); + +//------------------------------------------------------------------------ +// normal.cu + +void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p); +void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p); + +//------------------------------------------------------------------------ +// cubemap.cu + +void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p); +void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p); +void SpecularBoundsKernel(SpecularBoundsKernelParams p); +void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p); +void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p); + +//------------------------------------------------------------------------ +// bsdf.cu + +void LambertFwdKernel(LambertKernelParams p); +void LambertBwdKernel(LambertKernelParams p); + +void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p); +void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p); + +void FresnelShlickFwdKernel(FresnelShlickKernelParams p); +void FresnelShlickBwdKernel(FresnelShlickKernelParams p); + +void ndfGGXFwdKernel(NdfGGXParams p); +void ndfGGXBwdKernel(NdfGGXParams p); + +void lambdaGGXFwdKernel(NdfGGXParams p); +void lambdaGGXBwdKernel(NdfGGXParams p); + +void maskingSmithFwdKernel(MaskingSmithParams p); +void maskingSmithBwdKernel(MaskingSmithParams p); + +void pbrSpecularFwdKernel(PbrSpecular p); +void pbrSpecularBwdKernel(PbrSpecular p); + +void pbrBSDFFwdKernel(PbrBSDF p); +void pbrBSDFBwdKernel(PbrBSDF p); + +//------------------------------------------------------------------------ +// Tensor helpers + +void update_grid(dim3 &gridSize, torch::Tensor x) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); +} + +template +void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); + update_grid(gridSize, std::forward(vs)...); +} + +Tensor make_cuda_tensor(torch::Tensor val) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + return res; +} + +Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + if (val.dim() == 4) + res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3); + else + res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out + + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + if (grad != nullptr) + { + if (val.dim() == 4) + *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + else // 3 + *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + + res.d_val = res.fp16 ? (void*)grad->data_ptr() : (void*)grad->data_ptr(); + } + return res; +} + +//------------------------------------------------------------------------ +// prepare_shading_normal + +torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16) +{ + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(perturbed_nrm, 4, 3); + CHECK_TENSOR(smooth_nrm, 4, 3); + CHECK_TENSOR(smooth_tng, 4, 3); + CHECK_TENSOR(geom_nrm, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + p.out.fp16 = fp16; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.pos = make_cuda_tensor(pos, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad; + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad); +} + +//------------------------------------------------------------------------ +// lambert + +torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + update_grid(p.gridSize, nrm, wi); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad); +} + +//------------------------------------------------------------------------ +// frostbite diffuse + +torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(linearRoughness, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad, wo_grad, linearRoughness_grad); +} + +//------------------------------------------------------------------------ +// fresnel_shlick + +torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(f0, 4, 3); + CHECK_TENSOR(f90, 4, 3); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.f0 = make_cuda_tensor(f0, p.gridSize); + p.f90 = make_cuda_tensor(f90, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor f0_grad, f90_grad, cosT_grad; + p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad); + p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(f0_grad, f90_grad, cosT_grad); +} + +//------------------------------------------------------------------------ +// ndf_ggd + +torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// lambda_ggx + +torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// masking_smith + +torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosThetaI, 4, 1); + CHECK_TENSOR(cosThetaO, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad); +} + +//------------------------------------------------------------------------ +// pbr_specular + +torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16) +{ + CHECK_TENSOR(col, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(alpha, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.col = make_cuda_tensor(col, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.alpha = make_cuda_tensor(alpha, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + p.min_roughness = min_roughness; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad; + p.col = make_cuda_tensor(col, p.gridSize, &col_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad); +} + +//------------------------------------------------------------------------ +// pbr_bsdf + +torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16) +{ + CHECK_TENSOR(kd, 4, 3); + CHECK_TENSOR(arm, 4, 3); + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(light_pos, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + p.BSDF = BSDF; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.kd = make_cuda_tensor(kd, p.gridSize); + p.arm = make_cuda_tensor(arm, p.gridSize); + p.pos = make_cuda_tensor(pos, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + p.min_roughness = min_roughness; + p.BSDF = BSDF; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad; + p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad); + p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad); + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad); +} + +//------------------------------------------------------------------------ +// filter_cubemap + +torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap) +{ + CHECK_TENSOR(cubemap, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(grad, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +torch::Tensor specular_bounds(int resolution, float costheta_cutoff) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularBoundsKernelParams p; + p.costheta_cutoff = costheta_cutoff; + p.gridSize = dim3(resolution, resolution, 6); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +//------------------------------------------------------------------------ +// loss function + +LossType strToLoss(std::string str) +{ + if (str == "mse") + return LOSS_MSE; + else if (str == "relmse") + return LOSS_RELMSE; + else if (str == "smape") + return LOSS_SMAPE; + else + return LOSS_L1; +} + +torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16) +{ + CHECK_TENSOR(img, 4, 3); + CHECK_TENSOR(target, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.out.fp16 = fp16; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts); + + p.img = make_cuda_tensor(img, p.gridSize); + p.target = make_cuda_tensor(target, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor img_grad, target_grad; + p.img = make_cuda_tensor(img, p.gridSize, &img_grad); + p.target = make_cuda_tensor(target, p.gridSize, &target_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(img_grad, target_grad); +} + +//------------------------------------------------------------------------ +// transform function + +torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16) +{ + CHECK_TENSOR(points, 3, 3); + CHECK_TENSOR(matrix, 3, 4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.out.fp16 = fp16; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts); + + p.points = make_cuda_tensor(points, p.gridSize); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor points_grad; + p.points = make_cuda_tensor(points, p.gridSize, &points_grad); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream)); + + return points_grad; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd"); + m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd"); + m.def("lambert_fwd", &lambert_fwd, "lambert_fwd"); + m.def("lambert_bwd", &lambert_bwd, "lambert_bwd"); + m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd"); + m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd"); + m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd"); + m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd"); + m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd"); + m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd"); + m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd"); + m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd"); + m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd"); + m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd"); + m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd"); + m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd"); + m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd"); + m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd"); + m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd"); + m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd"); + m.def("specular_bounds", &specular_bounds, "specular_bounds"); + m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd"); + m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd"); + m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd"); + m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd"); + m.def("xfm_fwd", &xfm_fwd, "xfm_fwd"); + m.def("xfm_bwd", &xfm_bwd, "xfm_bwd"); +} \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/c_src/vec3f.h b/models/lrm/models/geometry/render/renderutils/c_src/vec3f.h new file mode 100644 index 0000000000000000000000000000000000000000..7e6745430f19e9fe1834c8cd3dfeb6e68d730297 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/vec3f.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec3f +{ + float x, y, z; + +#ifdef __CUDACC__ + __device__ vec3f() { } + __device__ vec3f(float v) { x = v; y = v; z = v; } + __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; } + __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; } + + __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; } + __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; } + __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; } + __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; } +#endif +}; + +#ifdef __CUDACC__ +__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); } +__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); } +__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); } +__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); } +__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); } + +__device__ static inline float sum(vec3f a) +{ + return a.x + a.y + a.z; +} + +__device__ static inline vec3f cross(vec3f a, vec3f b) +{ + vec3f out; + out.x = a.y * b.z - a.z * b.y; + out.y = a.z * b.x - a.x * b.z; + out.z = a.x * b.y - a.y * b.x; + return out; +} + +__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out) +{ + d_a.x += d_out.z * b.y - d_out.y * b.z; + d_a.y += d_out.x * b.z - d_out.z * b.x; + d_a.z += d_out.y * b.x - d_out.x * b.y; + + d_b.x += d_out.y * a.z - d_out.z * a.y; + d_b.y += d_out.z * a.x - d_out.x * a.z; + d_b.z += d_out.x * a.y - d_out.y * a.x; +} + +__device__ static inline float dot(vec3f a, vec3f b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out) +{ + d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; + d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; +} + +__device__ static inline vec3f reflect(vec3f x, vec3f n) +{ + return n * 2.0f * dot(n, x) - x; +} + +__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out) +{ + d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); + d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); + d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); + + d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); + d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); + d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); +} + +__device__ static inline vec3f safeNormalize(vec3f v) +{ + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + return l > 0.0f ? (v / l) : vec3f(0.0f); +} + +__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out) +{ + + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + if (l > 0.0f) + { + float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); + d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; + d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; + d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; + } +} + +#endif \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/c_src/vec4f.h b/models/lrm/models/geometry/render/renderutils/c_src/vec4f.h new file mode 100644 index 0000000000000000000000000000000000000000..e3f30776af334597475002275b8b40c584a05035 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/c_src/vec4f.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec4f +{ + float x, y, z, w; + +#ifdef __CUDACC__ + __device__ vec4f() { } + __device__ vec4f(float v) { x = v; y = v; z = v; w = v; } + __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; } + __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; } +#endif +}; + diff --git a/models/lrm/models/geometry/render/renderutils/loss.py b/models/lrm/models/geometry/render/renderutils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..92a24c02885380937762698eec578eb81bc80f9e --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/loss.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +#---------------------------------------------------------------------------- +# HDR image losses +#---------------------------------------------------------------------------- + +def _tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def _SMAPE(img, target, eps=0.01): + nom = torch.abs(img - target) + denom = torch.abs(img) + torch.abs(target) + 0.01 + return torch.mean(nom / denom) + +def _RELMSE(img, target, eps=0.1): + nom = (img - target) * (img - target) + denom = img * img + target * target + 0.1 + return torch.mean(nom / denom) + +def image_loss_fn(img, target, loss, tonemapper): + if tonemapper == 'log_srgb': + img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1)) + target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1)) + + if loss == 'mse': + return torch.nn.functional.mse_loss(img, target) + elif loss == 'smape': + return _SMAPE(img, target) + elif loss == 'relmse': + return _RELMSE(img, target) + else: + return torch.nn.functional.l1_loss(img, target) diff --git a/models/lrm/models/geometry/render/renderutils/ops.py b/models/lrm/models/geometry/render/renderutils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a27c72b2a57dac8d3f1a563d80661917b42d6ec9 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/ops.py @@ -0,0 +1,554 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import os +import sys +import torch +import torch.utils.cpp_extension + +from .bsdf import * +from .loss import * + +#---------------------------------------------------------------------------- +# C++/Cuda plugin compiler/loader. + +_cached_plugin = None +def _get_plugin(): + # Return cached plugin if already loaded. + global _cached_plugin + if _cached_plugin is not None: + return _cached_plugin + + # Make sure we can find the necessary compiler and libary binaries. + if os.name == 'nt': + def find_cl_path(): + import glob + for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: + paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ['PATH'] += ';' + cl_path + + # Compiler options. + opts = ['-DNVDR_TORCH'] + + # Linker options. + if os.name == 'posix': + ldflags = ['-lcuda', '-lnvrtc'] + elif os.name == 'nt': + ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] + + # List of sources. + source_files = [ + 'c_src/mesh.cu', + 'c_src/loss.cu', + 'c_src/bsdf.cu', + 'c_src/normal.cu', + 'c_src/cubemap.cu', + 'c_src/common.cpp', + 'c_src/torch_bindings.cpp' + ] + + # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. + try: + lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock') + if os.path.exists(lock_fn): + print("Warning: Lock file exists in build directory: '%s'" % lock_fn) + except: + pass + + # Compile and load. + # source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] + # torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts, + # extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True) + + # Import, cache, and return the compiled module. + import renderutils_plugin + _cached_plugin = renderutils_plugin + return _cached_plugin + +#---------------------------------------------------------------------------- +# Internal kernels, just used for testing functionality + +class _fresnel_shlick_func(torch.autograd.Function): + @staticmethod + def forward(ctx, f0, f90, cosTheta): + out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False) + ctx.save_for_backward(f0, f90, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + f0, f90, cosTheta = ctx.saved_variables + return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,) + +def _fresnel_shlick(f0, f90, cosTheta, use_python=False): + if use_python: + out = bsdf_fresnel_shlick(f0, f90, cosTheta) + else: + out = _fresnel_shlick_func.apply(f0, f90, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN" + return out + + +class _ndf_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _ndf_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_ndf_ggx(alphaSqr, cosTheta) + else: + out = _ndf_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN" + return out + +class _lambda_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _lambda_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_lambda_ggx(alphaSqr, cosTheta) + else: + out = _lambda_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN" + return out + +class _masking_smith_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosThetaI, cosThetaO): + ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO) + out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables + return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,) + +def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False): + if use_python: + out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO) + else: + out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Shading normal setup (bump mapping + bent normals) + +class _prepare_shading_normal_func(torch.autograd.Function): + @staticmethod + def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl + out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False) + ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm) + return out + + @staticmethod + def backward(ctx, dout): + pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables + return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None) + +def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False): + '''Takes care of all corner cases and produces a final normal used for shading: + - Constructs tangent space + - Flips normal direction based on geometric normal for two sided Shading + - Perturbs shading normal by normal map + - Bends backfacing normals towards the camera to avoid shading artifacts + + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + pos: World space g-buffer position. + view_pos: Camera position in world space (typically using broadcasting). + perturbed_nrm: Trangent-space normal perturbation from normal map lookup. + smooth_nrm: Interpolated vertex normals. + smooth_tng: Interpolated vertex tangents. + geom_nrm: Geometric (face) normals. + two_sided_shading: Use one/two sided shading + opengl: Use OpenGL/DirectX normal map conventions + use_python: Use PyTorch implementation (for validation) + Returns: + Final shading normal + ''' + + if perturbed_nrm is None: + perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...] + + if use_python: + out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + else: + out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# BSDF functions + +class _lambert_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi): + out = _get_plugin().lambert_fwd(nrm, wi, False) + ctx.save_for_backward(nrm, wi) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi = ctx.saved_variables + return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,) + +def lambert(nrm, wi, use_python=False): + '''Lambertian bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_lambert(nrm, wi) + else: + out = _lambert_func.apply(nrm, wi) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _frostbite_diffuse_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi, wo, linearRoughness): + out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False) + ctx.save_for_backward(nrm, wi, wo, linearRoughness) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi, wo, linearRoughness = ctx.saved_variables + return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,) + +def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False): + '''Frostbite, normalized Disney Diffuse bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + wo: World space camera vector. + linearRoughness: Material roughness + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_frostbite(nrm, wi, wo, linearRoughness) + else: + out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _pbr_specular_func(torch.autograd.Function): + @staticmethod + def forward(ctx, col, nrm, wo, wi, alpha, min_roughness): + ctx.save_for_backward(col, nrm, wo, wi, alpha) + ctx.min_roughness = min_roughness + out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False) + return out + + @staticmethod + def backward(ctx, dout): + col, nrm, wo, wi, alpha = ctx.saved_variables + return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None) + +def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False): + '''Physically-based specular bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + col: Specular lobe color + nrm: World space shading normal. + wo: World space camera vector. + wi: World space light vector + alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1] + min_roughness: Scalar roughness clamping threshold + + use_python: Use PyTorch implementation (for validation) + Returns: + Shaded specular color + ''' + + if use_python: + out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness) + else: + out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN" + return out + +class _pbr_bsdf_func(torch.autograd.Function): + @staticmethod + def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos) + ctx.min_roughness = min_roughness + ctx.BSDF = BSDF + out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False) + return out + + @staticmethod + def backward(ctx, dout): + kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables + return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None) + +def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False): + '''Physically-based bsdf, both diffuse & specular lobes + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + kd: Diffuse albedo. + arm: Specular parameters (attenuation, linear roughness, metalness). + pos: World space position. + nrm: World space shading normal. + view_pos: Camera position in world space, typically using broadcasting. + light_pos: Light position in world space, typically using broadcasting. + min_roughness: Scalar roughness clamping threshold + bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite' + + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded color. + ''' + + BSDF = 0 + if bsdf == 'frostbite': + BSDF = 1 + + if use_python: + out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + else: + out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# cubemap filter with filtering across edges + +class _diffuse_cubemap_func(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + out = _get_plugin().diffuse_cubemap_fwd(cubemap) + ctx.save_for_backward(cubemap) + return out + + @staticmethod + def backward(ctx, dout): + cubemap, = ctx.saved_variables + cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout) + return cubemap_grad, None + +def diffuse_cubemap(cubemap, use_python=False): + if use_python: + assert False + else: + out = _diffuse_cubemap_func.apply(cubemap) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN" + return out + +class _specular_cubemap(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap, roughness, costheta_cutoff, bounds): + out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff) + ctx.save_for_backward(cubemap, bounds) + ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff + return out + + @staticmethod + def backward(ctx, dout): + cubemap, bounds = ctx.saved_variables + cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff) + return cubemap_grad, None, None, None + +# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy +def __ndfBounds(res, roughness, cutoff): + def ndfGGX(alphaSqr, costheta): + costheta = np.clip(costheta, 0.0, 1.0) + d = (costheta * alphaSqr - costheta) * costheta + 1.0 + return alphaSqr / (d * d * np.pi) + + # Sample out cutoff angle + nSamples = 1000000 + costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples)) + D = np.cumsum(ndfGGX(roughness**4, costheta)) + idx = np.argmax(D >= D[..., -1] * cutoff) + + # Brute force compute lookup table with bounds + bounds = _get_plugin().specular_bounds(res, costheta[idx]) + + return costheta[idx], bounds +__ndfBoundsDict = {} + +def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False): + assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape) + + if use_python: + assert False + else: + key = (cubemap.shape[1], roughness, cutoff) + if key not in __ndfBoundsDict: + __ndfBoundsDict[key] = __ndfBounds(*key) + out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key]) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN" + return out[..., 0:3] / out[..., 3:] + +#---------------------------------------------------------------------------- +# Fast image loss function + +class _image_loss_func(torch.autograd.Function): + @staticmethod + def forward(ctx, img, target, loss, tonemapper): + ctx.loss, ctx.tonemapper = loss, tonemapper + ctx.save_for_backward(img, target) + out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False) + return out + + @staticmethod + def backward(ctx, dout): + img, target = ctx.saved_variables + return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None) + +def image_loss(img, target, loss='l1', tonemapper='none', use_python=False): + '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + img: Input image. + target: Target (reference) image. + loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse'] + tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb'] + use_python: Use PyTorch implementation (for validation) + + Returns: + Image space loss (scalar value). + ''' + if use_python: + out = image_loss_fn(img, target, loss, tonemapper) + else: + out = _image_loss_func.apply(img, target, loss, tonemapper) + out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2]) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Transform points function + +class _xfm_func(torch.autograd.Function): + @staticmethod + def forward(ctx, points, matrix, isPoints): + ctx.save_for_backward(points, matrix) + ctx.isPoints = isPoints + return _get_plugin().xfm_fwd(points, matrix, isPoints, False) + + @staticmethod + def backward(ctx, dout): + points, matrix = ctx.saved_variables + return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None) + +def xfm_points(points, matrix, use_python=False): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + if use_python: + out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + else: + out = _xfm_func.apply(points, matrix, True) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + +def xfm_vectors(vectors, matrix, use_python=False): + '''Transform vectors. + Args: + vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + + Returns: + Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + + if use_python: + out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous() + else: + out = _xfm_func.apply(vectors, matrix, False) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN" + return out + + + diff --git a/models/lrm/models/geometry/render/renderutils/tests/test_bsdf.py b/models/lrm/models/geometry/render/renderutils/tests/test_bsdf.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b60c350455717826c0f3edb01289b29baac27a --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/tests/test_bsdf.py @@ -0,0 +1,296 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 4 +DTYPE = torch.float32 + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_normal(): + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True) + perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True) + smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True) + smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True) + geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" bent normal") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad) + relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad) + relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad) + relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad) + relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad) + +def test_schlick(): + f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f0_ref = f0_cuda.clone().detach().requires_grad_(True) + f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f90_ref = f90_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Fresnel shlick") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("f0:", f0_ref.grad, f0_cuda.grad) + relative_loss("f90:", f90_ref.grad, f90_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_ndf_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Ndf GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_lambda_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambda GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_masking_smith(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True) + cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Smith masking term") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad) + relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad) + +def test_lambert(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.lambert(normals_ref, wi_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.lambert(normals_cuda, wi_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambert") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + +def test_frostbite(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + rough_ref = rough_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Frostbite") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + relative_loss("rough:", rough_ref.grad, rough_cuda.grad) + +def test_pbr_specular(): + col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + col_ref = col_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alpha_ref = alpha_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr specular") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if col_ref.grad is not None: + relative_loss("col:", col_ref.grad, col_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if wi_ref.grad is not None: + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + if wo_ref.grad is not None: + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + if alpha_ref.grad is not None: + relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad) + +def test_pbr_bsdf(bsdf): + kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr BSDF") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if kd_ref.grad is not None: + relative_loss("kd:", kd_ref.grad, kd_cuda.grad) + if arm_ref.grad is not None: + relative_loss("arm:", arm_ref.grad, arm_cuda.grad) + if pos_ref.grad is not None: + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if view_ref.grad is not None: + relative_loss("view:", view_ref.grad, view_cuda.grad) + if light_ref.grad is not None: + relative_loss("light:", light_ref.grad, light_cuda.grad) + +test_normal() + +test_schlick() +test_ndf_ggx() +test_lambda_ggx() +test_masking_smith() + +test_lambert() +test_frostbite() +test_pbr_specular() +test_pbr_bsdf('lambert') +test_pbr_bsdf('frostbite') diff --git a/models/lrm/models/geometry/render/renderutils/tests/test_loss.py b/models/lrm/models/geometry/render/renderutils/tests/test_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7a68f3fc4528431fe405d1d6077af0cb31687d31 --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/tests/test_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 8 +DTYPE = torch.float32 + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_loss(loss, tonemapper): + img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + img_ref = img_cuda.clone().detach().requires_grad_(True) + target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + target_ref = target_cuda.clone().detach().requires_grad_(True) + + ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True) + ref_loss.backward() + + cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Loss: %s, %s" % (loss, tonemapper)) + print("-------------------------------------------------------------") + + relative_loss("res:", ref_loss, cuda_loss) + relative_loss("img:", img_ref.grad, img_cuda.grad) + relative_loss("target:", target_ref.grad, target_cuda.grad) + + +test_loss('l1', 'none') +test_loss('l1', 'log_srgb') +test_loss('mse', 'log_srgb') +test_loss('smape', 'none') +test_loss('relmse', 'none') +test_loss('mse', 'none') \ No newline at end of file diff --git a/models/lrm/models/geometry/render/renderutils/tests/test_mesh.py b/models/lrm/models/geometry/render/renderutils/tests/test_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..4856c5ce07e2d6cd5f1fd463c1d3628791eafccc --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/tests/test_mesh.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +BATCH = 8 +RES = 1024 +DTYPE = torch.float32 + +torch.manual_seed(0) + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item()) + +def test_xfm_points(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target) + ref_loss.backward() + + cuda_out = ru.xfm_points(points_cuda, mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + +def test_xfm_vectors(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + points_cuda_p = points_cuda.clone().detach().requires_grad_(True) + points_ref_p = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3]) + ref_loss.backward() + + cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3]) + cuda_loss.backward() + + ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True) + ref_loss_p = torch.nn.MSELoss()(ref_out_p, target) + ref_loss_p.backward() + + cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda) + cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target) + cuda_loss_p.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad) + +test_xfm_points() +test_xfm_vectors() diff --git a/models/lrm/models/geometry/render/renderutils/tests/test_perf.py b/models/lrm/models/geometry/render/renderutils/tests/test_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc143e3004c0fd0a42a1941896823bc2bef939a --- /dev/null +++ b/models/lrm/models/geometry/render/renderutils/tests/test_perf.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +DTYPE=torch.float32 + +def test_bsdf(BATCH, RES, ITR): + kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, RES, 3, device='cuda') + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + + print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES)) + + start.record() + for i in range(ITR): + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF python:", start.elapsed_time(end)) + + start.record() + for i in range(ITR): + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF cuda:", start.elapsed_time(end)) + +test_bsdf(1, 512, 1000) +test_bsdf(16, 512, 1000) +test_bsdf(1, 2048, 1000) diff --git a/models/lrm/models/geometry/render/util.py b/models/lrm/models/geometry/render/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e292e91cf1cdd4b05b46f2f18b8a2bb14d2165ba --- /dev/null +++ b/models/lrm/models/geometry/render/util.py @@ -0,0 +1,465 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio + +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + diff --git a/models/lrm/models/geometry/rep_3d/__init__.py b/models/lrm/models/geometry/rep_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d5628a8433298477d1963f92578d47106b4a0f --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np + + +class Geometry(): + def __init__(self): + pass + + def forward(self): + pass diff --git a/models/lrm/models/geometry/rep_3d/dmtet.py b/models/lrm/models/geometry/rep_3d/dmtet.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a709380abac0bbf66fd1c8582485f3982223e4 --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/dmtet.py @@ -0,0 +1,504 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .dmtet_utils import get_center_boundary_index +import torch.nn.functional as F + + +############################################################################### +# DMTet utility functions +############################################################################### +def create_mt_variable(device): + triangle_table = torch.tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device=device) + + num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) + base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) + return triangle_table, num_triangles_table, base_tet_edges, v_id + + +def sort_edges(edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + return torch.stack([a, b], -1) + + +############################################################################### +# marching tetrahedrons (differentiable) +############################################################################### + +def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + return verts, faces + + +def create_tetmesh_variables(device='cuda'): + tet_table = torch.tensor( + [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], + [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], + [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], + [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], + [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], + [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], + [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], + [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], + [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], + [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) + num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) + return tet_table, num_tets_table + + +def marching_tets_tetmesh( + pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, + return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + if not return_tet_mesh: + return verts, faces + occupied_verts = ori_v[occ_n] + mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") + tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) + + idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 + tet_verts = torch.cat([verts, occupied_verts], 0) + num_tets = num_tets_table[tetindex] + + tets = torch.cat( + ( + torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( + -1, + 4), + torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( + -1, + 4), + ), dim=0) + # add fully occupied tets + fully_occupied = occ_fx4.sum(-1) == 4 + tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] + tets = torch.cat([tets, tet_fully_occupied]) + + return verts, faces, tet_verts, tets + + +############################################################################### +# Compact tet grid +############################################################################### + +def compact_tets(pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + # Find surface tets + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets + + valid_vtx = tet_fx4[valid_tets].reshape(-1) + unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) + new_pos = pos_nx3[unique_vtx] + new_sdf = sdf_n[unique_vtx] + new_tets = idx_map.reshape(-1, 4) + return new_pos, new_sdf, new_tets + + +############################################################################### +# Subdivide volume +############################################################################### + +def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + device = tet_pos_bxnx3.device + # get new verts + tet_fx4 = tet_bxfx4[0] + edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] + all_edges = tet_fx4[:, edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + idx_map = idx_map + tet_pos_bxnx3.shape[1] + all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) + mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( + all_values.shape[0], -1, 2, + all_values.shape[-1]).mean(2) + new_v = torch.cat([all_values, mid_points_pos], 1) + new_v, new_sdf = new_v[..., :3], new_v[..., 3] + + # get new tets + + idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] + idx_ab = idx_map[0::6] + idx_ac = idx_map[1::6] + idx_ad = idx_map[2::6] + idx_bc = idx_map[3::6] + idx_bd = idx_map[4::6] + idx_cd = idx_map[5::6] + + tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) + tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) + tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) + tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) + tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) + tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) + tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) + tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) + + tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) + tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) + tet = tet_np.long().to(device) + + return new_v, tet, new_sdf + + +############################################################################### +# Adjacency +############################################################################### +def tet_to_tet_adj_sparse(tet_tx4): + # include self connection!!!!!!!!!!!!!!!!!!! + with torch.no_grad(): + t = tet_tx4.shape[0] + device = tet_tx4.device + idx_array = torch.LongTensor( + [0, 1, 2, + 1, 0, 3, + 2, 3, 0, + 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) + + # get all faces + all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( + -1, + 3) # (tx4, 3) + all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) + # sort and group + all_faces_sorted, _ = torch.sort(all_faces, dim=1) + + all_faces_unique, inverse_indices, counts = torch.unique( + all_faces_sorted, dim=0, return_counts=True, + return_inverse=True) + tet_face_fx3 = all_faces_unique[counts == 2] + counts = counts[inverse_indices] # tx4 + valid = (counts == 2) + + group = inverse_indices[valid] + # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) + _, indices = torch.sort(group) + all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] + tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) + + tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) + adj_self = torch.arange(t, device=tet_tx4.device) + adj_self = torch.stack([adj_self, adj_self], -1) + tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) + + tet_adj_idx = torch.unique(tet_adj_idx, dim=0) + values = torch.ones( + tet_adj_idx.shape[0], device=tet_tx4.device).float() + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + + # normalization + neighbor_num = 1.0 / torch.sparse.sum( + adj_sparse, dim=1).to_dense() + values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + return adj_sparse + + +############################################################################### +# Compact grid +############################################################################### + +def get_tet_bxfx4x3(bxnxz, bxfx4): + n_batch, z = bxnxz.shape[0], bxnxz.shape[2] + gather_input = bxnxz.unsqueeze(2).expand( + n_batch, bxnxz.shape[1], 4, z) + gather_index = bxfx4.unsqueeze(-1).expand( + n_batch, bxfx4.shape[1], 4, z).long() + tet_bxfx4xz = torch.gather( + input=gather_input, dim=1, index=gather_index) + + return tet_bxfx4xz + + +def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + with torch.no_grad(): + assert tet_pos_bxnx3.shape[0] == 1 + + occ = grid_sdf[0] > 0 + occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) + mask = (occ_sum > 0) & (occ_sum < 4) + + # build connectivity graph + adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) + mask = mask.float().unsqueeze(-1) + + # Include a one ring of neighbors + for i in range(1): + mask = torch.sparse.mm(adj_matrix, mask) + mask = mask.squeeze(-1) > 0 + + mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) + new_tet_bxfx4 = tet_bxfx4[:, mask].long() + selected_verts_idx = torch.unique(new_tet_bxfx4) + new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] + mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) + new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) + new_grid_sdf = grid_sdf[:, selected_verts_idx] + return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf + + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], + (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], + (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +def sdf_reg_loss_batch(sdf, all_edges): + sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +############################################################################### +# Geometry interface +############################################################################### +class DMTetGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(DMTetGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + tets = np.load('data/tets/%d_compress.npz' % (grid_res)) + self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) + # Make sure the tet is zero-centered and length is equal to 1 + length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] + length = length.max() + mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 + self.verts = (self.verts - mid.unsqueeze(dim=0)) / length + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + self.indices = torch.from_numpy(tets['tets']).long().to(self.device) + self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) + self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) + # Parameters for regularization computation + edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) + all_edges = self.indices[:, edges].reshape(-1, 2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces = marching_tets( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces + + def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces, tet_verts, tets = marching_tets_tetmesh( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, + num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces, tet_verts, tets + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/models/lrm/models/geometry/rep_3d/dmtet_utils.py b/models/lrm/models/geometry/rep_3d/dmtet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d466a9e78c49d947c115707693aa18d759885ad --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/dmtet_utils.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch + + +def get_center_boundary_index(verts): + length_ = torch.sum(verts ** 2, dim=-1) + center_idx = torch.argmin(length_) + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + boundary_idx = torch.nonzero(boundary) + return center_idx, boundary_idx.squeeze(dim=-1) diff --git a/models/lrm/models/geometry/rep_3d/extract_texture_map.py b/models/lrm/models/geometry/rep_3d/extract_texture_map.py new file mode 100644 index 0000000000000000000000000000000000000000..aadea1f018fc00b1824e2d498f0c59504de3298f --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/extract_texture_map.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import numpy as np +import nvdiffrast.torch as dr + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(int) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/models/lrm/models/geometry/rep_3d/flexicubes.py b/models/lrm/models/geometry/rep_3d/flexicubes.py new file mode 100644 index 0000000000000000000000000000000000000000..26d7b91b6266d802baaf55b64238629cd0f740d0 --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/flexicubes.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + device (str): Specifies the computational device (default is "cuda"). + dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners (torch.Tensor): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (torch.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + """ + + def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + def construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = torch.arange(8).to(self.device) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = torch.ones(res, device=self.device) + + res = torch.tensor([res], dtype=torch.float, device=self.device) + coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 + verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) + cubes = (base_cube_f.unsqueeze(0) + + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) + + verts_rounded = torch.round(verts * 10**5) / (10**5) + verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes + + def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, + gamma_f=None, training=False, output_tetmesh=False, grad_func=None): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return torch.zeros( + (0, 3), + device=self.device), torch.zeros( + (0, 4), + dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( + (0, 3), + dtype=torch.long, device=self.device), torch.zeros( + (0), + device=self.device) + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) + else: + beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha_fx8 is not None: + alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) + else: + alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + if grad_func is not None: + normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) + vd = [] + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + + if grad_func is not None: + with torch.no_grad(): + cube_e_verts_idx = idx_map[cur_cubes] + curr_edge_group[~curr_mask] = 0 + + verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) + verts_group_idx[verts_group_idx == -1] = 0 + verts_group_pos = torch.index_select( + input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) + v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) + curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) + verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) + + normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( + -1, num.item(), 7, + 3) + curr_mask = curr_mask.squeeze(2) + vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, + verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + + if grad_func is not None: + vd = torch.cat(vd) + L_dev = torch.zeros([1], device=self.device) + else: + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + with torch.no_grad(): + vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( + 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) + gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( + 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 + vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / + weight_sum.unsqueeze(-1)).squeeze(1) + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = torch.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ + s_edges < 0]] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) + + tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = (occ_sum == 8) + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = torch.arange( + inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), + dtype=torch.long, device=x_nx3.device) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = torch.sort(group) + cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, + device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] + edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( + 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] + + tets = torch.cat([tets_surface, tets_inside]) + vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets diff --git a/models/lrm/models/geometry/rep_3d/flexicubes_geometry.py b/models/lrm/models/geometry/rep_3d/flexicubes_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..14af06fde64bc370dc6c6e79b458aa52758b196e --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/flexicubes_geometry.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +import nvdiffrast.torch as dr +from . import Geometry +from .flexicubes import FlexiCubes # replace later +from .dmtet import sdf_reg_loss_batch +from . import mesh +import torch.nn.functional as F +from models.lrm.utils import render + +def get_center_boundary_index(grid_res, device): + v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True + center_indices = torch.nonzero(v.reshape(-1)) + + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = torch.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(FlexiCubesGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + self.fc = FlexiCubes(device, weight_scale=0.5) + self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + + all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + self.all_edges = torch.unique(all_edges, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) + self.renderer = renderer + self.render_type = render_type + self.ctx = dr.RasterizeCudaContext(device=device) + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + @torch.no_grad() + def map_uv(self, face_gidx, max_idx): + N = int(np.ceil(np.sqrt((max_idx+1)//2))) + tex_y, tex_x = torch.meshgrid( + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda") + ) + + pad = 0.9 / N + + uvs = torch.stack([ + tex_x , tex_y, + tex_x + pad, tex_y, + tex_x + pad, tex_y + pad, + tex_x , tex_y + pad + ], dim=-1).view(-1, 2) + + def _idx(tet_idx, N): + x = tet_idx % N + y = torch.div(tet_idx, N, rounding_mode='floor') + return y * N + x + + tet_idx = _idx(torch.div(face_gidx, N, rounding_mode='floor'), N) + tri_idx = face_gidx % 2 + + uv_idx = torch.stack(( + tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 + ), dim = -1). view(-1, 3) + + return uvs, uv_idx + + def rotate_x(self, a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + def rotate_z(self, a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, -s, 0, 0], + [ s, c, 0, 0], + [ 0, 0, 1, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + def rotate_y(self, a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, + beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], training=is_training + ) + + face_gidx = torch.arange(faces.shape[0], dtype=torch.long, device="cuda") + uvs, uv_idx = self.map_uv(face_gidx, faces.shape[0]) + # breakpoint() + + verts = verts @ self.rotate_x(np.pi / 2, device=verts.device)[:3,:3] + verts = verts @ self.rotate_y(np.pi / 2, device=verts.device)[:3,:3] + + + + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx) + + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + + return verts, faces, v_reg_loss, imesh + + + # def render_mesh(self, mesh_v_nx3, mesh_f_fx3, mesh, camera_mv_bx4x4, camera_pos, resolution=256, hierarchical_mask=False): + # return_value = dict() + # if self.render_type == 'neural_render': + # tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal = self.renderer.render_mesh( + # mesh_v_nx3.unsqueeze(dim=0), + # mesh_f_fx3.int(), + # mesh, + # camera_mv_bx4x4, + # camera_pos, + # mesh_v_nx3.unsqueeze(dim=0), + # resolution=resolution, + # device=self.device, + # hierarchical_mask=hierarchical_mask + # ) + + # return_value['tex_pos'] = tex_pos + # return_value['mask'] = mask + # return_value['hard_mask'] = hard_mask + # return_value['rast'] = rast + # return_value['v_pos_clip'] = v_pos_clip + # return_value['mask_pyramid'] = mask_pyramid + # return_value['depth'] = depth + # return_value['normal'] = normal + # return_value['gb_normal'] = gb_normal + # else: + # raise NotImplementedError + + # return return_value + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, mesh, camera_mv_bx4x4, camera_pos, env, planes, kd_fn, materials, resolution=256, hierarchical_mask=False): + return_value = dict() + # if self.render_type == 'neural_render': + # tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal = self.renderer.render_mesh( + # mesh_v_nx3.unsqueeze(dim=0), + # mesh_f_fx3.int(), + # mesh, + # camera_mv_bx4x4, + # camera_pos, + # mesh_v_nx3.unsqueeze(dim=0), + # resolution=resolution, + # device=self.device, + # hierarchical_mask=hierarchical_mask + # ) + + # return_value['tex_pos'] = tex_pos + # return_value['mask'] = mask + # return_value['hard_mask'] = hard_mask + # return_value['rast'] = rast + # return_value['v_pos_clip'] = v_pos_clip + # return_value['mask_pyramid'] = mask_pyramid + # return_value['depth'] = depth + # return_value['normal'] = normal + # return_value['gb_normal'] = gb_normal + # else: + # raise NotImplementedError + buffer_dict = render.render_mesh(self.ctx, mesh, camera_mv_bx4x4, camera_pos, env, planes, kd_fn, materials, [resolution, resolution], spp=1, num_layers=1, msaa=True, background=None) + + return buffer_dict + + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/models/lrm/models/geometry/rep_3d/light.py b/models/lrm/models/geometry/rep_3d/light.py new file mode 100644 index 0000000000000000000000000000000000000000..766ab0a9e4e4fc42f379ac94d765059508cff97e --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/light.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util +from . import renderutils as ru + +###################################################################################### +# Utility functions +###################################################################################### + +class cubemap_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + return util.avg_pool_nhwc(cubemap, (2,2)) + + @staticmethod + def backward(ctx, dout): + res = dout.shape[1] * 2 + out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda") + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + indexing='ij') + v = util.safe_normalize(util.cube_to_dir(s, gx, gy)) + out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') + return out + +###################################################################################### +# Split-sum environment map light source with automatic mipmap generation +###################################################################################### + +class EnvironmentLight(torch.nn.Module): + LIGHT_MIN_RES = 16 + + MIN_ROUGHNESS = 0.08 + MAX_ROUGHNESS = 0.5 + + def __init__(self, base): + super(EnvironmentLight, self).__init__() + self.mtx = None + self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True) + self.register_parameter('env_base', self.base) + + def xfm(self, mtx): + self.mtx = mtx + + def clone(self): + return EnvironmentLight(self.base.clone().detach()) + + def clamp_(self, min=None, max=None): + self.base.clamp_(min, max) + + def get_mip(self, roughness): + return torch.where(roughness < self.MAX_ROUGHNESS + , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2) + , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2) + + def build_mips(self, cutoff=0.99): + self.specular = [self.base] + while self.specular[-1].shape[1] > self.LIGHT_MIN_RES: + self.specular += [cubemap_mip.apply(self.specular[-1])] + + self.diffuse = ru.diffuse_cubemap(self.specular[-1]) + + for idx in range(len(self.specular) - 1): + roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS + self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff) + self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff) + + def regularizer(self): + white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0 + return torch.mean(torch.abs(self.base - white)) + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + wo = util.safe_normalize(view_pos - gb_pos) + + if specular: + roughness = ks[..., 1:2] # y component + metallic = ks[..., 2:3] # z component + spec_col = (1.0 - metallic)*0.04 + kd * metallic + diff_col = kd * (1.0 - metallic) + else: + diff_col = kd + + reflvec = util.safe_normalize(util.reflect(wo, gb_normal)) + nrmvec = gb_normal + if self.mtx is not None: # Rotate lookup + mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') + reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) + nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) + + # Diffuse lookup + diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube') + shaded_col = diffuse * diff_col + + if specular: + # Lookup FG term from lookup texture + NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4) + fg_uv = torch.cat((NdotV, roughness), dim=-1) + if not hasattr(self, '_FG_LUT'): + self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') + fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp') + + # Roughness adjusted specular env lookup + miplevel = self.get_mip(roughness) + spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') + + # Compute aggregate lighting + reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2] + shaded_col += spec * reflectance + + return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility + +###################################################################################### +# Load and store +###################################################################################### + +# Load from latlong .HDR file +def _load_env_hdr(fn, scale=1.0): + latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale + cubemap = util.latlong_to_cubemap(latlong_img, [512, 512]) + + l = EnvironmentLight(cubemap) + l.build_mips() + + return l + +def load_env(fn, scale=1.0): + if os.path.splitext(fn)[1].lower() == ".hdr": + return _load_env_hdr(fn, scale) + else: + assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1] + +def save_env_map(fn, light): + assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently" + if isinstance(light, EnvironmentLight): + color = util.cubemap_to_latlong(light.base, [512, 1024]) + util.save_image_raw(fn, color.detach().cpu().numpy()) + +###################################################################################### +# Create trainable env map with random initialization +###################################################################################### + +def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): + base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias + return EnvironmentLight(base) + diff --git a/models/lrm/models/geometry/rep_3d/material.py b/models/lrm/models/geometry/rep_3d/material.py new file mode 100644 index 0000000000000000000000000000000000000000..64772e578493f41e5c94e432d906d9be23325221 --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/material.py @@ -0,0 +1,182 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import util +from . import texture + +###################################################################################### +# Wrapper to make materials behave like a python dict, but register textures as +# torch.nn.Module parameters. +###################################################################################### +class Material(torch.nn.Module): + def __init__(self, mat_dict): + super(Material, self).__init__() + self.mat_keys = set() + for key in mat_dict.keys(): + self.mat_keys.add(key) + self[key] = mat_dict[key] + + def __contains__(self, key): + return hasattr(self, key) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, val): + self.mat_keys.add(key) + setattr(self, key, val) + + def __delitem__(self, key): + self.mat_keys.remove(key) + delattr(self, key) + + def keys(self): + return self.mat_keys + +###################################################################################### +# .mtl material format loading / storing +###################################################################################### +@torch.no_grad() +def load_mtl(fn, clear_ks=True): + import re + mtl_path = os.path.dirname(fn) + + # Read file + with open(fn, 'r') as f: + lines = f.readlines() + + # Parse materials + materials = [] + for line in lines: + split_line = re.split(' +|\t+|\n+', line.strip()) + prefix = split_line[0].lower() + data = split_line[1:] + if 'newmtl' in prefix: + material = Material({'name' : data[0]}) + materials += [material] + elif materials: + if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: + material[prefix] = data[0] + else: + material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + + # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps + for mat in materials: + if not 'bsdf' in mat: + mat['bsdf'] = 'pbr' + + if 'map_kd' in mat: + mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) + else: + mat['kd'] = texture.Texture2D(mat['kd']) + + if 'map_ks' in mat: + mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) + else: + mat['ks'] = texture.Texture2D(mat['ks']) + + if 'bump' in mat: + mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) + + # Convert Kd from sRGB to linear RGB + mat['kd'] = texture.srgb_to_rgb(mat['kd']) + + if clear_ks: + # Override ORM occlusion (red) channel by zeros. We hijack this channel + for mip in mat['ks'].getMips(): + mip[..., 0] = 0.0 + + return materials + +@torch.no_grad() +def save_mtl(fn, material): + folder = os.path.dirname(fn) + with open(fn, "w") as f: + f.write('newmtl defaultMat\n') + if material is not None: + f.write('bsdf %s\n' % material['bsdf']) + if 'kd' in material.keys(): + f.write('map_Kd texture_kd.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) + if 'ks' in material.keys(): + f.write('map_Ks texture_ks.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) + if 'normal' in material.keys(): + f.write('bump texture_n.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) + else: + f.write('Kd 1 1 1\n') + f.write('Ks 0 0 0\n') + f.write('Ka 0 0 0\n') + f.write('Tf 1 1 1\n') + f.write('Ni 1\n') + f.write('Ns 0\n') + +###################################################################################### +# Merge multiple materials into a single uber-material +###################################################################################### + +def _upscale_replicate(x, full_res): + x = x.permute(0, 3, 1, 2) + x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') + return x.permute(0, 2, 3, 1).contiguous() + +def merge_materials(materials, texcoords, tfaces, mfaces): + assert len(materials) > 0 + for mat in materials: + assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" + assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" + + uber_material = Material({ + 'name' : 'uber_material', + 'bsdf' : materials[0]['bsdf'], + }) + + textures = ['kd', 'ks', 'normal'] + + # Find maximum texture resolution across all materials and textures + max_res = None + for mat in materials: + for tex in textures: + tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) + max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res + + # Compute size of compund texture and round up to nearest PoT + full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int) + + # Normalize texture resolution across all materials & combine into a single large texture + for tex in textures: + if tex in materials[0]: + tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x + tex_data = _upscale_replicate(tex_data, full_res) + uber_material[tex] = texture.Texture2D(tex_data) + + # Compute scaling values for used / unused texture area + s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] + + # Recompute texture coordinates to cooincide with new composite texture + new_tverts = {} + new_tverts_data = [] + for fi in range(len(tfaces)): + matIdx = mfaces[fi] + for vi in range(3): + ti = tfaces[fi][vi] + if not (ti in new_tverts): + new_tverts[ti] = {} + if not (matIdx in new_tverts[ti]): # create new vertex + new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex + + return uber_material, new_tverts_data, tfaces + diff --git a/models/lrm/models/geometry/rep_3d/mesh.py b/models/lrm/models/geometry/rep_3d/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2009b8b938dc251586fbd665bff716f11cf9616b --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/mesh.py @@ -0,0 +1,238 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import obj +from . import util + +###################################################################################### +# Base mesh class +###################################################################################### +class Mesh: + def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, material=None, base=None): + self.v_pos = v_pos + self.v_nrm = v_nrm + self.v_tex = v_tex + self.v_tng = v_tng + self.t_pos_idx = t_pos_idx + self.t_nrm_idx = t_nrm_idx + self.t_tex_idx = t_tex_idx + self.t_tng_idx = t_tng_idx + self.material = material + + if base is not None: + self.copy_none(base) + + def copy_none(self, other): + if self.v_pos is None: + self.v_pos = other.v_pos + if self.t_pos_idx is None: + self.t_pos_idx = other.t_pos_idx + if self.v_nrm is None: + self.v_nrm = other.v_nrm + if self.t_nrm_idx is None: + self.t_nrm_idx = other.t_nrm_idx + if self.v_tex is None: + self.v_tex = other.v_tex + if self.t_tex_idx is None: + self.t_tex_idx = other.t_tex_idx + if self.v_tng is None: + self.v_tng = other.v_tng + if self.t_tng_idx is None: + self.t_tng_idx = other.t_tng_idx + if self.material is None: + self.material = other.material + + def clone(self): + out = Mesh(base=self) + if out.v_pos is not None: + out.v_pos = out.v_pos.clone().detach() + if out.t_pos_idx is not None: + out.t_pos_idx = out.t_pos_idx.clone().detach() + if out.v_nrm is not None: + out.v_nrm = out.v_nrm.clone().detach() + if out.t_nrm_idx is not None: + out.t_nrm_idx = out.t_nrm_idx.clone().detach() + if out.v_tex is not None: + out.v_tex = out.v_tex.clone().detach() + if out.t_tex_idx is not None: + out.t_tex_idx = out.t_tex_idx.clone().detach() + if out.v_tng is not None: + out.v_tng = out.v_tng.clone().detach() + if out.t_tng_idx is not None: + out.t_tng_idx = out.t_tng_idx.clone().detach() + return out + +###################################################################################### +# Mesh loeading helper +###################################################################################### + +def load_mesh(filename, mtl_override=None): + name, ext = os.path.splitext(filename) + if ext == ".obj": + return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override) + assert False, "Invalid mesh file extension" + +###################################################################################### +# Compute AABB +###################################################################################### +def aabb(mesh): + return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values + +###################################################################################### +# Compute unique edge list from attribute/vertex index list +###################################################################################### +def compute_edges(attr_idx, return_inverse=False): + with torch.no_grad(): + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Eliminate duplicates and return inverse mapping + return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) + +###################################################################################### +# Compute unique edge to face mapping from attribute/vertex index list +###################################################################################### +def compute_edge_to_face_mapping(attr_idx, return_inverse=False): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +###################################################################################### +# Align base mesh to reference mesh:move & rescale to match bounding boxes. +###################################################################################### +def unit_size(mesh): + with torch.no_grad(): + vmin, vmax = aabb(mesh) + scale = 2 / torch.max(vmax - vmin).item() + v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin + v_pos = v_pos * scale # Rescale to unit size + + return Mesh(v_pos, base=mesh) + +###################################################################################### +# Center & scale mesh for rendering +###################################################################################### +def center_by_reference(base_mesh, ref_aabb, scale): + center = (ref_aabb[0] + ref_aabb[1]) * 0.5 + scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() + v_pos = (base_mesh.v_pos - center[None, ...]) * scale + return Mesh(v_pos, base=base_mesh) + +###################################################################################### +# Simple smooth vertex normal computation +###################################################################################### +def auto_normals(imesh): + + i0 = imesh.t_pos_idx[:, 0] + i1 = imesh.t_pos_idx[:, 1] + i2 = imesh.t_pos_idx[:, 2] + + v0 = imesh.v_pos[i0, :] + v1 = imesh.v_pos[i1, :] + v2 = imesh.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(imesh.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + v_nrm = util.safe_normalize(v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh) + +###################################################################################### +# Compute tangent space from texture map coordinates +# Follows http://www.mikktspace.com/ conventions +###################################################################################### +def compute_tangents(imesh): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0,3): + pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]] + tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]] + vn_idx[i] = imesh.t_nrm_idx[:, i] + + tangents = torch.zeros_like(imesh.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) + denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) + + # Update all 3 vertices + for i in range(0,3): + idx = vn_idx[i][:, None].repeat(1,3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + + # Normalize and make sure tangent is perpendicular to normal + tangents = util.safe_normalize(tangents) + tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) diff --git a/models/lrm/models/geometry/rep_3d/obj.py b/models/lrm/models/geometry/rep_3d/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..a33fbb9e66c69706ad39049e2ea8e5a7c425971c --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/obj.py @@ -0,0 +1,176 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import torch + +from . import texture +from . import mesh +from . import material + +###################################################################################### +# Utility functions +###################################################################################### + +def _find_mat(materials, name): + for mat in materials: + if mat['name'] == name: + return mat + return materials[0] # Materials 0 is the default + +###################################################################################### +# Create mesh object from objfile +###################################################################################### + +def load_obj(filename, clear_ks=True, mtl_override=None): + obj_path = os.path.dirname(filename) + + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # Load materials + all_materials = [ + { + 'name' : '_default_mat', + 'bsdf' : 'pbr', + 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), + 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) + } + ] + if mtl_override is None: + for line in lines: + if len(line.split()) == 0: + continue + if line.split()[0] == 'mtllib': + all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library + else: + all_materials += material.load_mtl(mtl_override) + + # load vertices + vertices, texcoords, normals = [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'vn': + normals.append([float(v) for v in line.split()[1:]]) + + # load faces + activeMatIdx = None + used_materials = [] + faces, tfaces, nfaces, mfaces = [], [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + mat = _find_mat(all_materials, line.split()[1]) + if not mat in used_materials: + used_materials.append(mat) + activeMatIdx = used_materials.index(mat) + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + n0 = int(vv[2]) - 1 if vv[2] != "" else -1 + for i in range(nv - 2): # Triangulate polygons + vv = vs[i + 1].split('/') + v1 = int(vv[0]) - 1 + t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + vv = vs[i + 2].split('/') + v2 = int(vv[0]) - 1 + t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + mfaces.append(activeMatIdx) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + assert len(tfaces) == len(faces) and len(nfaces) == len (faces) + + # Create an "uber" material by combining all textures into a larger texture + if len(used_materials) > 1: + uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) + else: + uber_material = used_materials[0] + + vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') + texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None + normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None + + faces = torch.tensor(faces, dtype=torch.int64, device='cuda') + tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None + nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None + + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + +###################################################################################### +# Save mesh object to objfile +###################################################################################### + +def write_obj(folder, mesh, save_material=True): + obj_file = os.path.join(folder, 'mesh.obj') + print("Writing mesh: ", obj_file) + with open(obj_file, "w") as f: + f.write("mtllib mesh.mtl\n") + f.write("g default\n") + + v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None + v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None + v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None + + t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None + t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None + t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None: + print(" writing %d texcoords" % len(v_tex)) + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + if v_nrm is not None: + print(" writing %d normals" % len(v_nrm)) + assert(len(t_pos_idx) == len(t_nrm_idx)) + for v in v_nrm: + f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + if save_material: + mtl_file = os.path.join(folder, 'mesh.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, mesh.material) + + print("Done exporting mesh") diff --git a/models/lrm/models/geometry/rep_3d/tables.py b/models/lrm/models/geometry/rep_3d/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/models/lrm/models/geometry/rep_3d/texture.py b/models/lrm/models/geometry/rep_3d/texture.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4a39d042dc4d356c47133efee897088b9ce5c6 --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/texture.py @@ -0,0 +1,186 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util + +###################################################################################### +# Smooth pooling / mip computation with linear gradient upscaling +###################################################################################### + +class texture2d_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, texture): + return util.avg_pool_nhwc(texture, (2,2)) + + @staticmethod + def backward(ctx, dout): + gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), + torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), + indexing='ij') + uv = torch.stack((gx, gy), dim=-1) + return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') + +######################################################################################################## +# Simple texture class. A texture can be either +# - A 3D tensor (using auto mipmaps) +# - A list of 3D tensors (full custom mip hierarchy) +######################################################################################################## + +class Texture2D(torch.nn.Module): + # Initializes a texture from image data. + # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) + def __init__(self, init, min_max=None): + super(Texture2D, self).__init__() + + if isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + elif isinstance(init, list) and len(init) == 1: + init = init[0] + + if isinstance(init, list): + self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init) + elif len(init.shape) == 4: + self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True) + elif len(init.shape) == 3: + self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True) + elif len(init.shape) == 1: + self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor + else: + assert False, "Invalid texture object" + + self.min_max = min_max + + # Filtered (trilinear) sample texture at a given location + def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): + if isinstance(self.data, list): + out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) + else: + if self.data.shape[1] > 1 and self.data.shape[2] > 1: + mips = [self.data] + while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: + mips += [texture2d_mip.apply(mips[-1])] + out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) + else: + out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) + return out + + def getRes(self): + return self.getMips()[0].shape[1:3] + + def getChannels(self): + return self.getMips()[0].shape[3] + + def getMips(self): + if isinstance(self.data, list): + return self.data + else: + return [self.data] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + if self.min_max is not None: + for mip in self.getMips(): + for i in range(mip.shape[-1]): + mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) + + # In-place clamp with no derivative to make sure values are in valid range after training + def normalize_(self): + with torch.no_grad(): + for mip in self.getMips(): + mip = util.safe_normalize(mip) + +######################################################################################################## +# Helper function to create a trainable texture from a regular texture. The trainable weights are +# initialized with texture data as an initial guess +######################################################################################################## + +def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): + with torch.no_grad(): + if isinstance(init, Texture2D): + assert isinstance(init.data, torch.Tensor) + min_max = init.min_max if min_max is None else min_max + init = init.data + elif isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + + # Pad to NHWC if needed + if len(init.shape) == 1: # Extend constant to NHWC tensor + init = init[None, None, None, :] + elif len(init.shape) == 3: + init = init[None, ...] + + # Scale input to desired resolution. + if res is not None: + init = util.scale_img_nhwc(init, res) + + # Genreate custom mipchain + if not auto_mipmaps: + mip_chain = [init.clone().detach().requires_grad_(True)] + while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: + new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] + mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] + return Texture2D(mip_chain, min_max=min_max) + else: + return Texture2D(init, min_max=min_max) + +######################################################################################################## +# Convert texture to and from SRGB +######################################################################################################## + +def srgb_to_rgb(texture): + return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) + +def rgb_to_srgb(texture): + return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) + +######################################################################################################## +# Utility functions for loading / storing a texture +######################################################################################################## + +def _load_mip2D(fn, lambda_fn=None, channels=None): + imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') + if channels is not None: + imgdata = imgdata[..., 0:channels] + if lambda_fn is not None: + imgdata = lambda_fn(imgdata) + return imgdata.detach().clone() + +def load_texture2D(fn, lambda_fn=None, channels=None): + base, ext = os.path.splitext(fn) + if os.path.exists(base + "_0" + ext): + mips = [] + while os.path.exists(base + ("_%d" % len(mips)) + ext): + mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] + return Texture2D(mips) + else: + return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + +def _save_mip2D(fn, mip, mipidx, lambda_fn): + if lambda_fn is not None: + data = lambda_fn(mip).detach().cpu().numpy() + else: + data = mip.detach().cpu().numpy() + + if mipidx is None: + util.save_image(fn, data) + else: + base, ext = os.path.splitext(fn) + util.save_image(base + ("_%d" % mipidx) + ext, data) + +def save_texture2D(fn, tex, lambda_fn=None): + if isinstance(tex.data, list): + for i, mip in enumerate(tex.data): + _save_mip2D(fn, mip[0,...], i, lambda_fn) + else: + _save_mip2D(fn, tex.data[0,...], None, lambda_fn) diff --git a/models/lrm/models/geometry/rep_3d/util.py b/models/lrm/models/geometry/rep_3d/util.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e512ad110849ec3ed6344b53f9c422fc303096 --- /dev/null +++ b/models/lrm/models/geometry/rep_3d/util.py @@ -0,0 +1,466 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio + +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + size = tuple(int(s) for s in size) + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + diff --git a/models/lrm/models/lrm.py b/models/lrm/models/lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..2128bac1ed99ccf3a6d99cd4491950e56360cfd0 --- /dev/null +++ b/models/lrm/models/lrm.py @@ -0,0 +1,209 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import mcubes +import nvdiffrast.torch as dr +from einops import rearrange, repeat + +from .encoder.dino_wrapper import DinoWrapper +from .decoder.transformer import TriplaneTransformer +from .renderer.synthesizer import TriplaneSynthesizer +from ..utils.mesh_util import xatlas_uvmap + + +class InstantNeRF(nn.Module): + """ + Full model of the large reconstruction model. + """ + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = 'facebook/dino-vitb16', + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + ): + super().__init__() + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + ) + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + ) + + def forward_planes(self, images, cameras): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + + # transformer generating planes + planes = self.transformer(image_feats) + + return planes + + def forward_synthesizer(self, planes, render_cameras, render_size: int): + render_results = self.synthesizer( + planes, + render_cameras, + render_size, + ) + return render_results + + def forward(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + + # render target views + render_results = self.synthesizer(planes, render_cameras, render_size) + + return { + 'planes': planes, + **render_results, + } + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + tex_pos = torch.cat(tex_pos, dim=0) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + tex_feat = torch.utils.checkpoint.checkpoint( + self.synthesizer.forward_points, + planes, + tex_pos, + use_reentrant=False, + )['rgb'] + + if hard_mask is not None: + final_tex_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def extract_mesh( + self, + planes: torch.Tensor, + mesh_resolution: int = 256, + mesh_threshold: int = 10.0, + use_texture_map: bool = False, + texture_resolution: int = 1024, + **kwargs, + ): + ''' + Extract a 3D mesh from triplane nerf. Only support batch_size 1. + :param planes: triplane features + :param mesh_resolution: marching cubes resolution + :param mesh_threshold: iso-surface threshold + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texture map + ''' + assert planes.shape[0] == 1 + device = planes.device + + grid_out = self.synthesizer.forward_grid( + planes=planes, + grid_size=mesh_resolution, + ) + + vertices, faces = mcubes.marching_cubes( + grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), + mesh_threshold, + ) + vertices = vertices / (mesh_resolution - 1) * 2 - 1 + + if not use_texture_map: + # query vertex colors + vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) + vertices_colors = self.synthesizer.forward_points( + planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() + vertices_colors = (vertices_colors * 255).astype(np.uint8) + + return vertices, faces, vertices_colors + + # use x-atlas to get uv mapping for the mesh + vertices = torch.tensor(vertices, dtype=torch.float32, device=device) + faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) + + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.float() + + # query the texture field to get the RGB color for texture map + tex_feat = self.get_texture_prediction( + planes, [gb_pos], tex_hard_mask) + background_feature = torch.zeros_like(tex_feat) + img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map \ No newline at end of file diff --git a/models/lrm/models/lrm_mesh.py b/models/lrm/models/lrm_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..7a86a04c6f22fb70c28791ce98696ac76a562f70 --- /dev/null +++ b/models/lrm/models/lrm_mesh.py @@ -0,0 +1,413 @@ +# Copyright (c) 2023, Tencent Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import nvdiffrast.torch as dr +from einops import rearrange, repeat + +from .encoder.dino_wrapper import DinoWrapper +from .decoder.transformer import TriplaneTransformer +from .renderer.synthesizer_mesh import TriplaneSynthesizer +from .geometry.camera.perspective_camera import PerspectiveCamera +from .geometry.render.neural_render import NeuralRender +from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry +from ..utils.mesh_util import xatlas_uvmap +from .geometry.rep_3d import util +import trimesh +from PIL import Image +from models.lrm.utils import render +from models.lrm.utils.render_utils import rotate_x, rotate_y + +class PRM(nn.Module): + """ + Full model of the large reconstruction model. + """ + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = 'facebook/dino-vitb16', + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + grid_res: int = 128, + grid_scale: float = 2.0, + ): + super().__init__() + + # attributes + self.grid_res = grid_res + self.grid_scale = grid_scale + self.deformation_multiplier = 4.0 + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + ) + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + ) + + def init_flexicubes_geometry(self, device, fovy=50.0): + camera = PerspectiveCamera(fovy=fovy, device=device) + renderer = NeuralRender(device, camera_model=camera) + self.geometry = FlexiCubesGeometry( + grid_res=self.grid_res, + scale=self.grid_scale, + renderer=renderer, + render_type='neural_render', + device=device, + ) + + def forward_planes(self, images, cameras): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + + # decode triplanes + planes = self.transformer(image_feats) + + return planes + + def get_sdf_deformation_prediction(self, planes): + ''' + Predict SDF and deformation for tetrahedron vertices + :param planes: triplane feature map for the geometry + ''' + init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) + + # Step 1: predict the SDF and deformation + sdf, deformation, weight = torch.utils.checkpoint.checkpoint( + self.synthesizer.get_geometry_prediction, + planes, + init_position, + self.geometry.indices, + use_reentrant=False, + ) + + # Step 2: Normalize the deformation to avoid the flipped triangles. + deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation) + sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) + + #### + # Step 3: Fix some sdf if we observe empty shape (full positive or full negative) + sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) + sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) + pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) + neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) + zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) + if torch.sum(zero_surface).item() > 0: + update_sdf = torch.zeros_like(sdf[0:1]) + max_sdf = sdf.max() + min_sdf = sdf.min() + update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero + update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero + new_sdf = torch.zeros_like(sdf) + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + new_sdf[i_batch:i_batch + 1] += update_sdf + update_mask = (new_sdf == 0).float() + # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) + sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) + sdf_reg_loss = sdf_reg_loss * zero_surface.float() + sdf = sdf * update_mask + new_sdf * (1 - update_mask) + + # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative) + final_sdf = [] + final_def = [] + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + final_sdf.append(sdf[i_batch: i_batch + 1].detach()) + final_def.append(deformation[i_batch: i_batch + 1].detach()) + else: + final_sdf.append(sdf[i_batch: i_batch + 1]) + final_def.append(deformation[i_batch: i_batch + 1]) + sdf = torch.cat(final_sdf, dim=0) + deformation = torch.cat(final_def, dim=0) + return sdf, deformation, sdf_reg_loss, weight + + def get_geometry_prediction(self, planes=None): + ''' + Function to generate mesh with give triplanes + :param planes: triplane features + ''' + # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid. + sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes) + v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation + tets = self.geometry.indices + n_batch = planes.shape[0] + v_list = [] + f_list = [] + imesh_list = [] + flexicubes_surface_reg_list = [] + + # Step 2: Using marching tet to obtain the mesh + for i_batch in range(n_batch): + verts, faces, flexicubes_surface_reg, imesh = self.geometry.get_mesh( + v_deformed[i_batch], + sdf[i_batch].squeeze(dim=-1), + with_uv=False, + indices=tets, + weight_n=weight[i_batch].squeeze(dim=-1), + is_training=self.training, + ) + flexicubes_surface_reg_list.append(flexicubes_surface_reg) + v_list.append(verts) + f_list.append(faces) + imesh_list.append(imesh) + + flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() + flexicubes_weight_reg = (weight ** 2).mean() + + return v_list, f_list, imesh_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg) + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None, gb_normal=None, training=True): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + tex_pos = torch.cat(tex_pos, dim=0) + shape = tex_pos.shape + flat_pos = tex_pos.view(-1, 3) + if training: + with torch.no_grad(): + flat_pos = flat_pos @ rotate_y(-np.pi / 2, device=flat_pos.device)[:3, :3] + flat_pos = flat_pos @ rotate_x(-np.pi / 2, device=flat_pos.device)[:3, :3] + tex_pos = flat_pos.reshape(*shape) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + tex_feat, metalic_feat, roughness_feat = torch.utils.checkpoint.checkpoint( + self.synthesizer.get_texture_prediction, + planes, + tex_pos, + use_reentrant=False, + ) + metalic_feat, roughness_feat = metalic_feat[..., None], roughness_feat[..., None] + + if hard_mask is not None: + final_tex_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) + final_matallic_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], metalic_feat.shape[-1], device=metalic_feat.device) + final_roughness_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], roughness_feat.shape[-1], device=roughness_feat.device) + + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + expanded_hard_mask_m = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_matallic_feat.shape[-1]) > 0.5 + expanded_hard_mask_r = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_roughness_feat.shape[-1]) > 0.5 + + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + final_matallic_feat[i][expanded_hard_mask_m[i]] = metalic_feat[i][:n_point_list[i]].reshape(-1) + final_roughness_feat[i][expanded_hard_mask_r[i]] = roughness_feat[i][:n_point_list[i]].reshape(-1) + + tex_feat = final_tex_feat + metalic_feat = final_matallic_feat + roughness_feat = final_roughness_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]), metalic_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], metalic_feat.shape[-1]), roughness_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], roughness_feat.shape[-1]) + + + def render_mesh(self, mesh_v, mesh_f, imesh, cam_mv, camera_pos, env, planes, materials, render_size=256, gt_albedo_map=None, single=False): + ''' + Function to render a generated mesh with nvdiffrast + :param mesh_v: List of vertices for the mesh + :param mesh_f: List of faces for the mesh + :param cam_mv: 4x4 rotation matrix + :return: + ''' + return_value_list = [] + for i_mesh in range(len(mesh_v)): + return_value = self.geometry.render_mesh( + mesh_v[i_mesh], + mesh_f[i_mesh].int(), + imesh[i_mesh], + cam_mv[i_mesh], + camera_pos[i_mesh], + env[i_mesh], + planes[i_mesh], + self.get_texture_prediction, + materials[i_mesh], + resolution=render_size, + hierarchical_mask=False, + # gt_albedo_map=gt_albedo_map, + ) + return_value_list.append(return_value) + return_keys = return_value_list[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in return_value_list] + return_value[k] = value + # mask = torch.cat(return_value['mask'], dim=0) + hard_mask = torch.cat(return_value['mask'], dim=0) + # tex_pos = return_value['tex_pos'] + rgb = torch.cat(return_value['shaded'], dim=0) + spec_light = torch.cat(return_value['spec_light'], dim=0) + diff_light = torch.cat(return_value['diff_light'], dim=0) + albedo = torch.cat(return_value['albedo'], dim=0) + depth = torch.cat(return_value['depth'], dim=0) + normal = torch.cat(return_value['normal'], dim=0) + gb_normal = torch.cat(return_value['gb_normal'], dim=0) + return rgb, spec_light, diff_light, albedo, depth, normal, gb_normal, hard_mask + + + def forward_geometry(self, planes, render_cameras, camera_pos, env, materials, albedo_map=None, render_size=256, sample_points=None, gt_albedo_map=None, single=False): + ''' + Main function of our Generator. It first generate 3D mesh, then render it into 2D image + with given `render_cameras`. + :param planes: triplane features + :param render_cameras: cameras to render generated 3D shape + ''' + B, NV = render_cameras.shape[:2] + + # Generate 3D mesh first + mesh_v, mesh_f, imesh, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + predict_sample_points = None + + # Render the mesh into 2D image (get 3d position of each image plane) + cam_mv = render_cameras + + rgb, spec_light, diff_light, albedo, depth, normal, gb_normal, mask = self.render_mesh(mesh_v, mesh_f, imesh, cam_mv, camera_pos, env, planes, materials, + render_size=render_size, gt_albedo_map=gt_albedo_map, single=single) + albedo = albedo[...,:3].clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + pbr_img = rgb[...,:3].clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + normal_img = gb_normal[...,:3].permute(0, 3, 1, 2).unflatten(0, (B, NV)) + pbr_spec_light = spec_light[...,:3].clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + pbr_diffuse_light = diff_light[...,:3].clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + antilias_mask = mask[...,:3].permute(0, 3, 1, 2).unflatten(0, (B, NV)) + depth = depth[...,:3].permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive + + out = { + 'albedo': albedo, + 'pbr_img': pbr_img, + 'normal_img': normal_img, + 'pbr_spec_light': pbr_spec_light, + 'pbr_diffuse_light': pbr_diffuse_light, + 'depth': depth, + 'normal': gb_normal, + 'mask': antilias_mask, + 'sdf': sdf, + 'mesh_v': mesh_v, + 'mesh_f': mesh_f, + 'sdf_reg_loss': sdf_reg_loss, + 'triplane': planes, + 'sample_points': predict_sample_points + } + return out + + def forward(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + out = self.forward_geometry(planes, render_cameras, render_size=render_size) + + return { + 'planes': planes, + **out + } + + def extract_mesh( + self, + planes: torch.Tensor, + use_texture_map: bool = False, + texture_resolution: int = 1024, + **kwargs, + ): + ''' + Extract a 3D mesh from FlexiCubes. Only support batch_size 1. + :param planes: triplane features + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texure map + ''' + assert planes.shape[0] == 1 + device = planes.device + + # predict geometry first + mesh_v, mesh_f, imesh, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + vertices, faces = mesh_v[0], mesh_f[0] + with torch.no_grad(): + vertices = vertices @ rotate_y(-np.pi / 2, device=vertices.device)[:3, :3] + vertices = vertices @ rotate_x(-np.pi / 2, device=vertices.device)[:3, :3] + if not use_texture_map: + # query vertex colors + vertices_tensor = vertices.unsqueeze(0) + vertices_colors, matellic, roughness = self.synthesizer.get_texture_prediction( + planes, vertices_tensor) + vertices_colors = vertices_colors.clamp(0, 1).squeeze(0).cpu().numpy() + vertices_colors = (vertices_colors * 255).astype(np.uint8) + + return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors + + # use x-atlas to get uv mapping for the mesh + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.float() + + # query the texture field to get the RGB color for texture map + tex_feat, _, _ = self.get_texture_prediction( + planes, [gb_pos], tex_hard_mask, training=False) + background_feature = torch.zeros_like(tex_feat) + img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map diff --git a/models/lrm/models/renderer/__init__.py b/models/lrm/models/renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/models/lrm/models/renderer/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/models/lrm/models/renderer/synthesizer.py b/models/lrm/models/renderer/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8db9fbdb1703b566117d227c8e4eef04157ccc93 --- /dev/null +++ b/models/lrm/models/renderer/synthesizer.py @@ -0,0 +1,203 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import ImportanceRenderer +from .utils.ray_sampler import RaySampler + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + self.net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1 + 3), + ) + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def forward(self, sampled_features, ray_directions): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + x = x.contiguous().view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # renderings + self.renderer = ImportanceRenderer() + self.ray_sampler = RaySampler() + + # modules + self.decoder = OSGDecoder(n_features=triplane_dim) + + def forward(self, planes, cameras, render_size=128, crop_params=None): + # planes: (N, 3, D', H', W') + # cameras: (N, M, D_cam) + # render_size: int + assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" + N, M = cameras.shape[:2] + + cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) + intrinsics = cameras[..., 16:25].view(N, M, 3, 3) + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler( + cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), + intrinsics=intrinsics.reshape(-1, 3, 3), + render_size=render_size, + ) + assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + # Crop rays if crop_params is available + if crop_params is not None: + ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) + ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) + i, j, h, w = crop_params + ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + + # Perform volume rendering + rgb_samples, depth_samples, weights_samples = self.renderer( + planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, + ) + + # Reshape into 'raw' neural-rendered image + if crop_params is not None: + Himg, Wimg = crop_params[2:] + else: + Himg = Wimg = render_size + rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() + depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + + out = { + 'images_rgb': rgb_images, + 'images_depth': depth_images, + 'images_weight': weight_images, + } + return out + + def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = torch.tensor([ + [self.rendering_kwargs['sampler_bbox_min']] * 3, + [self.rendering_kwargs['sampler_bbox_max']] * 3, + ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append(torch.stack(torch.meshgrid( + torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), + indexing='ij', + ), dim=-1).reshape(-1, 3)) + cube_grid = torch.stack(grid_points, dim=0).to(planes.device) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = { + k: v.reshape(N, grid_size, grid_size, grid_size, -1) + for k, v in features.items() + } + return features + + def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i:i+chunk_size] + + # query triplane + chunk_out = self.renderer.run_model_activated( + planes=planes, + decoder=self.decoder, + sample_coordinates=chunk_points, + sample_directions=torch.zeros_like(chunk_points), + options=self.rendering_kwargs, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = { + k: torch.cat([out[k] for out in outs], dim=1) + for k in outs[0].keys() + } + return point_features diff --git a/models/lrm/models/renderer/synthesizer_mesh.py b/models/lrm/models/renderer/synthesizer_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..a4bc9f555049bc9c02934343434e1fa262e55762 --- /dev/null +++ b/models/lrm/models/renderer/synthesizer_mesh.py @@ -0,0 +1,156 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + + self.net_sdf = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1), + ) + self.net_rgb = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_material = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 2), + ) + self.net_deformation = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_weight = nn.Sequential( + nn.Linear(8 * 3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 21), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_geometry_prediction(self, sampled_features, flexicubes_indices): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + sdf = self.net_sdf(sampled_features) + deformation = self.net_deformation(sampled_features) + + grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) + grid_features = grid_features.reshape( + sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) + weight = self.net_weight(grid_features) * 0.1 + + return sdf, deformation, weight + + def get_texture_prediction(self, sampled_features): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + rgb = self.net_rgb(sampled_features) + rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + + materials = self.net_material(sampled_features) + materials = torch.sigmoid(materials) + metallic, roughness = materials[...,0], materials[...,1] + rmax, rmin = 1.0, 0.04 ** 2 + roughness = roughness * (rmax - rmin) + rmin + + return rgb, metallic, roughness + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # modules + self.plane_axes = generate_planes() + self.decoder = OSGDecoder(n_features=triplane_dim) + + def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) + return sdf, deformation, weight + + def get_texture_prediction(self, planes, sample_coordinates): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + rgb, matellic, roughness = self.decoder.get_texture_prediction(sampled_features) + return rgb, matellic, roughness diff --git a/models/lrm/models/renderer/utils/__init__.py b/models/lrm/models/renderer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/models/lrm/models/renderer/utils/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/models/lrm/models/renderer/utils/math_utils.py b/models/lrm/models/renderer/utils/math_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af --- /dev/null +++ b/models/lrm/models/renderer/utils/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/models/lrm/models/renderer/utils/ray_marcher.py b/models/lrm/models/renderer/utils/ray_marcher.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1db43478de703509cdd04c684f92f8e283c5ad --- /dev/null +++ b/models/lrm/models/renderer/utils/ray_marcher.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MipRayMarcher2(nn.Module): + def __init__(self, activation_factory): + super().__init__() + self.activation_factory = activation_factory + + def run_forward(self, colors, densities, depths, rendering_options, normals=None): + dtype = colors.dtype + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + # using factory mode for better usability + densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta).to(dtype) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + weights = weights.to(dtype) + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = torch.sum(weights * depths_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights + + + def forward(self, colors, densities, depths, rendering_options, normals=None): + if normals is not None: + composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) + return composite_rgb, composite_depth, composite_normals, weights + + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + return composite_rgb, composite_depth, weights diff --git a/models/lrm/models/renderer/utils/ray_sampler.py b/models/lrm/models/renderer/utils/ray_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5151dda467e826ce346986bd486d4465c906f2 --- /dev/null +++ b/models/lrm/models/renderer/utils/ray_sampler.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + dtype = cam2world_matrix.dtype + device = cam2world_matrix.device + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=dtype, device=device), + torch.arange(render_size, dtype=dtype, device=device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.ones((N, M), dtype=dtype, device=device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs + + +class OrthoRaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, ortho_scale, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + ortho_scale: float + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 3) + """ + + N, M = cam2world_matrix.shape[0], render_size**2 + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.zeros((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - 0.5) * ortho_scale + y_lift = (y_cam - 0.5) * ortho_scale + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs_cam = torch.stack([ + torch.zeros((N, M), device=cam2world_matrix.device), + torch.zeros((N, M), device=cam2world_matrix.device), + torch.ones((N, M), device=cam2world_matrix.device), + ], dim=-1) + ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) + + return ray_origins, ray_dirs diff --git a/models/lrm/models/renderer/utils/renderer.py b/models/lrm/models/renderer/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..95c4c728efbd0283b8ddd7dc6a1b28d1510efa97 --- /dev/null +++ b/models/lrm/models/renderer/utils/renderer.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ray_marcher import MipRayMarcher2 +from . import math_utils + + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + dtype = plane_features.dtype + + coordinates = (2/box_warp) * coordinates # add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample( + plane_features, + projected_coordinates.to(dtype), + mode=mode, + padding_mode=padding_mode, + align_corners=False, + ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample( + grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', + padding_mode='zeros', + align_corners=False, + ) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class ImportanceRenderer(torch.nn.Module): + """ + Modified original version to filter out-of-box samples as TensoRF does. + + Reference: + TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 + """ + def __init__(self): + super().__init__() + self.activation_factory = self._build_activation_factory() + self.ray_marcher = MipRayMarcher2(self.activation_factory) + self.plane_axes = generate_planes() + + def _build_activation_factory(self): + def activation_factory(options: dict): + if options['clamp_mode'] == 'softplus': + return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better + else: + assert False, "Renderer only supports `clamp_mode`=`softplus`!" + return activation_factory + + def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, + planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): + """ + Additional filtering is applied to filter out-of-box samples. + Modifications made by Zexin He. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + device = depths.device + + # define sample points with depths + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + # filter out-of-box samples + mask_inbox = \ + (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ + (sample_coordinates <= rendering_options['sampler_bbox_max']) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 3 + DATA_TYPE = _out['sigma'].dtype + colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD + colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) + + return colors_pass, densities_pass + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): + # self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + # Coarse Pass + colors_coarse, densities_coarse = self._forward_pass( + depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + colors_fine, densities_fine = self._forward_pass( + depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + else: + rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + return rgb_final, depth_final, weights.sum(2) + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) + out['sigma'] = self.activation_factory(options)(out['sigma']) + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + if normals1 is not None and normals2 is not None: + all_normals = torch.cat([normals1, normals2], dim = -2) + else: + all_normals = None + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + if all_normals is not None: + all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) + return all_depths, all_colors, all_normals, all_densities + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom= 10: + # exit() diff --git a/models/lrm/online_render/render_single.py b/models/lrm/online_render/render_single.py new file mode 100755 index 0000000000000000000000000000000000000000..c719fdb88f75ed301d60870d9e8bf0ebac9030a5 --- /dev/null +++ b/models/lrm/online_render/render_single.py @@ -0,0 +1,156 @@ +import os, sys +import math +import json +import importlib +import time +from .data.online_render_dataloader import load_obj +import glm +from pathlib import Path + +import cv2 +import torchvision +import random +from tqdm import tqdm +import numpy as np +from PIL import Image +import open3d as o3d +import sys +# from .src.utils.mesh import Mesh +import nvdiffrast.torch as dr +from .src.utils import obj, mesh, render_utils, render +import torch +import torch.nn.functional as F +import random +from kiui.cam import orbit_camera +import itertools +# from .src.utils.material import Material +# from .utils.camera_util import ( +# FOV_to_intrinsics, +# center_looking_at_camera_pose, +# get_circular_camera_poses, +# ) +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +import re + +def sample_spherical(phi, theta, cam_radius): + theta = np.deg2rad(theta) + phi = np.deg2rad(phi) + + z = cam_radius * np.cos(phi) * np.sin(theta) + x = cam_radius * np.sin(phi) * np.sin(theta) + y = cam_radius * np.cos(theta) + + return x, y, z + +def load_mipmap(env_path): + diffuse_path = os.path.join(env_path, "diffuse.pth") + diffuse = torch.load(diffuse_path, map_location=torch.device('cpu')) + + specular = [] + for i in range(6): + specular_path = os.path.join(env_path, f"specular_{i}.pth") + specular_tensor = torch.load(specular_path, map_location=torch.device('cpu')) + specular.append(specular_tensor) + return [specular, diffuse] + +ENV = load_mipmap("models/lrm/env_mipmap/6") +materials = (0.0,0.9) +GLCTX = dr.RasterizeCudaContext() + +def random_scene(): + train_res = [512, 512] + cam_near_far = [0.1, 1000.0] + fovy = np.deg2rad(50) + spp = 1 + cam_radius = 3.5 + layers = 1 + iter_res = 512 + proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1]) + + all_azimuths = np.array([0, 90, 180, 270]) + all_elevations = np.array([60, 90, 90, 120]) + + # all_azimuths = np.array([0]) + # all_elevations = np.array([60]) + all_mv = [] + all_campos = [] + all_mvp = [] + for index, (azimuths, elevations) in enumerate(zip(all_azimuths, all_elevations)): + x, y, z = sample_spherical(azimuths, elevations, cam_radius) + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + mv = torch.from_numpy(np.array(view_matrix)) + mvp = proj_mtx @ (mv) #w2c + campos = torch.linalg.inv(mv)[:3, 3] + all_mv.append(mv[None, ...].cuda()) + all_campos.append(campos[None, ...].cuda()) + all_mvp.append(mvp[None, ...].cuda()) + return all_mv, all_mvp, all_campos + +def rendering(ref_mesh): + all_mv, all_mvp, all_campos = random_scene() + iter_res = [512, 512] + iter_spp = 1 + layers = 1 + all_albedo = [] + all_alpha = [] + all_image = [] + all_ccm = [] + all_depth = [] + all_normal = [] + for i in range(len(all_mv)): + mvp = all_mvp[i] + campos = all_campos[i] + + with torch.no_grad(): + buffer_dict = render.render_mesh(GLCTX, ref_mesh, mvp, campos, [ENV], None, None, + materials, iter_res, spp=iter_spp, num_layers=layers, msaa=True, + background=None, gt_render=True) + image = buffer_dict['shaded'][0] + albedo = (buffer_dict['albedo'][0]).clamp(0., 1.) + alpha = buffer_dict['mask'][0][:, :, 3:] + ccm = buffer_dict['ccm'][0][...,:3] + alpha = buffer_dict['mask'][0][...,:3] + albedo = buffer_dict['albedo'][0].clamp(0., 1.) + # breakpoint() + ccm = ccm * alpha + depth = buffer_dict['depth'][0] + normal = buffer_dict['gb_normal'][0] + all_image.append(image) + all_albedo.append(albedo) + all_alpha.append(alpha) + all_ccm.append(ccm) + all_depth.append(depth) + all_normal.append(normal) + all_albedo = torch.stack(all_albedo) + all_alpha = torch.stack(all_alpha) + all_ccm = torch.stack(all_ccm) + all_normal = torch.stack(all_normal) + + all_image = torch.stack(all_image) + all_depth = torch.stack(all_depth) + + # breakpoint() + return all_image.detach(), all_albedo.detach(), all_alpha.detach(), all_ccm.detach(), all_depth.detach(), all_normal.detach() + +def render_mesh(mesh_path): + ref_mesh = load_obj(mesh_path, return_attributes=False) + ref_mesh = mesh.auto_normals(ref_mesh) + ref_mesh = mesh.compute_tangents(ref_mesh) + ref_mesh.rotate_x_90() + # print(f"start ==> {mesh_path}") + rgb, albedo, alpha, ccm, depth, normal = rendering(ref_mesh) + depth = depth[...,:3] * alpha + # breakpoint() + torchvision.utils.save_image(rgb.permute(0, 3, 1, 2), f"debug_image/{mesh_path.split('/')[-1].split('.')[0]}_rgb.png") + torchvision.utils.save_image(albedo.permute(0, 3, 1, 2), f"debug_image/{mesh_path.split('/')[-1].split('.')[0]}_albedo.png") + torchvision.utils.save_image(alpha.permute(0, 3, 1, 2), f"debug_image/{mesh_path.split('/')[-1].split('.')[0]}_alpha.png") + torchvision.utils.save_image(ccm.permute(0, 3, 1, 2), f"debug_image/{mesh_path.split('/')[-1].split('.')[0]}_ccm.png") + torchvision.utils.save_image(depth.permute(0, 3, 1, 2), f"debug_image/{mesh_path.split('/')[-1].split('.')[0]}_depth.png", normalize=True) + torchvision.utils.save_image(normal.permute(0, 3, 1, 2), f"debug_image/{mesh_path.split('/')[-1].split('.')[0]}_normal.png") + print(f"end ==> {mesh_path}") + +if __name__ == '__main__': + render_mesh("./meshes_online/bubble_mart_blue/bubble_mart_blue.obj") diff --git a/models/lrm/online_render/src/models/__init__.py b/models/lrm/online_render/src/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/lrm/online_render/src/models/geometry/__init__.py b/models/lrm/online_render/src/models/geometry/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..89e9a6c2fffe82a55693885dae78c1a630924389 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. diff --git a/models/lrm/online_render/src/models/geometry/camera/__init__.py b/models/lrm/online_render/src/models/geometry/camera/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..c5c7082e47c65a08e25489b3c3fd010d07ad9758 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/camera/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from torch import nn + + +class Camera(nn.Module): + def __init__(self): + super(Camera, self).__init__() + pass diff --git a/models/lrm/online_render/src/models/geometry/camera/perspective_camera.py b/models/lrm/online_render/src/models/geometry/camera/perspective_camera.py new file mode 100755 index 0000000000000000000000000000000000000000..7dcab0d2a321a77a5d3c2d4c3f40ba2cc32f6dfa --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/camera/perspective_camera.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from . import Camera +import numpy as np + + +def projection(x=0.1, n=1.0, f=50.0, near_plane=None): + if near_plane is None: + near_plane = n + return np.array( + [[n / x, 0, 0, 0], + [0, n / -x, 0, 0], + [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], + [0, 0, -1, 0]]).astype(np.float32) + + +class PerspectiveCamera(Camera): + def __init__(self, fovy=49.0, device='cuda'): + super(PerspectiveCamera, self).__init__() + self.device = device + focal = np.tan(fovy / 180.0 * np.pi * 0.5) + self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) + + def project(self, points_bxnx4): + out = torch.matmul( + points_bxnx4, + torch.transpose(self.proj_mtx, 1, 2)) + return out diff --git a/models/lrm/online_render/src/models/geometry/render/__init__.py b/models/lrm/online_render/src/models/geometry/render/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..483cfabbf395853f1ca3e67b856d5f17b9889d1b --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/__init__.py @@ -0,0 +1,8 @@ +import torch + +class Renderer(): + def __init__(self): + pass + + def forward(self): + pass \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/neural_render.py b/models/lrm/online_render/src/models/geometry/render/neural_render.py new file mode 100755 index 0000000000000000000000000000000000000000..5d86fcc3f752fa4fcc7e7088438e0f980d6cf64a --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/neural_render.py @@ -0,0 +1,293 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr +from . import Renderer +from . import util +from . import renderutils as ru +_FG_LUT = None + + +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate( + attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + +def xfm_points(points, matrix, use_python=True): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def compute_vertex_normal(v_pos, t_pos_idx): + i0 = t_pos_idx[:, 0] + i1 = t_pos_idx[:, 1] + i2 = t_pos_idx[:, 2] + + v0 = v_pos[i0, :] + v1 = v_pos[i1, :] + v2 = v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + +class NeuralRender(Renderer): + def __init__(self, device='cuda', camera_model=None): + super(NeuralRender, self).__init__() + self.device = device + self.ctx = dr.RasterizeCudaContext(device=device) + self.projection_mtx = None + self.camera = camera_model + + # ============================================================================================== + # pixel shader + # ============================================================================================== + # def shade( + # self, + # gb_pos, + # gb_geometric_normal, + # gb_normal, + # gb_tangent, + # gb_texc, + # gb_texc_deriv, + # view_pos, + # ): + + # ################################################################################ + # # Texture lookups + # ################################################################################ + # breakpoint() + # # Separate kd into alpha and color, default alpha = 1 + # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) + # kd = kd[..., 0:3] + + # ################################################################################ + # # Normal perturbation & normal bend + # ################################################################################ + + # perturbed_nrm = None + + # gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) + + # ################################################################################ + # # Evaluate BSDF + # ################################################################################ + + # assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type" + # bsdf = material['bsdf'] if bsdf is None else bsdf + # if bsdf == 'pbr': + # if isinstance(lgt, light.EnvironmentLight): + # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) + # else: + # assert False, "Invalid light type" + # elif bsdf == 'diffuse': + # if isinstance(lgt, light.EnvironmentLight): + # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False) + # else: + # assert False, "Invalid light type" + # elif bsdf == 'normal': + # shaded_col = (gb_normal + 1.0)*0.5 + # elif bsdf == 'tangent': + # shaded_col = (gb_tangent + 1.0)*0.5 + # elif bsdf == 'kd': + # shaded_col = kd + # elif bsdf == 'ks': + # shaded_col = ks + # else: + # assert False, "Invalid BSDF '%s'" % bsdf + + # # Return multiple buffers + # buffers = { + # 'shaded' : torch.cat((shaded_col, alpha), dim=-1), + # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), + # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1) + # } + # return buffers + + # ============================================================================================== + # Render a depth slice of the mesh (scene), some limitations: + # - Single mesh + # - Single light + # - Single material + # ============================================================================================== + def render_layer( + self, + rast, + rast_deriv, + mesh, + view_pos, + resolution, + spp, + msaa + ): + + # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution + rast_out_s = rast + rast_out_deriv_s = rast_deriv + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) + + # Compute geometric normals. We need those because of bent normals trick (for bump mapping) + v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] + v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] + v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] + face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) + face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) + + # Compute tangent space + assert mesh.v_nrm is not None and mesh.v_tng is not None + gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) + gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents + + # Texture coordinate + # assert mesh.v_tex is not None + # gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s) + perturbed_nrm = None + gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) + + return gb_pos, gb_normal + + def render_mesh( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + mesh, + camera_mv_bx4x4, + camera_pos, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + + # view_pos = torch.linalg.inv(mtx_in)[:, :3, 3] + view_pos = camera_pos + v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization] + + layers = [] + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False) + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + # normal = F.normalize(normal, dim=-1) + # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal + + def render_mesh_light( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + mesh, + camera_mv_bx4x4, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + + v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + normal = F.normalize(normal, dim=-1) + normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background + + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/__init__.py b/models/lrm/online_render/src/models/geometry/render/renderutils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..f29739f961e48de71c58b4bbc45801654df49a70 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith +__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ] diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/bsdf.py b/models/lrm/online_render/src/models/geometry/render/renderutils/bsdf.py new file mode 100755 index 0000000000000000000000000000000000000000..38457ed58ee447cdf74bb780eb7457d4db1f7f92 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/bsdf.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch + +NORMAL_THRESHOLD = 0.1 + +################################################################################ +# Vector utility functions +################################################################################ + +def _dot(x, y): + return torch.sum(x*y, -1, keepdim=True) + +def _reflect(x, n): + return 2*_dot(x, n)*n - x + +def _safe_normalize(x): + return torch.nn.functional.normalize(x, dim = -1) + +def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading): + # Swap normal direction for backfacing surfaces + if two_sided_shading: + smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm) + geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm) + + t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1) + return torch.lerp(geom_nrm, smooth_nrm, t) + + +def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl): + smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm)) + if opengl: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + else: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + return _safe_normalize(shading_nrm) + +def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + smooth_nrm = _safe_normalize(smooth_nrm) + smooth_tng = _safe_normalize(smooth_tng) + view_vec = _safe_normalize(view_pos - pos) + shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl) + return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading) + +################################################################################ +# Simple lambertian diffuse BSDF +################################################################################ + +def bsdf_lambert(nrm, wi): + return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi + +################################################################################ +# Frostbite diffuse +################################################################################ + +def bsdf_frostbite(nrm, wi, wo, linearRoughness): + wiDotN = _dot(wi, nrm) + woDotN = _dot(wo, nrm) + + h = _safe_normalize(wo + wi) + wiDotH = _dot(wi, h) + + energyBias = 0.5 * linearRoughness + energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness + f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness + f0 = 1.0 + + wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN) + woScatter = bsdf_fresnel_shlick(f0, f90, woDotN) + res = wiScatter * woScatter * energyFactor + return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res)) + +################################################################################ +# Phong specular, loosely based on mitsuba implementation +################################################################################ + +def bsdf_phong(nrm, wo, wi, N): + dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0) + dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0) + return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi) + +################################################################################ +# PBR's implementation of GGX specular +################################################################################ + +specular_epsilon = 1e-4 + +def bsdf_fresnel_shlick(f0, f90, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0 + +def bsdf_ndf_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1 + return alphaSqr / (d * d * math.pi) + +def bsdf_lambda_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + cosThetaSqr = _cosTheta * _cosTheta + tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr + res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0) + return res + +def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO): + lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI) + lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO) + return 1 / (1 + lambdaI + lambdaO) + +def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08): + _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0) + alphaSqr = _alpha * _alpha + + h = _safe_normalize(wo + wi) + woDotN = _dot(wo, nrm) + wiDotN = _dot(wi, nrm) + woDotH = _dot(wo, h) + nDotH = _dot(nrm, h) + + D = bsdf_ndf_ggx(alphaSqr, nDotH) + G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN) + F = bsdf_fresnel_shlick(col, 1, woDotH) + + w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon) + + frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon) + return torch.where(frontfacing, w, torch.zeros_like(w)) + +def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + wo = _safe_normalize(view_pos - pos) + wi = _safe_normalize(light_pos - pos) + + spec_str = arm[..., 0:1] # x component + roughness = arm[..., 1:2] # y component + metallic = arm[..., 2:3] # z component + ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str) + kd = kd * (1.0 - metallic) + + if BSDF == 0: + diffuse = kd * bsdf_lambert(nrm, wi) + else: + diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness) + specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness) + return diffuse + specular diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/bsdf.cu b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/bsdf.cu new file mode 100755 index 0000000000000000000000000000000000000000..c167214f9a4cb42b8d640202969e3950be8b806d --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/bsdf.cu @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "bsdf.h" + +#define SPECULAR_EPSILON 1e-4f + +//------------------------------------------------------------------------ +// Lambert functions + +__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi) +{ + return max(dot(nrm, wi) / M_PI, 0.0f); +} + +__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) +{ + if (dot(nrm, wi) > 0.0f) + bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); +} + +//------------------------------------------------------------------------ +// Fresnel Schlick + +__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); + } +} + +__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); + } +} + +//------------------------------------------------------------------------ +// Frostbite diffuse + +__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + return wiScatter * woScatter * energyFactor; + } + else return 0.0f; +} + +__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + // -------------- BWD -------------- + // Backprop: return wiScatter * woScatter * energyFactor; + float d_wiScatter = d_out * woScatter * energyFactor; + float d_woScatter = d_out * wiScatter * energyFactor; + float d_energyFactor = d_out * wiScatter * woScatter; + + // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f; + bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter); + + // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN); + float d_wiDotN = 0.0f; + bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter); + + // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float d_energyBias = d_f90; + float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness; + d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH; + + // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor; + + // Backprop: float energyBias = 0.5f * linearRoughness; + d_linearRoughness += 0.5 * d_energyBias; + + // Backprop: float wiDotH = dot(wi, h); + vec3f d_h(0); + bwdDot(wi, h, d_wi, d_h, d_wiDotH); + + // Backprop: vec3f h = safeNormalize(wo + wi); + vec3f d_wo_wi(0); + bwdSafeNormalize(wo + wi, d_wo_wi, d_h); + d_wi += d_wo_wi; d_wo += d_wo_wi; + + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + } +} + +//------------------------------------------------------------------------ +// Ndf GGX + +__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + // Torch only back propagates if clamp doesn't trigger + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + } +} + +//------------------------------------------------------------------------ +// Lambda GGX + +__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + return res; +} + +__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + + d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); +} + +//------------------------------------------------------------------------ +// Masking GGX + +__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) +{ + // FWD eval + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + + // BWD eval + float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); + bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); + bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); +} + +//------------------------------------------------------------------------ +// GGX specular + +__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) +{ + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + return frontfacing ? w : 0.0f; +} + +__device__ void bwdPbrSpecular( + const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, + vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) +{ + /////////////////////////////////////////////////////////////////////// + // FWD eval + + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + + if (frontfacing) + { + /////////////////////////////////////////////////////////////////////// + // BWD eval + + vec3f d_F = d_out * D * G * 0.25f / woDotN; + float d_D = sum(d_out * F * G * 0.25f / woDotN); + float d_G = sum(d_out * F * D * 0.25f / woDotN); + + float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); + + vec3f d_f90(0); + float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); + bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); + bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); + bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); + + vec3f d_h(0); + bwdDot(nrm, h, d_nrm, d_h, d_nDotH); + bwdDot(wo, h, d_wo, d_h, d_woDotH); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + + vec3f d_h_unnorm(0); + bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); + d_wo += d_h_unnorm; + d_wi += d_h_unnorm; + + if (alpha > min_roughness * min_roughness) + d_alpha += d_alphaSqr * 2 * alpha; + } +} + +//------------------------------------------------------------------------ +// Full PBR BSDF + +__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF) +{ + vec3f wo = safeNormalize(view_pos - pos); + vec3f wi = safeNormalize(light_pos - pos); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + vec3f diffuse = diff_col * diff; + vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); + + return diffuse + specular; +} + +__device__ void bwdPbrBSDF( + const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF, + vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _wi = light_pos - pos; + vec3f _wo = view_pos - pos; + vec3f wi = safeNormalize(_wi); + vec3f wo = safeNormalize(_wo); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + + //////////////////////////////////////////////////////////////////////// + // BWD + + float d_alpha(0); + vec3f d_spec_col(0), d_wi(0), d_wo(0); + bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + float d_diff = sum(diff_col * d_out); + if (BSDF == 0) + bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); + else + bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff); + + // Backprop: diff_col = kd * (1.0f - arm.z) + vec3f d_diff_col = d_out * diff; + d_kd += d_diff_col * (1.0f - arm.z); + d_arm.z -= sum(d_diff_col * kd); + + // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) + d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; + d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); + d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); + + // Backprop: alpha = arm.y * arm.y + d_arm.y += d_alpha * 2 * arm.y; + + // Backprop: vec3f wi = safeNormalize(light_pos - pos); + vec3f d__wi(0); + bwdSafeNormalize(_wi, d__wi, d_wi); + d_light_pos += d__wi; + d_pos -= d__wi; + + // Backprop: vec3f wo = safeNormalize(view_pos - pos); + vec3f d__wo(0); + bwdSafeNormalize(_wo, d__wo, d_wo); + d_view_pos += d__wo; + d_pos -= d__wo; +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void LambertFwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + + float res = fwdLambert(nrm, wi); + + p.out.store(px, py, pz, res); +} + +__global__ void LambertBwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + vec3f d_nrm(0), d_wi(0); + bwdLambert(nrm, wi, d_nrm, d_wi, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); +} + +__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + + float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness); + + p.out.store(px, py, pz, res); +} + +__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_linearRoughness = 0.0f; + vec3f d_nrm(0), d_wi(0), d_wo(0); + bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); + p.wo.store_grad(px, py, pz, d_wo); + p.linearRoughness.store_grad(px, py, pz, d_linearRoughness); +} + +__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + + vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); + p.out.store(px, py, pz, res); +} + +__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_f0(0), d_f90(0); + float d_cosTheta(0); + bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); + + p.f0.store_grad(px, py, pz, d_f0); + p.f90.store_grad(px, py, pz, d_f90); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void ndfGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdNdfGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void ndfGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void lambdaGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdLambdaGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void lambdaGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void maskingSmithFwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); + + p.out.store(px, py, pz, res); +} + +__global__ void maskingSmithBwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); + bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); + p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); +} + +__global__ void pbrSpecularFwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + + vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); + + p.out.store(px, py, pz, res); +} + +__global__ void pbrSpecularBwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + float d_alpha(0); + vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); + bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + p.col.store_grad(px, py, pz, d_col); + p.nrm.store_grad(px, py, pz, d_nrm); + p.wo.store_grad(px, py, pz, d_wo); + p.wi.store_grad(px, py, pz, d_wi); + p.alpha.store_grad(px, py, pz, d_alpha); +} + +__global__ void pbrBSDFFwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + + vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF); + + p.out.store(px, py, pz, res); +} +__global__ void pbrBSDFBwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); + bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); + + p.kd.store_grad(px, py, pz, d_kd); + p.arm.store_grad(px, py, pz, d_arm); + p.pos.store_grad(px, py, pz, d_pos); + p.nrm.store_grad(px, py, pz, d_nrm); + p.view_pos.store_grad(px, py, pz, d_view_pos); + p.light_pos.store_grad(px, py, pz, d_light_pos); +} diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/bsdf.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/bsdf.h new file mode 100755 index 0000000000000000000000000000000000000000..59adbf097490c5a643ebdcff9c3784173522e070 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/bsdf.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct LambertKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor out; + dim3 gridSize; +}; + +struct FrostbiteDiffuseKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor wo; + Tensor linearRoughness; + Tensor out; + dim3 gridSize; +}; + +struct FresnelShlickKernelParams +{ + Tensor f0; + Tensor f90; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct NdfGGXParams +{ + Tensor alphaSqr; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct MaskingSmithParams +{ + Tensor alphaSqr; + Tensor cosThetaI; + Tensor cosThetaO; + Tensor out; + dim3 gridSize; +}; + +struct PbrSpecular +{ + Tensor col; + Tensor nrm; + Tensor wo; + Tensor wi; + Tensor alpha; + Tensor out; + dim3 gridSize; + float min_roughness; +}; + +struct PbrBSDF +{ + Tensor kd; + Tensor arm; + Tensor pos; + Tensor nrm; + Tensor view_pos; + Tensor light_pos; + Tensor out; + dim3 gridSize; + float min_roughness; + int BSDF; +}; diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/common.cpp b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/common.cpp new file mode 100755 index 0000000000000000000000000000000000000000..445895e57f7d0bcd6a2812f5ba97d7be2ddfbe28 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/common.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +//------------------------------------------------------------------------ +// Block and grid size calculators for kernel launches. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims) +{ + int maxThreads = maxWidth * maxHeight; + if (maxThreads <= 1 || (dims.x * dims.y) <= 1) + return dim3(1, 1, 1); // Degenerate. + + // Start from max size. + int bw = maxWidth; + int bh = maxHeight; + + // Optimizations for weirdly sized buffers. + if (dims.x < bw) + { + // Decrease block width to smallest power of two that covers the buffer width. + while ((bw >> 1) >= dims.x) + bw >>= 1; + + // Maximize height. + bh = maxThreads / bw; + if (bh > dims.y) + bh = dims.y; + } + else if (dims.y < bh) + { + // Halve height and double width until fits completely inside buffer vertically. + while (bh > dims.y) + { + bh >>= 1; + if (bw < dims.x) + bw <<= 1; + } + } + + // Done. + return dim3(bw, bh, 1); +} + +// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync) +dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + std::min(blockSize.x, 32u), + std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), + std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z)) + ); +} + +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims) +{ + dim3 gridSize; + gridSize.x = (dims.x - 1) / blockSize.x + 1; + gridSize.y = (dims.y - 1) / blockSize.y + 1; + gridSize.z = (dims.z - 1) / blockSize.z + 1; + return gridSize; +} + +//------------------------------------------------------------------------ diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/common.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/common.h new file mode 100755 index 0000000000000000000000000000000000000000..5abaeebdd3f0a0910f7df3e9e0470a9fa682d507 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/common.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +#include "vec3f.h" +#include "vec4f.h" +#include "tensor.h" + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims); +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims); + +#ifdef __CUDACC__ + +#ifdef _MSC_VER +#define M_PI 3.14159265358979323846f +#endif + +__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + min(blockSize.x, 32u), + min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)), + min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z)) + ); +} + +__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); } +#else +dim3 getWarpSize(dim3 blockSize); +#endif \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/cubemap.cu b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/cubemap.cu new file mode 100755 index 0000000000000000000000000000000000000000..2ce21d83b2dd6759da30874cf8e01b7fd88e9217 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/cubemap.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "cubemap.h" +#include + +// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf +__device__ float pixel_area(int x, int y, int N) +{ + if (N > 1) + { + int H = N / 2; + x = abs(x - H); + y = abs(y - H); + float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H); + float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H); + return dx * dy; + } + else + return 1; +} + +__device__ vec3f cube_to_dir(int x, int y, int side, int N) +{ + float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f; + float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f; + switch (side) + { + case 0: return safeNormalize(vec3f(1, -fy, -fx)); + case 1: return safeNormalize(vec3f(-1, -fy, fx)); + case 2: return safeNormalize(vec3f(fx, 1, fy)); + case 3: return safeNormalize(vec3f(fx, -1, -fy)); + case 4: return safeNormalize(vec3f(fx, -fy, 1)); + case 5: return safeNormalize(vec3f(-fx, -fy, -1)); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ vec3f dir_to_side(int side, vec3f v) +{ + switch (side) + { + case 0: return vec3f(-v.z, -v.y, v.x); + case 1: return vec3f( v.z, -v.y, -v.x); + case 2: return vec3f( v.x, v.z, v.y); + case 3: return vec3f( v.x, -v.z, -v.y); + case 4: return vec3f( v.x, -v.y, v.z); + case 5: return vec3f(-v.x, -v.y, -v.z); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max) +{ + float l = sqrtf(x * x + z * z); + float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l; + float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l; + if (pzl <= 0.00001f) + _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX; + else + _min = pxl / pzl; + if (pzr <= 0.00001f) + _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX; + else + _max = pxr / pzr; +} + +__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax) +{ + vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1 + + if (theta < 0.785398f) // PI/4 + { + float xmin, xmax, ymin, ymax; + extents_1d(c.x, c.z, theta, xmin, xmax); + extents_1d(c.y, c.z, theta, ymin, ymax); + + if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f) + { + _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb + } + else + { + _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + } + } + else + { + _xmin = 0.0f; + _xmax = (float)(N-1); + _ymin = 0.0f; + _ymax = (float)(N-1); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Diffuse kernel +__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + + vec3f col(0); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + col += p.cubemap.fetch3(x, y, s) * w; + } + } + } + + p.out.store(px, py, pz, col); +} + +__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + vec3f grad = p.out.fetch3(px, py, pz); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// GGX splitsum kernel + +__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, 0.0, 1.0f); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p) +{ + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.gridSize.x; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + const int TILE_SIZE = 16; + + // Brute force entire cubemap and compute bounds for the cone + for (int s = 0; s < p.gridSize.z; ++s) + { + // Assume empty BBox + int _min_x = p.gridSize.x - 1, _max_x = 0; + int _min_y = p.gridSize.y - 1, _max_y = 0; + + // For each (8x8) tile + for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++) + { + for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++) + { + // Compute tile extents + int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE; + int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y); + + // Use some blunt interval arithmetics to cull tiles + vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx); + vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx); + + float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x)); + float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y)); + float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z)); + + float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z); + if (maxdp >= p.costheta_cutoff) + { + // Test all pixels in tile. + for (int y = tsy; y < tey; ++y) + { + for (int x = tsx; x < tex; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + _min_x = min(_min_x, x); + _max_x = max(_max_x, x); + _min_y = min(_min_y, y); + _max_y = max(_max_y, y); + } + } + } + } + } + } + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y); + } +} + +__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + float wsum = 0.0f; + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + col += p.cubemap.fetch3(x, y, s) * w; + wsum += w; + } + } + } + } + } + + p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x); + p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y); + p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z); + p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum); +} + +__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + vec3f grad = p.out.fetch3(px, py, pz); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } + } + } +} diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/cubemap.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/cubemap.h new file mode 100755 index 0000000000000000000000000000000000000000..f395cc237d4a46c660bcde18609068a21f3c3fea --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/cubemap.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct DiffuseCubemapKernelParams +{ + Tensor cubemap; + Tensor out; + dim3 gridSize; +}; + +struct SpecularCubemapKernelParams +{ + Tensor cubemap; + Tensor bounds; + Tensor out; + dim3 gridSize; + float costheta_cutoff; + float roughness; +}; + +struct SpecularBoundsKernelParams +{ + float costheta_cutoff; + Tensor out; + dim3 gridSize; +}; diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/loss.cu b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/loss.cu new file mode 100755 index 0000000000000000000000000000000000000000..aae5272de3c5364c22ee0bd5fde023d908e9153d --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/loss.cu @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +#include "common.h" +#include "loss.h" + +//------------------------------------------------------------------------ +// Utils + +__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; } + +__device__ float warpSum(float val) { + for (int i = 1; i < 32; i *= 2) + val += __shfl_xor_sync(0xFFFFFFFF, val, i); + return val; +} + +//------------------------------------------------------------------------ +// Tonemapping + +__device__ inline float fwdSRGB(float x) +{ + return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f); +} + +__device__ inline void bwdSRGB(float x, float &d_x, float d_out) +{ + if (x > 0.0031308f) + d_x += d_out * 0.439583f / powf(x, 0.583333f); + else if (x > 0.0f) + d_x += d_out * 12.92f; +} + +__device__ inline vec3f fwdTonemapLogSRGB(vec3f x) +{ + return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f))); +} + +__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out) +{ + if (x.x > 0.0f && x.x < 65535.0f) + { + bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x); + d_x.x *= 1 / (x.x + 1.0f); + } + if (x.y > 0.0f && x.y < 65535.0f) + { + bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y); + d_x.y *= 1 / (x.y + 1.0f); + } + if (x.z > 0.0f && x.z < 65535.0f) + { + bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z); + d_x.z *= 1 / (x.z + 1.0f); + } +} + +__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f) +{ + return (img - target) * (img - target) / (img * img + target * target + eps); +} + +__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f) +{ + float denom = (target * target + img * img + eps); + d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom); + d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom); +} + +__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f) +{ + return abs(img - target) / (img + target + eps); +} + +__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f) +{ + float denom = (target + img + eps); + d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom); + d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom); +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void imgLossFwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + float floss = 0.0f; + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z) + { + vec3f img = p.img.fetch3(px, py, pz); + vec3f target = p.target.fetch3(px, py, pz); + + img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f)); + target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f)); + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + vec3f vloss(0); + if (p.loss == LOSS_MSE) + vloss = (img - target) * (img - target); + else if (p.loss == LOSS_RELMSE) + vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z)); + else if (p.loss == LOSS_SMAPE) + vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z)); + else + vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z)); + + floss = sum(vloss) / 3.0f; + } + + floss = warpSum(floss); + + dim3 warpSize = getWarpSize(blockDim); + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0) + p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss); +} + +__global__ void imgLossBwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + dim3 warpSize = getWarpSize(blockDim); + + vec3f _img = p.img.fetch3(px, py, pz); + vec3f _target = p.target.fetch3(px, py, pz); + float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z); + + ///////////////////////////////////////////////////////////////////// + // FWD + + vec3f img = _img, target = _target; + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + ///////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f; + + vec3f d_img(0), d_target(0); + if (p.loss == LOSS_MSE) + { + d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z)); + d_target = -d_img; + } + else if (p.loss == LOSS_RELMSE) + { + bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else if (p.loss == LOSS_SMAPE) + { + bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else + { + d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z)); + d_target = -d_img; + } + + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + vec3f d__img(0), d__target(0); + bwdTonemapLogSRGB(_img, d__img, d_img); + bwdTonemapLogSRGB(_target, d__target, d_target); + d_img = d__img; d_target = d__target; + } + + if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0; + if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0; + if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0; + if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0; + if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0; + if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0; + + p.img.store_grad(px, py, pz, d_img); + p.target.store_grad(px, py, pz, d_target); +} \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/loss.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/loss.h new file mode 100755 index 0000000000000000000000000000000000000000..26790bf02de2afd9d27e541edf23d1b064f6f9a9 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/loss.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +enum TonemapperType +{ + TONEMAPPER_NONE = 0, + TONEMAPPER_LOG_SRGB = 1 +}; + +enum LossType +{ + LOSS_L1 = 0, + LOSS_MSE = 1, + LOSS_RELMSE = 2, + LOSS_SMAPE = 3 +}; + +struct LossKernelParams +{ + Tensor img; + Tensor target; + Tensor out; + dim3 gridSize; + TonemapperType tonemapper; + LossType loss; +}; diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/mesh.cu b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/mesh.cu new file mode 100755 index 0000000000000000000000000000000000000000..3690ea3621c38beae03ac9ff228cf5605d303663 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/mesh.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +#include "common.h" +#include "mesh.h" + + +//------------------------------------------------------------------------ +// Kernels + +__global__ void xfmPointsFwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + if (p.isPoints) + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]); + p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]); + } + else + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]); + } +} + +__global__ void xfmPointsBwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + vec4f d_out( + p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0)) + ); + + if (p.isPoints) + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]); + } + else + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]); + } +} \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/mesh.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/mesh.h new file mode 100755 index 0000000000000000000000000000000000000000..16e2166cc55f41c4482b2c5010529e9c75182d7b --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/mesh.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct XfmKernelParams +{ + bool isPoints; + Tensor points; + Tensor matrix; + Tensor out; + dim3 gridSize; +}; diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/normal.cu b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/normal.cu new file mode 100755 index 0000000000000000000000000000000000000000..a50e49e6b5b4061a60ec4d5d8edca2fb0833570e --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/normal.cu @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "normal.h" + +#define NORMAL_THRESHOLD 0.1f + +//------------------------------------------------------------------------ +// Perturb shading normal by tangent frame + +__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl) +{ + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + return safeNormalize(_shading_nrm); +} + +__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + vec3f d_shading_nrm(0); + bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out); + + vec3f d_smooth_bitng(0); + + if (perturbed_nrm.z > 0.0f) + { + d_smooth_nrm += d_shading_nrm * perturbed_nrm.z; + d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm); + } + + d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y; + d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng); + + d_smooth_tng += d_shading_nrm * perturbed_nrm.x; + d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng); + + vec3f d__smooth_bitng(0); + bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng); + + bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng); +} + +//------------------------------------------------------------------------ +#define bent_nrm_eps 0.001f + +__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm) +{ + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + return geom_nrm * (1.0f - t) + smooth_nrm * t; +} + +__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + if (dp > NORMAL_THRESHOLD) + d_smooth_nrm += d_out; + else + { + // geom_nrm * (1.0f - t) + smooth_nrm * t; + d_geom_nrm += d_out * (1.0f - t); + d_smooth_nrm += d_out * t; + float d_t = sum(d_out * (smooth_nrm - geom_nrm)); + + float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD; + + bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp); + } +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f view_vec = safeNormalize(view_pos - pos); + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + vec3f res; + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm); + else + res = fwdBendNormal(view_vec, shading_nrm, geom_nrm); + + p.out.store(px, py, pz, res); +} + +__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // FWD + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f _view_vec = view_pos - pos; + vec3f view_vec = safeNormalize(view_pos - pos); + + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0); + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + { + bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + d_shading_nrm = -d_shading_nrm; + d_geom_nrm = -d_geom_nrm; + } + else + bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + + vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0); + bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl); + + vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0); + bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec); + bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm); + bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng); + + p.pos.store_grad(px, py, pz, -d__view_vec); + p.view_pos.store_grad(px, py, pz, d__view_vec); + p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm); + p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm); + p.smooth_tng.store_grad(px, py, pz, d__smooth_tng); + p.geom_nrm.store_grad(px, py, pz, d_geom_nrm); +} \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/normal.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/normal.h new file mode 100755 index 0000000000000000000000000000000000000000..8882c225cfba5e747462c056d6fcf0b04dd48751 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/normal.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct PrepareShadingNormalKernelParams +{ + Tensor pos; + Tensor view_pos; + Tensor perturbed_nrm; + Tensor smooth_nrm; + Tensor smooth_tng; + Tensor geom_nrm; + Tensor out; + dim3 gridSize; + bool two_sided_shading, opengl; +}; diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/tensor.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/tensor.h new file mode 100755 index 0000000000000000000000000000000000000000..1dfb4e85c46f0394821f2533dc98468e5b7248af --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/tensor.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#if defined(__CUDACC__) && defined(BFLOAT16) +#include // bfloat16 is float32 compatible with less mantissa bits +#endif + +//--------------------------------------------------------------------------------- +// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16 + +struct Tensor +{ + void* val; + void* d_val; + int dims[4], _dims[4]; + int strides[4]; + bool fp16; + +#if defined(__CUDA__) && !defined(__CUDA_ARCH__) + Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {} +#endif + +#ifdef __CUDACC__ + // Helpers to index and read/write a single element + __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; } + __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); } + __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; } +#ifdef BFLOAT16 + __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; } +#else + __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; } +#endif + + ////////////////////////////////////////////////////////////////////////////////////////// + // Fetch, use broadcasting for tensor dimensions of size 1 + __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const + { + return fetch(nhwcIndex(z, y, x, 0)); + } + + __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const + { + return vec3f( + fetch(nhwcIndex(z, y, x, 0)), + fetch(nhwcIndex(z, y, x, 1)), + fetch(nhwcIndex(z, y, x, 2)) + ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store(_nhwcIndex(z, y, x, 0), _val); + } + + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store(_nhwcIndex(z, y, x, 0), _val.x); + store(_nhwcIndex(z, y, x, 1), _val.y); + store(_nhwcIndex(z, y, x, 2), _val.z); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val); + } + + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x); + store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y); + store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z); + } +#endif + +}; diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/torch_bindings.cpp b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/torch_bindings.cpp new file mode 100755 index 0000000000000000000000000000000000000000..64c9e70f79507944490cb978233c34ac9e3e97a6 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/torch_bindings.cpp @@ -0,0 +1,1062 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#ifdef _MSC_VER +#pragma warning(push, 0) +#include +#pragma warning(pop) +#else +#include +#endif + +#include +#include +#include +#include + +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); } +#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } +#define CHECK_TENSOR(X, DIMS, CHANNELS) \ + TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \ + TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \ + TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \ + TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels") + +#include "common.h" +#include "loss.h" +#include "normal.h" +#include "cubemap.h" +#include "bsdf.h" +#include "mesh.h" + +#define BLOCK_X 8 +#define BLOCK_Y 8 + +//------------------------------------------------------------------------ +// mesh.cu + +void xfmPointsFwdKernel(XfmKernelParams p); +void xfmPointsBwdKernel(XfmKernelParams p); + +//------------------------------------------------------------------------ +// loss.cu + +void imgLossFwdKernel(LossKernelParams p); +void imgLossBwdKernel(LossKernelParams p); + +//------------------------------------------------------------------------ +// normal.cu + +void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p); +void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p); + +//------------------------------------------------------------------------ +// cubemap.cu + +void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p); +void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p); +void SpecularBoundsKernel(SpecularBoundsKernelParams p); +void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p); +void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p); + +//------------------------------------------------------------------------ +// bsdf.cu + +void LambertFwdKernel(LambertKernelParams p); +void LambertBwdKernel(LambertKernelParams p); + +void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p); +void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p); + +void FresnelShlickFwdKernel(FresnelShlickKernelParams p); +void FresnelShlickBwdKernel(FresnelShlickKernelParams p); + +void ndfGGXFwdKernel(NdfGGXParams p); +void ndfGGXBwdKernel(NdfGGXParams p); + +void lambdaGGXFwdKernel(NdfGGXParams p); +void lambdaGGXBwdKernel(NdfGGXParams p); + +void maskingSmithFwdKernel(MaskingSmithParams p); +void maskingSmithBwdKernel(MaskingSmithParams p); + +void pbrSpecularFwdKernel(PbrSpecular p); +void pbrSpecularBwdKernel(PbrSpecular p); + +void pbrBSDFFwdKernel(PbrBSDF p); +void pbrBSDFBwdKernel(PbrBSDF p); + +//------------------------------------------------------------------------ +// Tensor helpers + +void update_grid(dim3 &gridSize, torch::Tensor x) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); +} + +template +void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); + update_grid(gridSize, std::forward(vs)...); +} + +Tensor make_cuda_tensor(torch::Tensor val) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + return res; +} + +Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + if (val.dim() == 4) + res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3); + else + res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out + + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + if (grad != nullptr) + { + if (val.dim() == 4) + *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + else // 3 + *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + + res.d_val = res.fp16 ? (void*)grad->data_ptr() : (void*)grad->data_ptr(); + } + return res; +} + +//------------------------------------------------------------------------ +// prepare_shading_normal + +torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16) +{ + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(perturbed_nrm, 4, 3); + CHECK_TENSOR(smooth_nrm, 4, 3); + CHECK_TENSOR(smooth_tng, 4, 3); + CHECK_TENSOR(geom_nrm, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + p.out.fp16 = fp16; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.pos = make_cuda_tensor(pos, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad; + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad); +} + +//------------------------------------------------------------------------ +// lambert + +torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + update_grid(p.gridSize, nrm, wi); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad); +} + +//------------------------------------------------------------------------ +// frostbite diffuse + +torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(linearRoughness, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad, wo_grad, linearRoughness_grad); +} + +//------------------------------------------------------------------------ +// fresnel_shlick + +torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(f0, 4, 3); + CHECK_TENSOR(f90, 4, 3); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.f0 = make_cuda_tensor(f0, p.gridSize); + p.f90 = make_cuda_tensor(f90, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor f0_grad, f90_grad, cosT_grad; + p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad); + p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(f0_grad, f90_grad, cosT_grad); +} + +//------------------------------------------------------------------------ +// ndf_ggd + +torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// lambda_ggx + +torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// masking_smith + +torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosThetaI, 4, 1); + CHECK_TENSOR(cosThetaO, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad); +} + +//------------------------------------------------------------------------ +// pbr_specular + +torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16) +{ + CHECK_TENSOR(col, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(alpha, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.col = make_cuda_tensor(col, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.alpha = make_cuda_tensor(alpha, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + p.min_roughness = min_roughness; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad; + p.col = make_cuda_tensor(col, p.gridSize, &col_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad); +} + +//------------------------------------------------------------------------ +// pbr_bsdf + +torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16) +{ + CHECK_TENSOR(kd, 4, 3); + CHECK_TENSOR(arm, 4, 3); + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(light_pos, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + p.BSDF = BSDF; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.kd = make_cuda_tensor(kd, p.gridSize); + p.arm = make_cuda_tensor(arm, p.gridSize); + p.pos = make_cuda_tensor(pos, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + p.min_roughness = min_roughness; + p.BSDF = BSDF; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad; + p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad); + p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad); + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad); +} + +//------------------------------------------------------------------------ +// filter_cubemap + +torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap) +{ + CHECK_TENSOR(cubemap, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(grad, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +torch::Tensor specular_bounds(int resolution, float costheta_cutoff) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularBoundsKernelParams p; + p.costheta_cutoff = costheta_cutoff; + p.gridSize = dim3(resolution, resolution, 6); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +//------------------------------------------------------------------------ +// loss function + +LossType strToLoss(std::string str) +{ + if (str == "mse") + return LOSS_MSE; + else if (str == "relmse") + return LOSS_RELMSE; + else if (str == "smape") + return LOSS_SMAPE; + else + return LOSS_L1; +} + +torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16) +{ + CHECK_TENSOR(img, 4, 3); + CHECK_TENSOR(target, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.out.fp16 = fp16; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts); + + p.img = make_cuda_tensor(img, p.gridSize); + p.target = make_cuda_tensor(target, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor img_grad, target_grad; + p.img = make_cuda_tensor(img, p.gridSize, &img_grad); + p.target = make_cuda_tensor(target, p.gridSize, &target_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(img_grad, target_grad); +} + +//------------------------------------------------------------------------ +// transform function + +torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16) +{ + CHECK_TENSOR(points, 3, 3); + CHECK_TENSOR(matrix, 3, 4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.out.fp16 = fp16; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts); + + p.points = make_cuda_tensor(points, p.gridSize); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor points_grad; + p.points = make_cuda_tensor(points, p.gridSize, &points_grad); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream)); + + return points_grad; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd"); + m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd"); + m.def("lambert_fwd", &lambert_fwd, "lambert_fwd"); + m.def("lambert_bwd", &lambert_bwd, "lambert_bwd"); + m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd"); + m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd"); + m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd"); + m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd"); + m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd"); + m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd"); + m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd"); + m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd"); + m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd"); + m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd"); + m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd"); + m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd"); + m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd"); + m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd"); + m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd"); + m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd"); + m.def("specular_bounds", &specular_bounds, "specular_bounds"); + m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd"); + m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd"); + m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd"); + m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd"); + m.def("xfm_fwd", &xfm_fwd, "xfm_fwd"); + m.def("xfm_bwd", &xfm_bwd, "xfm_bwd"); +} \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/vec3f.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/vec3f.h new file mode 100755 index 0000000000000000000000000000000000000000..7e6745430f19e9fe1834c8cd3dfeb6e68d730297 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/vec3f.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec3f +{ + float x, y, z; + +#ifdef __CUDACC__ + __device__ vec3f() { } + __device__ vec3f(float v) { x = v; y = v; z = v; } + __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; } + __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; } + + __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; } + __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; } + __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; } + __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; } +#endif +}; + +#ifdef __CUDACC__ +__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); } +__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); } +__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); } +__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); } +__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); } + +__device__ static inline float sum(vec3f a) +{ + return a.x + a.y + a.z; +} + +__device__ static inline vec3f cross(vec3f a, vec3f b) +{ + vec3f out; + out.x = a.y * b.z - a.z * b.y; + out.y = a.z * b.x - a.x * b.z; + out.z = a.x * b.y - a.y * b.x; + return out; +} + +__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out) +{ + d_a.x += d_out.z * b.y - d_out.y * b.z; + d_a.y += d_out.x * b.z - d_out.z * b.x; + d_a.z += d_out.y * b.x - d_out.x * b.y; + + d_b.x += d_out.y * a.z - d_out.z * a.y; + d_b.y += d_out.z * a.x - d_out.x * a.z; + d_b.z += d_out.x * a.y - d_out.y * a.x; +} + +__device__ static inline float dot(vec3f a, vec3f b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out) +{ + d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; + d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; +} + +__device__ static inline vec3f reflect(vec3f x, vec3f n) +{ + return n * 2.0f * dot(n, x) - x; +} + +__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out) +{ + d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); + d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); + d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); + + d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); + d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); + d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); +} + +__device__ static inline vec3f safeNormalize(vec3f v) +{ + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + return l > 0.0f ? (v / l) : vec3f(0.0f); +} + +__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out) +{ + + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + if (l > 0.0f) + { + float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); + d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; + d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; + d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; + } +} + +#endif \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/vec4f.h b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/vec4f.h new file mode 100755 index 0000000000000000000000000000000000000000..e3f30776af334597475002275b8b40c584a05035 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/c_src/vec4f.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec4f +{ + float x, y, z, w; + +#ifdef __CUDACC__ + __device__ vec4f() { } + __device__ vec4f(float v) { x = v; y = v; z = v; w = v; } + __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; } + __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; } +#endif +}; + diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/loss.py b/models/lrm/online_render/src/models/geometry/render/renderutils/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..92a24c02885380937762698eec578eb81bc80f9e --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/loss.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +#---------------------------------------------------------------------------- +# HDR image losses +#---------------------------------------------------------------------------- + +def _tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def _SMAPE(img, target, eps=0.01): + nom = torch.abs(img - target) + denom = torch.abs(img) + torch.abs(target) + 0.01 + return torch.mean(nom / denom) + +def _RELMSE(img, target, eps=0.1): + nom = (img - target) * (img - target) + denom = img * img + target * target + 0.1 + return torch.mean(nom / denom) + +def image_loss_fn(img, target, loss, tonemapper): + if tonemapper == 'log_srgb': + img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1)) + target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1)) + + if loss == 'mse': + return torch.nn.functional.mse_loss(img, target) + elif loss == 'smape': + return _SMAPE(img, target) + elif loss == 'relmse': + return _RELMSE(img, target) + else: + return torch.nn.functional.l1_loss(img, target) diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/ops.py b/models/lrm/online_render/src/models/geometry/render/renderutils/ops.py new file mode 100755 index 0000000000000000000000000000000000000000..a27c72b2a57dac8d3f1a563d80661917b42d6ec9 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/ops.py @@ -0,0 +1,554 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import os +import sys +import torch +import torch.utils.cpp_extension + +from .bsdf import * +from .loss import * + +#---------------------------------------------------------------------------- +# C++/Cuda plugin compiler/loader. + +_cached_plugin = None +def _get_plugin(): + # Return cached plugin if already loaded. + global _cached_plugin + if _cached_plugin is not None: + return _cached_plugin + + # Make sure we can find the necessary compiler and libary binaries. + if os.name == 'nt': + def find_cl_path(): + import glob + for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: + paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ['PATH'] += ';' + cl_path + + # Compiler options. + opts = ['-DNVDR_TORCH'] + + # Linker options. + if os.name == 'posix': + ldflags = ['-lcuda', '-lnvrtc'] + elif os.name == 'nt': + ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] + + # List of sources. + source_files = [ + 'c_src/mesh.cu', + 'c_src/loss.cu', + 'c_src/bsdf.cu', + 'c_src/normal.cu', + 'c_src/cubemap.cu', + 'c_src/common.cpp', + 'c_src/torch_bindings.cpp' + ] + + # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. + try: + lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock') + if os.path.exists(lock_fn): + print("Warning: Lock file exists in build directory: '%s'" % lock_fn) + except: + pass + + # Compile and load. + # source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] + # torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts, + # extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True) + + # Import, cache, and return the compiled module. + import renderutils_plugin + _cached_plugin = renderutils_plugin + return _cached_plugin + +#---------------------------------------------------------------------------- +# Internal kernels, just used for testing functionality + +class _fresnel_shlick_func(torch.autograd.Function): + @staticmethod + def forward(ctx, f0, f90, cosTheta): + out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False) + ctx.save_for_backward(f0, f90, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + f0, f90, cosTheta = ctx.saved_variables + return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,) + +def _fresnel_shlick(f0, f90, cosTheta, use_python=False): + if use_python: + out = bsdf_fresnel_shlick(f0, f90, cosTheta) + else: + out = _fresnel_shlick_func.apply(f0, f90, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN" + return out + + +class _ndf_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _ndf_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_ndf_ggx(alphaSqr, cosTheta) + else: + out = _ndf_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN" + return out + +class _lambda_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _lambda_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_lambda_ggx(alphaSqr, cosTheta) + else: + out = _lambda_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN" + return out + +class _masking_smith_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosThetaI, cosThetaO): + ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO) + out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables + return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,) + +def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False): + if use_python: + out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO) + else: + out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Shading normal setup (bump mapping + bent normals) + +class _prepare_shading_normal_func(torch.autograd.Function): + @staticmethod + def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl + out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False) + ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm) + return out + + @staticmethod + def backward(ctx, dout): + pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables + return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None) + +def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False): + '''Takes care of all corner cases and produces a final normal used for shading: + - Constructs tangent space + - Flips normal direction based on geometric normal for two sided Shading + - Perturbs shading normal by normal map + - Bends backfacing normals towards the camera to avoid shading artifacts + + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + pos: World space g-buffer position. + view_pos: Camera position in world space (typically using broadcasting). + perturbed_nrm: Trangent-space normal perturbation from normal map lookup. + smooth_nrm: Interpolated vertex normals. + smooth_tng: Interpolated vertex tangents. + geom_nrm: Geometric (face) normals. + two_sided_shading: Use one/two sided shading + opengl: Use OpenGL/DirectX normal map conventions + use_python: Use PyTorch implementation (for validation) + Returns: + Final shading normal + ''' + + if perturbed_nrm is None: + perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...] + + if use_python: + out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + else: + out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# BSDF functions + +class _lambert_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi): + out = _get_plugin().lambert_fwd(nrm, wi, False) + ctx.save_for_backward(nrm, wi) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi = ctx.saved_variables + return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,) + +def lambert(nrm, wi, use_python=False): + '''Lambertian bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_lambert(nrm, wi) + else: + out = _lambert_func.apply(nrm, wi) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _frostbite_diffuse_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi, wo, linearRoughness): + out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False) + ctx.save_for_backward(nrm, wi, wo, linearRoughness) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi, wo, linearRoughness = ctx.saved_variables + return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,) + +def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False): + '''Frostbite, normalized Disney Diffuse bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + wo: World space camera vector. + linearRoughness: Material roughness + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_frostbite(nrm, wi, wo, linearRoughness) + else: + out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _pbr_specular_func(torch.autograd.Function): + @staticmethod + def forward(ctx, col, nrm, wo, wi, alpha, min_roughness): + ctx.save_for_backward(col, nrm, wo, wi, alpha) + ctx.min_roughness = min_roughness + out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False) + return out + + @staticmethod + def backward(ctx, dout): + col, nrm, wo, wi, alpha = ctx.saved_variables + return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None) + +def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False): + '''Physically-based specular bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + col: Specular lobe color + nrm: World space shading normal. + wo: World space camera vector. + wi: World space light vector + alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1] + min_roughness: Scalar roughness clamping threshold + + use_python: Use PyTorch implementation (for validation) + Returns: + Shaded specular color + ''' + + if use_python: + out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness) + else: + out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN" + return out + +class _pbr_bsdf_func(torch.autograd.Function): + @staticmethod + def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos) + ctx.min_roughness = min_roughness + ctx.BSDF = BSDF + out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False) + return out + + @staticmethod + def backward(ctx, dout): + kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables + return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None) + +def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False): + '''Physically-based bsdf, both diffuse & specular lobes + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + kd: Diffuse albedo. + arm: Specular parameters (attenuation, linear roughness, metalness). + pos: World space position. + nrm: World space shading normal. + view_pos: Camera position in world space, typically using broadcasting. + light_pos: Light position in world space, typically using broadcasting. + min_roughness: Scalar roughness clamping threshold + bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite' + + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded color. + ''' + + BSDF = 0 + if bsdf == 'frostbite': + BSDF = 1 + + if use_python: + out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + else: + out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# cubemap filter with filtering across edges + +class _diffuse_cubemap_func(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + out = _get_plugin().diffuse_cubemap_fwd(cubemap) + ctx.save_for_backward(cubemap) + return out + + @staticmethod + def backward(ctx, dout): + cubemap, = ctx.saved_variables + cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout) + return cubemap_grad, None + +def diffuse_cubemap(cubemap, use_python=False): + if use_python: + assert False + else: + out = _diffuse_cubemap_func.apply(cubemap) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN" + return out + +class _specular_cubemap(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap, roughness, costheta_cutoff, bounds): + out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff) + ctx.save_for_backward(cubemap, bounds) + ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff + return out + + @staticmethod + def backward(ctx, dout): + cubemap, bounds = ctx.saved_variables + cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff) + return cubemap_grad, None, None, None + +# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy +def __ndfBounds(res, roughness, cutoff): + def ndfGGX(alphaSqr, costheta): + costheta = np.clip(costheta, 0.0, 1.0) + d = (costheta * alphaSqr - costheta) * costheta + 1.0 + return alphaSqr / (d * d * np.pi) + + # Sample out cutoff angle + nSamples = 1000000 + costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples)) + D = np.cumsum(ndfGGX(roughness**4, costheta)) + idx = np.argmax(D >= D[..., -1] * cutoff) + + # Brute force compute lookup table with bounds + bounds = _get_plugin().specular_bounds(res, costheta[idx]) + + return costheta[idx], bounds +__ndfBoundsDict = {} + +def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False): + assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape) + + if use_python: + assert False + else: + key = (cubemap.shape[1], roughness, cutoff) + if key not in __ndfBoundsDict: + __ndfBoundsDict[key] = __ndfBounds(*key) + out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key]) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN" + return out[..., 0:3] / out[..., 3:] + +#---------------------------------------------------------------------------- +# Fast image loss function + +class _image_loss_func(torch.autograd.Function): + @staticmethod + def forward(ctx, img, target, loss, tonemapper): + ctx.loss, ctx.tonemapper = loss, tonemapper + ctx.save_for_backward(img, target) + out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False) + return out + + @staticmethod + def backward(ctx, dout): + img, target = ctx.saved_variables + return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None) + +def image_loss(img, target, loss='l1', tonemapper='none', use_python=False): + '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + img: Input image. + target: Target (reference) image. + loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse'] + tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb'] + use_python: Use PyTorch implementation (for validation) + + Returns: + Image space loss (scalar value). + ''' + if use_python: + out = image_loss_fn(img, target, loss, tonemapper) + else: + out = _image_loss_func.apply(img, target, loss, tonemapper) + out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2]) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Transform points function + +class _xfm_func(torch.autograd.Function): + @staticmethod + def forward(ctx, points, matrix, isPoints): + ctx.save_for_backward(points, matrix) + ctx.isPoints = isPoints + return _get_plugin().xfm_fwd(points, matrix, isPoints, False) + + @staticmethod + def backward(ctx, dout): + points, matrix = ctx.saved_variables + return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None) + +def xfm_points(points, matrix, use_python=False): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + if use_python: + out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + else: + out = _xfm_func.apply(points, matrix, True) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + +def xfm_vectors(vectors, matrix, use_python=False): + '''Transform vectors. + Args: + vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + + Returns: + Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + + if use_python: + out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous() + else: + out = _xfm_func.apply(vectors, matrix, False) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN" + return out + + + diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_bsdf.py b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_bsdf.py new file mode 100755 index 0000000000000000000000000000000000000000..b0b60c350455717826c0f3edb01289b29baac27a --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_bsdf.py @@ -0,0 +1,296 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 4 +DTYPE = torch.float32 + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_normal(): + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True) + perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True) + smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True) + smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True) + geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" bent normal") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad) + relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad) + relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad) + relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad) + relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad) + +def test_schlick(): + f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f0_ref = f0_cuda.clone().detach().requires_grad_(True) + f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f90_ref = f90_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Fresnel shlick") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("f0:", f0_ref.grad, f0_cuda.grad) + relative_loss("f90:", f90_ref.grad, f90_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_ndf_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Ndf GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_lambda_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambda GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_masking_smith(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True) + cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Smith masking term") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad) + relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad) + +def test_lambert(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.lambert(normals_ref, wi_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.lambert(normals_cuda, wi_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambert") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + +def test_frostbite(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + rough_ref = rough_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Frostbite") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + relative_loss("rough:", rough_ref.grad, rough_cuda.grad) + +def test_pbr_specular(): + col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + col_ref = col_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alpha_ref = alpha_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr specular") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if col_ref.grad is not None: + relative_loss("col:", col_ref.grad, col_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if wi_ref.grad is not None: + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + if wo_ref.grad is not None: + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + if alpha_ref.grad is not None: + relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad) + +def test_pbr_bsdf(bsdf): + kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr BSDF") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if kd_ref.grad is not None: + relative_loss("kd:", kd_ref.grad, kd_cuda.grad) + if arm_ref.grad is not None: + relative_loss("arm:", arm_ref.grad, arm_cuda.grad) + if pos_ref.grad is not None: + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if view_ref.grad is not None: + relative_loss("view:", view_ref.grad, view_cuda.grad) + if light_ref.grad is not None: + relative_loss("light:", light_ref.grad, light_cuda.grad) + +test_normal() + +test_schlick() +test_ndf_ggx() +test_lambda_ggx() +test_masking_smith() + +test_lambert() +test_frostbite() +test_pbr_specular() +test_pbr_bsdf('lambert') +test_pbr_bsdf('frostbite') diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_loss.py b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..7a68f3fc4528431fe405d1d6077af0cb31687d31 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 8 +DTYPE = torch.float32 + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_loss(loss, tonemapper): + img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + img_ref = img_cuda.clone().detach().requires_grad_(True) + target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + target_ref = target_cuda.clone().detach().requires_grad_(True) + + ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True) + ref_loss.backward() + + cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Loss: %s, %s" % (loss, tonemapper)) + print("-------------------------------------------------------------") + + relative_loss("res:", ref_loss, cuda_loss) + relative_loss("img:", img_ref.grad, img_cuda.grad) + relative_loss("target:", target_ref.grad, target_cuda.grad) + + +test_loss('l1', 'none') +test_loss('l1', 'log_srgb') +test_loss('mse', 'log_srgb') +test_loss('smape', 'none') +test_loss('relmse', 'none') +test_loss('mse', 'none') \ No newline at end of file diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_mesh.py b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..4856c5ce07e2d6cd5f1fd463c1d3628791eafccc --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_mesh.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +BATCH = 8 +RES = 1024 +DTYPE = torch.float32 + +torch.manual_seed(0) + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item()) + +def test_xfm_points(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target) + ref_loss.backward() + + cuda_out = ru.xfm_points(points_cuda, mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + +def test_xfm_vectors(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + points_cuda_p = points_cuda.clone().detach().requires_grad_(True) + points_ref_p = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3]) + ref_loss.backward() + + cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3]) + cuda_loss.backward() + + ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True) + ref_loss_p = torch.nn.MSELoss()(ref_out_p, target) + ref_loss_p.backward() + + cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda) + cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target) + cuda_loss_p.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad) + +test_xfm_points() +test_xfm_vectors() diff --git a/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_perf.py b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_perf.py new file mode 100755 index 0000000000000000000000000000000000000000..ffc143e3004c0fd0a42a1941896823bc2bef939a --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/renderutils/tests/test_perf.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +DTYPE=torch.float32 + +def test_bsdf(BATCH, RES, ITR): + kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, RES, 3, device='cuda') + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + + print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES)) + + start.record() + for i in range(ITR): + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF python:", start.elapsed_time(end)) + + start.record() + for i in range(ITR): + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF cuda:", start.elapsed_time(end)) + +test_bsdf(1, 512, 1000) +test_bsdf(16, 512, 1000) +test_bsdf(1, 2048, 1000) diff --git a/models/lrm/online_render/src/models/geometry/render/util.py b/models/lrm/online_render/src/models/geometry/render/util.py new file mode 100755 index 0000000000000000000000000000000000000000..e292e91cf1cdd4b05b46f2f18b8a2bb14d2165ba --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/render/util.py @@ -0,0 +1,465 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio + +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/__init__.py b/models/lrm/online_render/src/models/geometry/rep_3d/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a3d5628a8433298477d1963f92578d47106b4a0f --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np + + +class Geometry(): + def __init__(self): + pass + + def forward(self): + pass diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/dmtet.py b/models/lrm/online_render/src/models/geometry/rep_3d/dmtet.py new file mode 100755 index 0000000000000000000000000000000000000000..b6a709380abac0bbf66fd1c8582485f3982223e4 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/dmtet.py @@ -0,0 +1,504 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .dmtet_utils import get_center_boundary_index +import torch.nn.functional as F + + +############################################################################### +# DMTet utility functions +############################################################################### +def create_mt_variable(device): + triangle_table = torch.tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device=device) + + num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) + base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) + return triangle_table, num_triangles_table, base_tet_edges, v_id + + +def sort_edges(edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + return torch.stack([a, b], -1) + + +############################################################################### +# marching tetrahedrons (differentiable) +############################################################################### + +def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + return verts, faces + + +def create_tetmesh_variables(device='cuda'): + tet_table = torch.tensor( + [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], + [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], + [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], + [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], + [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], + [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], + [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], + [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], + [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], + [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) + num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) + return tet_table, num_tets_table + + +def marching_tets_tetmesh( + pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, + return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + if not return_tet_mesh: + return verts, faces + occupied_verts = ori_v[occ_n] + mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") + tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) + + idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 + tet_verts = torch.cat([verts, occupied_verts], 0) + num_tets = num_tets_table[tetindex] + + tets = torch.cat( + ( + torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( + -1, + 4), + torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( + -1, + 4), + ), dim=0) + # add fully occupied tets + fully_occupied = occ_fx4.sum(-1) == 4 + tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] + tets = torch.cat([tets, tet_fully_occupied]) + + return verts, faces, tet_verts, tets + + +############################################################################### +# Compact tet grid +############################################################################### + +def compact_tets(pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + # Find surface tets + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets + + valid_vtx = tet_fx4[valid_tets].reshape(-1) + unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) + new_pos = pos_nx3[unique_vtx] + new_sdf = sdf_n[unique_vtx] + new_tets = idx_map.reshape(-1, 4) + return new_pos, new_sdf, new_tets + + +############################################################################### +# Subdivide volume +############################################################################### + +def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + device = tet_pos_bxnx3.device + # get new verts + tet_fx4 = tet_bxfx4[0] + edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] + all_edges = tet_fx4[:, edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + idx_map = idx_map + tet_pos_bxnx3.shape[1] + all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) + mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( + all_values.shape[0], -1, 2, + all_values.shape[-1]).mean(2) + new_v = torch.cat([all_values, mid_points_pos], 1) + new_v, new_sdf = new_v[..., :3], new_v[..., 3] + + # get new tets + + idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] + idx_ab = idx_map[0::6] + idx_ac = idx_map[1::6] + idx_ad = idx_map[2::6] + idx_bc = idx_map[3::6] + idx_bd = idx_map[4::6] + idx_cd = idx_map[5::6] + + tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) + tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) + tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) + tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) + tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) + tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) + tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) + tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) + + tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) + tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) + tet = tet_np.long().to(device) + + return new_v, tet, new_sdf + + +############################################################################### +# Adjacency +############################################################################### +def tet_to_tet_adj_sparse(tet_tx4): + # include self connection!!!!!!!!!!!!!!!!!!! + with torch.no_grad(): + t = tet_tx4.shape[0] + device = tet_tx4.device + idx_array = torch.LongTensor( + [0, 1, 2, + 1, 0, 3, + 2, 3, 0, + 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) + + # get all faces + all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( + -1, + 3) # (tx4, 3) + all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) + # sort and group + all_faces_sorted, _ = torch.sort(all_faces, dim=1) + + all_faces_unique, inverse_indices, counts = torch.unique( + all_faces_sorted, dim=0, return_counts=True, + return_inverse=True) + tet_face_fx3 = all_faces_unique[counts == 2] + counts = counts[inverse_indices] # tx4 + valid = (counts == 2) + + group = inverse_indices[valid] + # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) + _, indices = torch.sort(group) + all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] + tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) + + tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) + adj_self = torch.arange(t, device=tet_tx4.device) + adj_self = torch.stack([adj_self, adj_self], -1) + tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) + + tet_adj_idx = torch.unique(tet_adj_idx, dim=0) + values = torch.ones( + tet_adj_idx.shape[0], device=tet_tx4.device).float() + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + + # normalization + neighbor_num = 1.0 / torch.sparse.sum( + adj_sparse, dim=1).to_dense() + values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + return adj_sparse + + +############################################################################### +# Compact grid +############################################################################### + +def get_tet_bxfx4x3(bxnxz, bxfx4): + n_batch, z = bxnxz.shape[0], bxnxz.shape[2] + gather_input = bxnxz.unsqueeze(2).expand( + n_batch, bxnxz.shape[1], 4, z) + gather_index = bxfx4.unsqueeze(-1).expand( + n_batch, bxfx4.shape[1], 4, z).long() + tet_bxfx4xz = torch.gather( + input=gather_input, dim=1, index=gather_index) + + return tet_bxfx4xz + + +def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + with torch.no_grad(): + assert tet_pos_bxnx3.shape[0] == 1 + + occ = grid_sdf[0] > 0 + occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) + mask = (occ_sum > 0) & (occ_sum < 4) + + # build connectivity graph + adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) + mask = mask.float().unsqueeze(-1) + + # Include a one ring of neighbors + for i in range(1): + mask = torch.sparse.mm(adj_matrix, mask) + mask = mask.squeeze(-1) > 0 + + mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) + new_tet_bxfx4 = tet_bxfx4[:, mask].long() + selected_verts_idx = torch.unique(new_tet_bxfx4) + new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] + mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) + new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) + new_grid_sdf = grid_sdf[:, selected_verts_idx] + return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf + + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], + (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], + (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +def sdf_reg_loss_batch(sdf, all_edges): + sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +############################################################################### +# Geometry interface +############################################################################### +class DMTetGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(DMTetGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + tets = np.load('data/tets/%d_compress.npz' % (grid_res)) + self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) + # Make sure the tet is zero-centered and length is equal to 1 + length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] + length = length.max() + mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 + self.verts = (self.verts - mid.unsqueeze(dim=0)) / length + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + self.indices = torch.from_numpy(tets['tets']).long().to(self.device) + self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) + self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) + # Parameters for regularization computation + edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) + all_edges = self.indices[:, edges].reshape(-1, 2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces = marching_tets( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces + + def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces, tet_verts, tets = marching_tets_tetmesh( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, + num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces, tet_verts, tets + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/dmtet_utils.py b/models/lrm/online_render/src/models/geometry/rep_3d/dmtet_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..8d466a9e78c49d947c115707693aa18d759885ad --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/dmtet_utils.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch + + +def get_center_boundary_index(verts): + length_ = torch.sum(verts ** 2, dim=-1) + center_idx = torch.argmin(length_) + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + boundary_idx = torch.nonzero(boundary) + return center_idx, boundary_idx.squeeze(dim=-1) diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/extract_texture_map.py b/models/lrm/online_render/src/models/geometry/rep_3d/extract_texture_map.py new file mode 100755 index 0000000000000000000000000000000000000000..aadea1f018fc00b1824e2d498f0c59504de3298f --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/extract_texture_map.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import numpy as np +import nvdiffrast.torch as dr + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(int) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/flexicubes.py b/models/lrm/online_render/src/models/geometry/rep_3d/flexicubes.py new file mode 100755 index 0000000000000000000000000000000000000000..26d7b91b6266d802baaf55b64238629cd0f740d0 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/flexicubes.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + device (str): Specifies the computational device (default is "cuda"). + dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners (torch.Tensor): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (torch.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + """ + + def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + def construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = torch.arange(8).to(self.device) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = torch.ones(res, device=self.device) + + res = torch.tensor([res], dtype=torch.float, device=self.device) + coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 + verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) + cubes = (base_cube_f.unsqueeze(0) + + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) + + verts_rounded = torch.round(verts * 10**5) / (10**5) + verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes + + def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, + gamma_f=None, training=False, output_tetmesh=False, grad_func=None): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return torch.zeros( + (0, 3), + device=self.device), torch.zeros( + (0, 4), + dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( + (0, 3), + dtype=torch.long, device=self.device), torch.zeros( + (0), + device=self.device) + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) + else: + beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha_fx8 is not None: + alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) + else: + alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + if grad_func is not None: + normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) + vd = [] + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + + if grad_func is not None: + with torch.no_grad(): + cube_e_verts_idx = idx_map[cur_cubes] + curr_edge_group[~curr_mask] = 0 + + verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) + verts_group_idx[verts_group_idx == -1] = 0 + verts_group_pos = torch.index_select( + input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) + v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) + curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) + verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) + + normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( + -1, num.item(), 7, + 3) + curr_mask = curr_mask.squeeze(2) + vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, + verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + + if grad_func is not None: + vd = torch.cat(vd) + L_dev = torch.zeros([1], device=self.device) + else: + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + with torch.no_grad(): + vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( + 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) + gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( + 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 + vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / + weight_sum.unsqueeze(-1)).squeeze(1) + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = torch.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ + s_edges < 0]] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) + + tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = (occ_sum == 8) + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = torch.arange( + inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), + dtype=torch.long, device=x_nx3.device) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = torch.sort(group) + cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, + device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] + edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( + 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] + + tets = torch.cat([tets_surface, tets_inside]) + vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/flexicubes_geometry.py b/models/lrm/online_render/src/models/geometry/rep_3d/flexicubes_geometry.py new file mode 100755 index 0000000000000000000000000000000000000000..e1242119d22c0177578f1ef95be7ee7cf5da9b8c --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/flexicubes_geometry.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +import nvdiffrast.torch as dr +from . import Geometry +from .flexicubes import FlexiCubes # replace later +from .dmtet import sdf_reg_loss_batch +from . import mesh +import torch.nn.functional as F +from src.utils import render + +def get_center_boundary_index(grid_res, device): + v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True + center_indices = torch.nonzero(v.reshape(-1)) + + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = torch.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(FlexiCubesGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + self.fc = FlexiCubes(device, weight_scale=0.5) + self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + + all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + self.all_edges = torch.unique(all_edges, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) + self.renderer = renderer + self.render_type = render_type + self.ctx = dr.RasterizeCudaContext(device=device) + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + @torch.no_grad() + def map_uv(self, face_gidx, max_idx): + N = int(np.ceil(np.sqrt((max_idx+1)//2))) + tex_y, tex_x = torch.meshgrid( + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda") + ) + + pad = 0.9 / N + + uvs = torch.stack([ + tex_x , tex_y, + tex_x + pad, tex_y, + tex_x + pad, tex_y + pad, + tex_x , tex_y + pad + ], dim=-1).view(-1, 2) + + def _idx(tet_idx, N): + x = tet_idx % N + y = torch.div(tet_idx, N, rounding_mode='floor') + return y * N + x + + tet_idx = _idx(torch.div(face_gidx, N, rounding_mode='floor'), N) + tri_idx = face_gidx % 2 + + uv_idx = torch.stack(( + tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 + ), dim = -1). view(-1, 3) + + return uvs, uv_idx + + def rotate_x(self, a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + def rotate_z(self, a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, -s, 0, 0], + [ s, c, 0, 0], + [ 0, 0, 1, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + def rotate_y(self, a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, + beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], training=is_training + ) + + face_gidx = torch.arange(faces.shape[0], dtype=torch.long, device="cuda") + uvs, uv_idx = self.map_uv(face_gidx, faces.shape[0]) + # breakpoint() + + verts = verts @ self.rotate_x(np.pi / 2, device=verts.device)[:3,:3] + verts = verts @ self.rotate_y(np.pi / 2, device=verts.device)[:3,:3] + + + + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx) + + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + + return verts, faces, v_reg_loss, imesh + + + # def render_mesh(self, mesh_v_nx3, mesh_f_fx3, mesh, camera_mv_bx4x4, camera_pos, resolution=256, hierarchical_mask=False): + # return_value = dict() + # if self.render_type == 'neural_render': + # tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal = self.renderer.render_mesh( + # mesh_v_nx3.unsqueeze(dim=0), + # mesh_f_fx3.int(), + # mesh, + # camera_mv_bx4x4, + # camera_pos, + # mesh_v_nx3.unsqueeze(dim=0), + # resolution=resolution, + # device=self.device, + # hierarchical_mask=hierarchical_mask + # ) + + # return_value['tex_pos'] = tex_pos + # return_value['mask'] = mask + # return_value['hard_mask'] = hard_mask + # return_value['rast'] = rast + # return_value['v_pos_clip'] = v_pos_clip + # return_value['mask_pyramid'] = mask_pyramid + # return_value['depth'] = depth + # return_value['normal'] = normal + # return_value['gb_normal'] = gb_normal + # else: + # raise NotImplementedError + + # return return_value + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, mesh, camera_mv_bx4x4, camera_pos, env, planes, kd_fn, materials, resolution=256, hierarchical_mask=False): + return_value = dict() + # if self.render_type == 'neural_render': + # tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal = self.renderer.render_mesh( + # mesh_v_nx3.unsqueeze(dim=0), + # mesh_f_fx3.int(), + # mesh, + # camera_mv_bx4x4, + # camera_pos, + # mesh_v_nx3.unsqueeze(dim=0), + # resolution=resolution, + # device=self.device, + # hierarchical_mask=hierarchical_mask + # ) + + # return_value['tex_pos'] = tex_pos + # return_value['mask'] = mask + # return_value['hard_mask'] = hard_mask + # return_value['rast'] = rast + # return_value['v_pos_clip'] = v_pos_clip + # return_value['mask_pyramid'] = mask_pyramid + # return_value['depth'] = depth + # return_value['normal'] = normal + # return_value['gb_normal'] = gb_normal + # else: + # raise NotImplementedError + buffer_dict = render.render_mesh(self.ctx, mesh, camera_mv_bx4x4, camera_pos, env, planes, kd_fn, materials, [resolution, resolution], spp=1, num_layers=1, msaa=True, background=None) + + return buffer_dict + + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/light.py b/models/lrm/online_render/src/models/geometry/rep_3d/light.py new file mode 100755 index 0000000000000000000000000000000000000000..766ab0a9e4e4fc42f379ac94d765059508cff97e --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/light.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util +from . import renderutils as ru + +###################################################################################### +# Utility functions +###################################################################################### + +class cubemap_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + return util.avg_pool_nhwc(cubemap, (2,2)) + + @staticmethod + def backward(ctx, dout): + res = dout.shape[1] * 2 + out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda") + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + indexing='ij') + v = util.safe_normalize(util.cube_to_dir(s, gx, gy)) + out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') + return out + +###################################################################################### +# Split-sum environment map light source with automatic mipmap generation +###################################################################################### + +class EnvironmentLight(torch.nn.Module): + LIGHT_MIN_RES = 16 + + MIN_ROUGHNESS = 0.08 + MAX_ROUGHNESS = 0.5 + + def __init__(self, base): + super(EnvironmentLight, self).__init__() + self.mtx = None + self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True) + self.register_parameter('env_base', self.base) + + def xfm(self, mtx): + self.mtx = mtx + + def clone(self): + return EnvironmentLight(self.base.clone().detach()) + + def clamp_(self, min=None, max=None): + self.base.clamp_(min, max) + + def get_mip(self, roughness): + return torch.where(roughness < self.MAX_ROUGHNESS + , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2) + , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2) + + def build_mips(self, cutoff=0.99): + self.specular = [self.base] + while self.specular[-1].shape[1] > self.LIGHT_MIN_RES: + self.specular += [cubemap_mip.apply(self.specular[-1])] + + self.diffuse = ru.diffuse_cubemap(self.specular[-1]) + + for idx in range(len(self.specular) - 1): + roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS + self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff) + self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff) + + def regularizer(self): + white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0 + return torch.mean(torch.abs(self.base - white)) + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + wo = util.safe_normalize(view_pos - gb_pos) + + if specular: + roughness = ks[..., 1:2] # y component + metallic = ks[..., 2:3] # z component + spec_col = (1.0 - metallic)*0.04 + kd * metallic + diff_col = kd * (1.0 - metallic) + else: + diff_col = kd + + reflvec = util.safe_normalize(util.reflect(wo, gb_normal)) + nrmvec = gb_normal + if self.mtx is not None: # Rotate lookup + mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') + reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) + nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) + + # Diffuse lookup + diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube') + shaded_col = diffuse * diff_col + + if specular: + # Lookup FG term from lookup texture + NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4) + fg_uv = torch.cat((NdotV, roughness), dim=-1) + if not hasattr(self, '_FG_LUT'): + self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') + fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp') + + # Roughness adjusted specular env lookup + miplevel = self.get_mip(roughness) + spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') + + # Compute aggregate lighting + reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2] + shaded_col += spec * reflectance + + return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility + +###################################################################################### +# Load and store +###################################################################################### + +# Load from latlong .HDR file +def _load_env_hdr(fn, scale=1.0): + latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale + cubemap = util.latlong_to_cubemap(latlong_img, [512, 512]) + + l = EnvironmentLight(cubemap) + l.build_mips() + + return l + +def load_env(fn, scale=1.0): + if os.path.splitext(fn)[1].lower() == ".hdr": + return _load_env_hdr(fn, scale) + else: + assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1] + +def save_env_map(fn, light): + assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently" + if isinstance(light, EnvironmentLight): + color = util.cubemap_to_latlong(light.base, [512, 1024]) + util.save_image_raw(fn, color.detach().cpu().numpy()) + +###################################################################################### +# Create trainable env map with random initialization +###################################################################################### + +def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): + base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias + return EnvironmentLight(base) + diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/material.py b/models/lrm/online_render/src/models/geometry/rep_3d/material.py new file mode 100755 index 0000000000000000000000000000000000000000..64772e578493f41e5c94e432d906d9be23325221 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/material.py @@ -0,0 +1,182 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import util +from . import texture + +###################################################################################### +# Wrapper to make materials behave like a python dict, but register textures as +# torch.nn.Module parameters. +###################################################################################### +class Material(torch.nn.Module): + def __init__(self, mat_dict): + super(Material, self).__init__() + self.mat_keys = set() + for key in mat_dict.keys(): + self.mat_keys.add(key) + self[key] = mat_dict[key] + + def __contains__(self, key): + return hasattr(self, key) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, val): + self.mat_keys.add(key) + setattr(self, key, val) + + def __delitem__(self, key): + self.mat_keys.remove(key) + delattr(self, key) + + def keys(self): + return self.mat_keys + +###################################################################################### +# .mtl material format loading / storing +###################################################################################### +@torch.no_grad() +def load_mtl(fn, clear_ks=True): + import re + mtl_path = os.path.dirname(fn) + + # Read file + with open(fn, 'r') as f: + lines = f.readlines() + + # Parse materials + materials = [] + for line in lines: + split_line = re.split(' +|\t+|\n+', line.strip()) + prefix = split_line[0].lower() + data = split_line[1:] + if 'newmtl' in prefix: + material = Material({'name' : data[0]}) + materials += [material] + elif materials: + if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: + material[prefix] = data[0] + else: + material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + + # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps + for mat in materials: + if not 'bsdf' in mat: + mat['bsdf'] = 'pbr' + + if 'map_kd' in mat: + mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) + else: + mat['kd'] = texture.Texture2D(mat['kd']) + + if 'map_ks' in mat: + mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) + else: + mat['ks'] = texture.Texture2D(mat['ks']) + + if 'bump' in mat: + mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) + + # Convert Kd from sRGB to linear RGB + mat['kd'] = texture.srgb_to_rgb(mat['kd']) + + if clear_ks: + # Override ORM occlusion (red) channel by zeros. We hijack this channel + for mip in mat['ks'].getMips(): + mip[..., 0] = 0.0 + + return materials + +@torch.no_grad() +def save_mtl(fn, material): + folder = os.path.dirname(fn) + with open(fn, "w") as f: + f.write('newmtl defaultMat\n') + if material is not None: + f.write('bsdf %s\n' % material['bsdf']) + if 'kd' in material.keys(): + f.write('map_Kd texture_kd.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) + if 'ks' in material.keys(): + f.write('map_Ks texture_ks.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) + if 'normal' in material.keys(): + f.write('bump texture_n.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) + else: + f.write('Kd 1 1 1\n') + f.write('Ks 0 0 0\n') + f.write('Ka 0 0 0\n') + f.write('Tf 1 1 1\n') + f.write('Ni 1\n') + f.write('Ns 0\n') + +###################################################################################### +# Merge multiple materials into a single uber-material +###################################################################################### + +def _upscale_replicate(x, full_res): + x = x.permute(0, 3, 1, 2) + x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') + return x.permute(0, 2, 3, 1).contiguous() + +def merge_materials(materials, texcoords, tfaces, mfaces): + assert len(materials) > 0 + for mat in materials: + assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" + assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" + + uber_material = Material({ + 'name' : 'uber_material', + 'bsdf' : materials[0]['bsdf'], + }) + + textures = ['kd', 'ks', 'normal'] + + # Find maximum texture resolution across all materials and textures + max_res = None + for mat in materials: + for tex in textures: + tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) + max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res + + # Compute size of compund texture and round up to nearest PoT + full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int) + + # Normalize texture resolution across all materials & combine into a single large texture + for tex in textures: + if tex in materials[0]: + tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x + tex_data = _upscale_replicate(tex_data, full_res) + uber_material[tex] = texture.Texture2D(tex_data) + + # Compute scaling values for used / unused texture area + s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] + + # Recompute texture coordinates to cooincide with new composite texture + new_tverts = {} + new_tverts_data = [] + for fi in range(len(tfaces)): + matIdx = mfaces[fi] + for vi in range(3): + ti = tfaces[fi][vi] + if not (ti in new_tverts): + new_tverts[ti] = {} + if not (matIdx in new_tverts[ti]): # create new vertex + new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex + + return uber_material, new_tverts_data, tfaces + diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/mesh.py b/models/lrm/online_render/src/models/geometry/rep_3d/mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..2009b8b938dc251586fbd665bff716f11cf9616b --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/mesh.py @@ -0,0 +1,238 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import obj +from . import util + +###################################################################################### +# Base mesh class +###################################################################################### +class Mesh: + def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, material=None, base=None): + self.v_pos = v_pos + self.v_nrm = v_nrm + self.v_tex = v_tex + self.v_tng = v_tng + self.t_pos_idx = t_pos_idx + self.t_nrm_idx = t_nrm_idx + self.t_tex_idx = t_tex_idx + self.t_tng_idx = t_tng_idx + self.material = material + + if base is not None: + self.copy_none(base) + + def copy_none(self, other): + if self.v_pos is None: + self.v_pos = other.v_pos + if self.t_pos_idx is None: + self.t_pos_idx = other.t_pos_idx + if self.v_nrm is None: + self.v_nrm = other.v_nrm + if self.t_nrm_idx is None: + self.t_nrm_idx = other.t_nrm_idx + if self.v_tex is None: + self.v_tex = other.v_tex + if self.t_tex_idx is None: + self.t_tex_idx = other.t_tex_idx + if self.v_tng is None: + self.v_tng = other.v_tng + if self.t_tng_idx is None: + self.t_tng_idx = other.t_tng_idx + if self.material is None: + self.material = other.material + + def clone(self): + out = Mesh(base=self) + if out.v_pos is not None: + out.v_pos = out.v_pos.clone().detach() + if out.t_pos_idx is not None: + out.t_pos_idx = out.t_pos_idx.clone().detach() + if out.v_nrm is not None: + out.v_nrm = out.v_nrm.clone().detach() + if out.t_nrm_idx is not None: + out.t_nrm_idx = out.t_nrm_idx.clone().detach() + if out.v_tex is not None: + out.v_tex = out.v_tex.clone().detach() + if out.t_tex_idx is not None: + out.t_tex_idx = out.t_tex_idx.clone().detach() + if out.v_tng is not None: + out.v_tng = out.v_tng.clone().detach() + if out.t_tng_idx is not None: + out.t_tng_idx = out.t_tng_idx.clone().detach() + return out + +###################################################################################### +# Mesh loeading helper +###################################################################################### + +def load_mesh(filename, mtl_override=None): + name, ext = os.path.splitext(filename) + if ext == ".obj": + return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override) + assert False, "Invalid mesh file extension" + +###################################################################################### +# Compute AABB +###################################################################################### +def aabb(mesh): + return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values + +###################################################################################### +# Compute unique edge list from attribute/vertex index list +###################################################################################### +def compute_edges(attr_idx, return_inverse=False): + with torch.no_grad(): + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Eliminate duplicates and return inverse mapping + return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) + +###################################################################################### +# Compute unique edge to face mapping from attribute/vertex index list +###################################################################################### +def compute_edge_to_face_mapping(attr_idx, return_inverse=False): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +###################################################################################### +# Align base mesh to reference mesh:move & rescale to match bounding boxes. +###################################################################################### +def unit_size(mesh): + with torch.no_grad(): + vmin, vmax = aabb(mesh) + scale = 2 / torch.max(vmax - vmin).item() + v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin + v_pos = v_pos * scale # Rescale to unit size + + return Mesh(v_pos, base=mesh) + +###################################################################################### +# Center & scale mesh for rendering +###################################################################################### +def center_by_reference(base_mesh, ref_aabb, scale): + center = (ref_aabb[0] + ref_aabb[1]) * 0.5 + scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() + v_pos = (base_mesh.v_pos - center[None, ...]) * scale + return Mesh(v_pos, base=base_mesh) + +###################################################################################### +# Simple smooth vertex normal computation +###################################################################################### +def auto_normals(imesh): + + i0 = imesh.t_pos_idx[:, 0] + i1 = imesh.t_pos_idx[:, 1] + i2 = imesh.t_pos_idx[:, 2] + + v0 = imesh.v_pos[i0, :] + v1 = imesh.v_pos[i1, :] + v2 = imesh.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(imesh.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + v_nrm = util.safe_normalize(v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh) + +###################################################################################### +# Compute tangent space from texture map coordinates +# Follows http://www.mikktspace.com/ conventions +###################################################################################### +def compute_tangents(imesh): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0,3): + pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]] + tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]] + vn_idx[i] = imesh.t_nrm_idx[:, i] + + tangents = torch.zeros_like(imesh.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) + denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) + + # Update all 3 vertices + for i in range(0,3): + idx = vn_idx[i][:, None].repeat(1,3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + + # Normalize and make sure tangent is perpendicular to normal + tangents = util.safe_normalize(tangents) + tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/obj.py b/models/lrm/online_render/src/models/geometry/rep_3d/obj.py new file mode 100755 index 0000000000000000000000000000000000000000..a33fbb9e66c69706ad39049e2ea8e5a7c425971c --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/obj.py @@ -0,0 +1,176 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import torch + +from . import texture +from . import mesh +from . import material + +###################################################################################### +# Utility functions +###################################################################################### + +def _find_mat(materials, name): + for mat in materials: + if mat['name'] == name: + return mat + return materials[0] # Materials 0 is the default + +###################################################################################### +# Create mesh object from objfile +###################################################################################### + +def load_obj(filename, clear_ks=True, mtl_override=None): + obj_path = os.path.dirname(filename) + + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # Load materials + all_materials = [ + { + 'name' : '_default_mat', + 'bsdf' : 'pbr', + 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), + 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) + } + ] + if mtl_override is None: + for line in lines: + if len(line.split()) == 0: + continue + if line.split()[0] == 'mtllib': + all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library + else: + all_materials += material.load_mtl(mtl_override) + + # load vertices + vertices, texcoords, normals = [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'vn': + normals.append([float(v) for v in line.split()[1:]]) + + # load faces + activeMatIdx = None + used_materials = [] + faces, tfaces, nfaces, mfaces = [], [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + mat = _find_mat(all_materials, line.split()[1]) + if not mat in used_materials: + used_materials.append(mat) + activeMatIdx = used_materials.index(mat) + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + n0 = int(vv[2]) - 1 if vv[2] != "" else -1 + for i in range(nv - 2): # Triangulate polygons + vv = vs[i + 1].split('/') + v1 = int(vv[0]) - 1 + t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + vv = vs[i + 2].split('/') + v2 = int(vv[0]) - 1 + t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + mfaces.append(activeMatIdx) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + assert len(tfaces) == len(faces) and len(nfaces) == len (faces) + + # Create an "uber" material by combining all textures into a larger texture + if len(used_materials) > 1: + uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) + else: + uber_material = used_materials[0] + + vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') + texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None + normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None + + faces = torch.tensor(faces, dtype=torch.int64, device='cuda') + tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None + nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None + + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + +###################################################################################### +# Save mesh object to objfile +###################################################################################### + +def write_obj(folder, mesh, save_material=True): + obj_file = os.path.join(folder, 'mesh.obj') + print("Writing mesh: ", obj_file) + with open(obj_file, "w") as f: + f.write("mtllib mesh.mtl\n") + f.write("g default\n") + + v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None + v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None + v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None + + t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None + t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None + t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None: + print(" writing %d texcoords" % len(v_tex)) + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + if v_nrm is not None: + print(" writing %d normals" % len(v_nrm)) + assert(len(t_pos_idx) == len(t_nrm_idx)) + for v in v_nrm: + f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + if save_material: + mtl_file = os.path.join(folder, 'mesh.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, mesh.material) + + print("Done exporting mesh") diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/tables.py b/models/lrm/online_render/src/models/geometry/rep_3d/tables.py new file mode 100755 index 0000000000000000000000000000000000000000..5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/texture.py b/models/lrm/online_render/src/models/geometry/rep_3d/texture.py new file mode 100755 index 0000000000000000000000000000000000000000..4e4a39d042dc4d356c47133efee897088b9ce5c6 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/texture.py @@ -0,0 +1,186 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util + +###################################################################################### +# Smooth pooling / mip computation with linear gradient upscaling +###################################################################################### + +class texture2d_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, texture): + return util.avg_pool_nhwc(texture, (2,2)) + + @staticmethod + def backward(ctx, dout): + gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), + torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), + indexing='ij') + uv = torch.stack((gx, gy), dim=-1) + return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') + +######################################################################################################## +# Simple texture class. A texture can be either +# - A 3D tensor (using auto mipmaps) +# - A list of 3D tensors (full custom mip hierarchy) +######################################################################################################## + +class Texture2D(torch.nn.Module): + # Initializes a texture from image data. + # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) + def __init__(self, init, min_max=None): + super(Texture2D, self).__init__() + + if isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + elif isinstance(init, list) and len(init) == 1: + init = init[0] + + if isinstance(init, list): + self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init) + elif len(init.shape) == 4: + self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True) + elif len(init.shape) == 3: + self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True) + elif len(init.shape) == 1: + self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor + else: + assert False, "Invalid texture object" + + self.min_max = min_max + + # Filtered (trilinear) sample texture at a given location + def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): + if isinstance(self.data, list): + out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) + else: + if self.data.shape[1] > 1 and self.data.shape[2] > 1: + mips = [self.data] + while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: + mips += [texture2d_mip.apply(mips[-1])] + out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) + else: + out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) + return out + + def getRes(self): + return self.getMips()[0].shape[1:3] + + def getChannels(self): + return self.getMips()[0].shape[3] + + def getMips(self): + if isinstance(self.data, list): + return self.data + else: + return [self.data] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + if self.min_max is not None: + for mip in self.getMips(): + for i in range(mip.shape[-1]): + mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) + + # In-place clamp with no derivative to make sure values are in valid range after training + def normalize_(self): + with torch.no_grad(): + for mip in self.getMips(): + mip = util.safe_normalize(mip) + +######################################################################################################## +# Helper function to create a trainable texture from a regular texture. The trainable weights are +# initialized with texture data as an initial guess +######################################################################################################## + +def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): + with torch.no_grad(): + if isinstance(init, Texture2D): + assert isinstance(init.data, torch.Tensor) + min_max = init.min_max if min_max is None else min_max + init = init.data + elif isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + + # Pad to NHWC if needed + if len(init.shape) == 1: # Extend constant to NHWC tensor + init = init[None, None, None, :] + elif len(init.shape) == 3: + init = init[None, ...] + + # Scale input to desired resolution. + if res is not None: + init = util.scale_img_nhwc(init, res) + + # Genreate custom mipchain + if not auto_mipmaps: + mip_chain = [init.clone().detach().requires_grad_(True)] + while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: + new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] + mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] + return Texture2D(mip_chain, min_max=min_max) + else: + return Texture2D(init, min_max=min_max) + +######################################################################################################## +# Convert texture to and from SRGB +######################################################################################################## + +def srgb_to_rgb(texture): + return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) + +def rgb_to_srgb(texture): + return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) + +######################################################################################################## +# Utility functions for loading / storing a texture +######################################################################################################## + +def _load_mip2D(fn, lambda_fn=None, channels=None): + imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') + if channels is not None: + imgdata = imgdata[..., 0:channels] + if lambda_fn is not None: + imgdata = lambda_fn(imgdata) + return imgdata.detach().clone() + +def load_texture2D(fn, lambda_fn=None, channels=None): + base, ext = os.path.splitext(fn) + if os.path.exists(base + "_0" + ext): + mips = [] + while os.path.exists(base + ("_%d" % len(mips)) + ext): + mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] + return Texture2D(mips) + else: + return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + +def _save_mip2D(fn, mip, mipidx, lambda_fn): + if lambda_fn is not None: + data = lambda_fn(mip).detach().cpu().numpy() + else: + data = mip.detach().cpu().numpy() + + if mipidx is None: + util.save_image(fn, data) + else: + base, ext = os.path.splitext(fn) + util.save_image(base + ("_%d" % mipidx) + ext, data) + +def save_texture2D(fn, tex, lambda_fn=None): + if isinstance(tex.data, list): + for i, mip in enumerate(tex.data): + _save_mip2D(fn, mip[0,...], i, lambda_fn) + else: + _save_mip2D(fn, tex.data[0,...], None, lambda_fn) diff --git a/models/lrm/online_render/src/models/geometry/rep_3d/util.py b/models/lrm/online_render/src/models/geometry/rep_3d/util.py new file mode 100755 index 0000000000000000000000000000000000000000..c4e512ad110849ec3ed6344b53f9c422fc303096 --- /dev/null +++ b/models/lrm/online_render/src/models/geometry/rep_3d/util.py @@ -0,0 +1,466 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio + +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + size = tuple(int(s) for s in size) + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + diff --git a/models/lrm/online_render/src/models/renderer/__init__.py b/models/lrm/online_render/src/models/renderer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/models/lrm/online_render/src/models/renderer/synthesizer.py b/models/lrm/online_render/src/models/renderer/synthesizer.py new file mode 100755 index 0000000000000000000000000000000000000000..8db9fbdb1703b566117d227c8e4eef04157ccc93 --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/synthesizer.py @@ -0,0 +1,203 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import ImportanceRenderer +from .utils.ray_sampler import RaySampler + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + self.net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1 + 3), + ) + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def forward(self, sampled_features, ray_directions): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + x = x.contiguous().view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # renderings + self.renderer = ImportanceRenderer() + self.ray_sampler = RaySampler() + + # modules + self.decoder = OSGDecoder(n_features=triplane_dim) + + def forward(self, planes, cameras, render_size=128, crop_params=None): + # planes: (N, 3, D', H', W') + # cameras: (N, M, D_cam) + # render_size: int + assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" + N, M = cameras.shape[:2] + + cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) + intrinsics = cameras[..., 16:25].view(N, M, 3, 3) + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler( + cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), + intrinsics=intrinsics.reshape(-1, 3, 3), + render_size=render_size, + ) + assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + # Crop rays if crop_params is available + if crop_params is not None: + ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) + ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) + i, j, h, w = crop_params + ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + + # Perform volume rendering + rgb_samples, depth_samples, weights_samples = self.renderer( + planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, + ) + + # Reshape into 'raw' neural-rendered image + if crop_params is not None: + Himg, Wimg = crop_params[2:] + else: + Himg = Wimg = render_size + rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() + depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + + out = { + 'images_rgb': rgb_images, + 'images_depth': depth_images, + 'images_weight': weight_images, + } + return out + + def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = torch.tensor([ + [self.rendering_kwargs['sampler_bbox_min']] * 3, + [self.rendering_kwargs['sampler_bbox_max']] * 3, + ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append(torch.stack(torch.meshgrid( + torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), + indexing='ij', + ), dim=-1).reshape(-1, 3)) + cube_grid = torch.stack(grid_points, dim=0).to(planes.device) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = { + k: v.reshape(N, grid_size, grid_size, grid_size, -1) + for k, v in features.items() + } + return features + + def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i:i+chunk_size] + + # query triplane + chunk_out = self.renderer.run_model_activated( + planes=planes, + decoder=self.decoder, + sample_coordinates=chunk_points, + sample_directions=torch.zeros_like(chunk_points), + options=self.rendering_kwargs, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = { + k: torch.cat([out[k] for out in outs], dim=1) + for k in outs[0].keys() + } + return point_features diff --git a/models/lrm/online_render/src/models/renderer/synthesizer_mesh.py b/models/lrm/online_render/src/models/renderer/synthesizer_mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..a4bc9f555049bc9c02934343434e1fa262e55762 --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/synthesizer_mesh.py @@ -0,0 +1,156 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + + self.net_sdf = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1), + ) + self.net_rgb = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_material = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 2), + ) + self.net_deformation = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_weight = nn.Sequential( + nn.Linear(8 * 3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 21), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_geometry_prediction(self, sampled_features, flexicubes_indices): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + sdf = self.net_sdf(sampled_features) + deformation = self.net_deformation(sampled_features) + + grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) + grid_features = grid_features.reshape( + sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) + weight = self.net_weight(grid_features) * 0.1 + + return sdf, deformation, weight + + def get_texture_prediction(self, sampled_features): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + rgb = self.net_rgb(sampled_features) + rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + + materials = self.net_material(sampled_features) + materials = torch.sigmoid(materials) + metallic, roughness = materials[...,0], materials[...,1] + rmax, rmin = 1.0, 0.04 ** 2 + roughness = roughness * (rmax - rmin) + rmin + + return rgb, metallic, roughness + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # modules + self.plane_axes = generate_planes() + self.decoder = OSGDecoder(n_features=triplane_dim) + + def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) + return sdf, deformation, weight + + def get_texture_prediction(self, planes, sample_coordinates): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + rgb, matellic, roughness = self.decoder.get_texture_prediction(sampled_features) + return rgb, matellic, roughness diff --git a/models/lrm/online_render/src/models/renderer/utils/__init__.py b/models/lrm/online_render/src/models/renderer/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/utils/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/models/lrm/online_render/src/models/renderer/utils/math_utils.py b/models/lrm/online_render/src/models/renderer/utils/math_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/utils/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/models/lrm/online_render/src/models/renderer/utils/ray_marcher.py b/models/lrm/online_render/src/models/renderer/utils/ray_marcher.py new file mode 100755 index 0000000000000000000000000000000000000000..ea1db43478de703509cdd04c684f92f8e283c5ad --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/utils/ray_marcher.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MipRayMarcher2(nn.Module): + def __init__(self, activation_factory): + super().__init__() + self.activation_factory = activation_factory + + def run_forward(self, colors, densities, depths, rendering_options, normals=None): + dtype = colors.dtype + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + # using factory mode for better usability + densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta).to(dtype) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + weights = weights.to(dtype) + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = torch.sum(weights * depths_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights + + + def forward(self, colors, densities, depths, rendering_options, normals=None): + if normals is not None: + composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) + return composite_rgb, composite_depth, composite_normals, weights + + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + return composite_rgb, composite_depth, weights diff --git a/models/lrm/online_render/src/models/renderer/utils/ray_sampler.py b/models/lrm/online_render/src/models/renderer/utils/ray_sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..ae5151dda467e826ce346986bd486d4465c906f2 --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/utils/ray_sampler.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + dtype = cam2world_matrix.dtype + device = cam2world_matrix.device + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=dtype, device=device), + torch.arange(render_size, dtype=dtype, device=device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.ones((N, M), dtype=dtype, device=device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs + + +class OrthoRaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, ortho_scale, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + ortho_scale: float + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 3) + """ + + N, M = cam2world_matrix.shape[0], render_size**2 + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.zeros((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - 0.5) * ortho_scale + y_lift = (y_cam - 0.5) * ortho_scale + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs_cam = torch.stack([ + torch.zeros((N, M), device=cam2world_matrix.device), + torch.zeros((N, M), device=cam2world_matrix.device), + torch.ones((N, M), device=cam2world_matrix.device), + ], dim=-1) + ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) + + return ray_origins, ray_dirs diff --git a/models/lrm/online_render/src/models/renderer/utils/renderer.py b/models/lrm/online_render/src/models/renderer/utils/renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..95c4c728efbd0283b8ddd7dc6a1b28d1510efa97 --- /dev/null +++ b/models/lrm/online_render/src/models/renderer/utils/renderer.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ray_marcher import MipRayMarcher2 +from . import math_utils + + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + dtype = plane_features.dtype + + coordinates = (2/box_warp) * coordinates # add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample( + plane_features, + projected_coordinates.to(dtype), + mode=mode, + padding_mode=padding_mode, + align_corners=False, + ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample( + grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', + padding_mode='zeros', + align_corners=False, + ) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class ImportanceRenderer(torch.nn.Module): + """ + Modified original version to filter out-of-box samples as TensoRF does. + + Reference: + TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 + """ + def __init__(self): + super().__init__() + self.activation_factory = self._build_activation_factory() + self.ray_marcher = MipRayMarcher2(self.activation_factory) + self.plane_axes = generate_planes() + + def _build_activation_factory(self): + def activation_factory(options: dict): + if options['clamp_mode'] == 'softplus': + return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better + else: + assert False, "Renderer only supports `clamp_mode`=`softplus`!" + return activation_factory + + def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, + planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): + """ + Additional filtering is applied to filter out-of-box samples. + Modifications made by Zexin He. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + device = depths.device + + # define sample points with depths + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + # filter out-of-box samples + mask_inbox = \ + (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ + (sample_coordinates <= rendering_options['sampler_bbox_max']) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 3 + DATA_TYPE = _out['sigma'].dtype + colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD + colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) + + return colors_pass, densities_pass + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): + # self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + # Coarse Pass + colors_coarse, densities_coarse = self._forward_pass( + depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + colors_fine, densities_fine = self._forward_pass( + depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + else: + rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + return rgb_final, depth_final, weights.sum(2) + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) + out['sigma'] = self.activation_factory(options)(out['sigma']) + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + if normals1 is not None and normals2 is not None: + all_normals = torch.cat([normals1, normals2], dim = -2) + else: + all_normals = None + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + if all_normals is not None: + all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) + return all_depths, all_colors, all_normals, all_densities + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom 0 and radius > 0 + + elevation = np.deg2rad(elevation) + + camera_positions = [] + for i in range(M): + azimuth = 2 * np.pi * i / M + x = radius * np.cos(elevation) * np.cos(azimuth) + y = radius * np.cos(elevation) * np.sin(azimuth) + z = radius * np.sin(elevation) + camera_positions.append([x, y, z]) + camera_positions = np.array(camera_positions) + camera_positions = torch.from_numpy(camera_positions).float() + extrinsics = center_looking_at_camera_pose(camera_positions) + return extrinsics + + +def FOV_to_intrinsics(fov, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics + + +def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) + elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/models/lrm/online_render/src/utils/infer_util.py b/models/lrm/online_render/src/utils/infer_util.py new file mode 100755 index 0000000000000000000000000000000000000000..f2faf2bf3b12d4af7b33cb2292da2b5ed62eb52e --- /dev/null +++ b/models/lrm/online_render/src/utils/infer_util.py @@ -0,0 +1,97 @@ +import os +import imageio +import rembg +import torch +import numpy as np +import PIL.Image +from PIL import Image +from typing import Any + + +def remove_background(image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def images_to_video( + images: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + video_dir = os.path.dirname(output_path) + video_name = os.path.basename(output_path) + os.makedirs(video_dir, exist_ok=True) + + frames = [] + for i in range(len(images)): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) + + +def save_video( + frames: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] + writer = imageio.get_writer(output_path, fps=fps) + for frame in frames: + writer.append_data(frame) + writer.close() \ No newline at end of file diff --git a/models/lrm/online_render/src/utils/material.py b/models/lrm/online_render/src/utils/material.py new file mode 100755 index 0000000000000000000000000000000000000000..8e465efc7f86f37f71e6b17fb25bbf069b82a38b --- /dev/null +++ b/models/lrm/online_render/src/utils/material.py @@ -0,0 +1,197 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from ..models.geometry.rep_3d import util +from . import texture + +###################################################################################### +# Wrapper to make materials behave like a python dict, but register textures as +# torch.nn.Module parameters. +###################################################################################### +class Material(torch.nn.Module): + def __init__(self, mat_dict): + super(Material, self).__init__() + self.mat_keys = set() + for key in mat_dict.keys(): + self.mat_keys.add(key) + self[key] = mat_dict[key] + + def __contains__(self, key): + return hasattr(self, key) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, val): + self.mat_keys.add(key) + setattr(self, key, val) + + def __delitem__(self, key): + self.mat_keys.remove(key) + delattr(self, key) + + def keys(self): + return self.mat_keys + +###################################################################################### +# .mtl material format loading / storing +###################################################################################### +@torch.no_grad() +def load_mtl(fn, clear_ks=True): + import re + mtl_path = os.path.dirname(fn) + + # Read file + with open(fn, 'r') as f: + lines = f.readlines() + + # Parse materials + materials = [] + for line in lines: + split_line = re.split(' +|\t+|\n+', line.strip()) + prefix = split_line[0].lower() + data = split_line[1:] + if 'newmtl' in prefix: + material = Material({'name' : data[0]}) + materials += [material] + elif materials: + if 'map_d' in prefix: + # 设置透明度为1.0,即完全不透明 + material['d'] = torch.tensor(1.0, dtype=torch.float32, device='cuda') + elif 'map_ke' in prefix: + # 设置自发光为0 + material['Ke'] = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda') + elif 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: + material[prefix] = data[0] + else: + material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + + # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps + for mat in materials: + if not 'bsdf' in mat: + mat['bsdf'] = 'pbr' + + if 'map_kd' in mat: + mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) + else: + mat['kd'] = texture.Texture2D(mat['kd']) + + if 'map_ks' in mat: + mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) + else: + mat['ks'] = texture.Texture2D(mat['ks']) + + if 'bump' in mat: + mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) + + # Convert Kd from sRGB to linear RGB + mat['kd'] = texture.srgb_to_rgb(mat['kd']) + + if clear_ks: + # Override ORM occlusion (red) channel by zeros. We hijack this channel + for mip in mat['ks'].getMips(): + mip[..., 0] = 0.0 + + return materials + +@torch.no_grad() +def save_mtl(fn, material): + folder = os.path.dirname(fn) + with open(fn, "w") as f: + f.write('newmtl defaultMat\n') + if material is not None: + f.write('bsdf %s\n' % material['bsdf']) + if 'kd' in material.keys(): + f.write('map_Kd texture_kd.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) + if 'ks' in material.keys(): + f.write('map_Ks texture_ks.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) + if 'normal' in material.keys(): + f.write('bump texture_n.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) + else: + f.write('Kd 1 1 1\n') + f.write('Ks 0 0 0\n') + f.write('Ka 0 0 0\n') + f.write('Tf 1 1 1\n') + f.write('Ni 1\n') + f.write('Ns 0\n') + +###################################################################################### +# Merge multiple materials into a single uber-material +###################################################################################### + +def _upscale_replicate(x, full_res): + x = x.permute(0, 3, 1, 2) + x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') + return x.permute(0, 2, 3, 1).contiguous() + +def merge_materials(materials, texcoords, tfaces, mfaces): + assert len(materials) > 0 + for mat in materials: + assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" + assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" + + uber_material = Material({ + 'name' : 'uber_material', + 'bsdf' : materials[0]['bsdf'], + }) + + textures = ['kd', 'ks', 'normal'] + + # Find maximum texture resolution across all materials and textures + max_res = None + for mat in materials: + for tex in textures: + tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) + max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res + + # Compute size of compund texture and round up to nearest PoT + full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(int) + + # Normalize texture resolution across all materials & combine into a single large texture + for tex in textures: + if tex in materials[0]: + # breakpoint() + tex_data_list = [] + for mat in materials: + if tex in mat: + scaled_tex = util.scale_img_nhwc(mat[tex].data, tuple(max_res)) + if scaled_tex.shape[-1] != 3: + scaled_tex = scaled_tex[:, :, :, :3] + tex_data_list.append(scaled_tex) + # tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x + tex_data = torch.cat(tuple(tex_data_list), dim=2) # 将所有纹理水平排列,NHWC 的 dim2 是 x 轴 + tex_data = _upscale_replicate(tex_data, full_res) + uber_material[tex] = texture.Texture2D(tex_data) + + # Compute scaling values for used / unused texture area + s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] + + # Recompute texture coordinates to cooincide with new composite texture + new_tverts = {} + new_tverts_data = [] + for fi in range(len(tfaces)): + matIdx = mfaces[fi] + for vi in range(3): + ti = tfaces[fi][vi] + if not (ti in new_tverts): + new_tverts[ti] = {} + if not (matIdx in new_tverts[ti]): # create new vertex + new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex + + return uber_material, new_tverts_data, tfaces + diff --git a/models/lrm/online_render/src/utils/mesh.py b/models/lrm/online_render/src/utils/mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..165147b80a51d097ecf124177ee7d17172865f80 --- /dev/null +++ b/models/lrm/online_render/src/utils/mesh.py @@ -0,0 +1,255 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import obj +from ..models.geometry.rep_3d import util + +###################################################################################### +# Base mesh class +###################################################################################### +class Mesh: + def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, material=None, base=None): + self.v_pos = v_pos + self.v_nrm = v_nrm + self.v_tex = v_tex + self.v_tng = v_tng + self.t_pos_idx = t_pos_idx + self.t_nrm_idx = t_nrm_idx + self.t_tex_idx = t_tex_idx + self.t_tng_idx = t_tng_idx + self.material = material + + if base is not None: + self.copy_none(base) + + def copy_none(self, other): + if self.v_pos is None: + self.v_pos = other.v_pos + if self.t_pos_idx is None: + self.t_pos_idx = other.t_pos_idx + if self.v_nrm is None: + self.v_nrm = other.v_nrm + if self.t_nrm_idx is None: + self.t_nrm_idx = other.t_nrm_idx + if self.v_tex is None: + self.v_tex = other.v_tex + if self.t_tex_idx is None: + self.t_tex_idx = other.t_tex_idx + if self.v_tng is None: + self.v_tng = other.v_tng + if self.t_tng_idx is None: + self.t_tng_idx = other.t_tng_idx + if self.material is None: + self.material = other.material + + def clone(self): + out = Mesh(base=self) + if out.v_pos is not None: + out.v_pos = out.v_pos.clone().detach() + if out.t_pos_idx is not None: + out.t_pos_idx = out.t_pos_idx.clone().detach() + if out.v_nrm is not None: + out.v_nrm = out.v_nrm.clone().detach() + if out.t_nrm_idx is not None: + out.t_nrm_idx = out.t_nrm_idx.clone().detach() + if out.v_tex is not None: + out.v_tex = out.v_tex.clone().detach() + if out.t_tex_idx is not None: + out.t_tex_idx = out.t_tex_idx.clone().detach() + if out.v_tng is not None: + out.v_tng = out.v_tng.clone().detach() + if out.t_tng_idx is not None: + out.t_tng_idx = out.t_tng_idx.clone().detach() + return out + def rotate_x_90(self): + # 定义绕X轴旋转90度的旋转矩阵 + rotate_x = torch.tensor([[1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=self.v_pos.device) + + # 将旋转矩阵应用到顶点坐标 + if self.v_pos is not None: + v_pos_homo = torch.cat((self.v_pos, torch.ones(self.v_pos.shape[0], 1, device=self.v_pos.device)), dim=1) + v_pos_rotated = v_pos_homo @ rotate_x.T + self.v_pos = v_pos_rotated[:, :3] + + # 将旋转矩阵应用到法线 + if self.v_nrm is not None: + v_nrm_homo = torch.cat((self.v_nrm, torch.zeros(self.v_nrm.shape[0], 1, device=self.v_nrm.device)), dim=1) + v_nrm_rotated = v_nrm_homo @ rotate_x.T + self.v_nrm = v_nrm_rotated[:, :3] +###################################################################################### +# Mesh loeading helper +###################################################################################### + +def load_mesh(filename, mtl_override=None): + name, ext = os.path.splitext(filename) + if ext == ".obj": + return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override) + assert False, "Invalid mesh file extension" + +###################################################################################### +# Compute AABB +###################################################################################### +def aabb(mesh): + return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values + +###################################################################################### +# Compute unique edge list from attribute/vertex index list +###################################################################################### +def compute_edges(attr_idx, return_inverse=False): + with torch.no_grad(): + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Eliminate duplicates and return inverse mapping + return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) + +###################################################################################### +# Compute unique edge to face mapping from attribute/vertex index list +###################################################################################### +def compute_edge_to_face_mapping(attr_idx, return_inverse=False): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +###################################################################################### +# Align base mesh to reference mesh:move & rescale to match bounding boxes. +###################################################################################### +def unit_size(mesh): + with torch.no_grad(): + vmin, vmax = aabb(mesh) + scale = 2 / torch.max(vmax - vmin).item() + v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin + v_pos = v_pos * scale # Rescale to unit size + + return Mesh(v_pos, base=mesh) + +###################################################################################### +# Center & scale mesh for rendering +###################################################################################### +def center_by_reference(base_mesh, ref_aabb, scale): + center = (ref_aabb[0] + ref_aabb[1]) * 0.5 + scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() + v_pos = (base_mesh.v_pos - center[None, ...]) * scale + return Mesh(v_pos, base=base_mesh) + +###################################################################################### +# Simple smooth vertex normal computation +###################################################################################### +def auto_normals(imesh): + + i0 = imesh.t_pos_idx[:, 0] + i1 = imesh.t_pos_idx[:, 1] + i2 = imesh.t_pos_idx[:, 2] + + v0 = imesh.v_pos[i0, :] + v1 = imesh.v_pos[i1, :] + v2 = imesh.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(imesh.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + v_nrm = util.safe_normalize(v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh) + +###################################################################################### +# Compute tangent space from texture map coordinates +# Follows http://www.mikktspace.com/ conventions +###################################################################################### +def compute_tangents(imesh): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0,3): + pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]] + tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]] + vn_idx[i] = imesh.t_nrm_idx[:, i] + + tangents = torch.zeros_like(imesh.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) + denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) + + # Update all 3 vertices + for i in range(0,3): + idx = vn_idx[i][:, None].repeat(1,3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + + # Normalize and make sure tangent is perpendicular to normal + tangents = util.safe_normalize(tangents) + tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) diff --git a/models/lrm/online_render/src/utils/mesh_util.py b/models/lrm/online_render/src/utils/mesh_util.py new file mode 100755 index 0000000000000000000000000000000000000000..0ec4663eeaa5c54209e08771969ec4f2a739c0b4 --- /dev/null +++ b/models/lrm/online_render/src/utils/mesh_util.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import trimesh +import cv2 +import numpy as np +import nvdiffrast.torch as dr +from PIL import Image + + +def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + facenp_fx3 = facenp_fx3[:, [2, 1, 0]] + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'obj') + + +def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'glb') + + +def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + import os + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = '%s/%s.mtl' % (fol, na) + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write('map_Kd %s.png\n' % na) + fid.close() + #### + + fid = open(fname, 'w') + fid.write('mtllib %s.mtl\n' % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write('vt %f %f\n' % (pp[0], pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') + + +def loadobj(meshfile): + v = [] + f = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if len(data) != 4: + continue + if data[0] == 'v': + v.append([float(d) for d in data[1:]]) + if data[0] == 'f': + data = [da.split('/')[0] for da in data] + f.append([int(d) for d in data[1:]]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + return pointnp_px3, facenp_fx3 + + +def loadobjtex(meshfile): + v = [] + vt = [] + f = [] + ft = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): + continue + if data[0] == 'v': + assert len(data) == 4 + + v.append([float(d) for d in data[1:]]) + if data[0] == 'vt': + if len(data) == 3 or len(data) == 4: + vt.append([float(d) for d in data[1:3]]) + if data[0] == 'f': + data = [da.split('/') for da in data] + if len(data) == 4: + f.append([int(d[0]) for d in data[1:]]) + ft.append([int(d[1]) for d in data[1:]]) + elif len(data) == 5: + idx1 = [1, 2, 3] + data1 = [data[i] for i in idx1] + f.append([int(d[0]) for d in data1]) + ft.append([int(d[1]) for d in data1]) + idx2 = [1, 3, 4] + data2 = [data[i] for i in idx2] + f.append([int(d[0]) for d in data2]) + ft.append([int(d[1]) for d in data2]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + uvs = np.array(vt, dtype=np.float32) + return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/models/lrm/online_render/src/utils/obj.py b/models/lrm/online_render/src/utils/obj.py new file mode 100755 index 0000000000000000000000000000000000000000..0748581d3d2370ae80d31f456566cc192d0ac876 --- /dev/null +++ b/models/lrm/online_render/src/utils/obj.py @@ -0,0 +1,225 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import torch + +from . import texture +from . import mesh +from . import material + +###################################################################################### +# Utility functions +###################################################################################### + +def _find_mat(materials, name): + for mat in materials: + if mat['name'] == name: + return mat + return materials[0] # Materials 0 is the default + + +def normalize_mesh(vertices): + # 计算边界框 + min_vals, _ = torch.min(vertices, dim=0) + max_vals, _ = torch.max(vertices, dim=0) + + # 计算中心点 + center = (max_vals + min_vals) / 2 + + # 平移顶点 + vertices = vertices - center + + # 计算缩放因子 + max_extent = torch.max(max_vals - min_vals) + scale = 2.0 / max_extent + + # 缩放顶点 + vertices = vertices * scale + + return vertices + +###################################################################################### +# Create mesh object from objfile +###################################################################################### +def rotate_y_90(v_pos): + # 定义绕X轴旋转90度的旋转矩阵 + rotate_y = torch.tensor([[0, 0, 1, 0], + [0, 1, 0, 0], + [-1, 0, 0, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=v_pos.device) + return rotate_y + +def load_obj(filename, clear_ks=True, mtl_override=None, return_attributes=False, path_is_attributrs=False): + read_normal = True + obj_path = os.path.dirname(filename) + + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # Load materials + all_materials = [ + { + 'name' : '_default_mat', + 'bsdf' : 'pbr', + 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), + 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) + } + ] + if mtl_override is None: + for line in lines: + if len(line.split()) == 0: + continue + if line.split()[0] == 'mtllib': + all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library + else: + all_materials += material.load_mtl(mtl_override) + + # load vertices + vertices, texcoords, normals = [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'vn': + normals.append([float(v) for v in line.split()[1:]]) + + # load faces + activeMatIdx = None + used_materials = [] + faces, tfaces, nfaces, mfaces = [], [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + mat = _find_mat(all_materials, line.split()[1]) + if not mat in used_materials: + used_materials.append(mat) + activeMatIdx = used_materials.index(mat) + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + try: + n0 = int(vv[2]) - 1 if vv[2] != "" else -1 + except: + read_normal = False + for i in range(nv - 2): # Triangulate polygons + vv = vs[i + 1].split('/') + v1 = int(vv[0]) - 1 + t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + if read_normal==False: + pass + else: + n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + vv = vs[i + 2].split('/') + v2 = int(vv[0]) - 1 + t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + if read_normal==False: + pass + else: + n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + mfaces.append(activeMatIdx) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + if read_normal==False: + pass + else: + nfaces.append([n0, n1, n2]) + if read_normal==False: + pass + else: + assert len(tfaces) == len(faces) and len(nfaces) == len (faces) + + # Create an "uber" material by combining all textures into a larger texture + if len(used_materials) > 1: + uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) + else: + uber_material = used_materials[0] + + vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') + texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None + normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None + + faces = torch.tensor(faces, dtype=torch.int64, device='cuda') + tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None + nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None + + vertices = normalize_mesh(vertices) + # vertices = vertices @ rotate_y_90(vertices)[:3,:3] + + if return_attributes: + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material), vertices, faces, normals, nfaces, texcoords, tfaces, uber_material + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + +###################################################################################### +# Save mesh object to objfile +###################################################################################### + +def write_obj(folder, mesh, save_material=True): + obj_file = os.path.join(folder, 'mesh.obj') + print("Writing mesh: ", obj_file) + with open(obj_file, "w") as f: + f.write("mtllib mesh.mtl\n") + f.write("g default\n") + + v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None + v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None + v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None + + t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None + t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None + t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None: + print(" writing %d texcoords" % len(v_tex)) + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + if v_nrm is not None: + print(" writing %d normals" % len(v_nrm)) + assert(len(t_pos_idx) == len(t_nrm_idx)) + for v in v_nrm: + f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + if save_material: + mtl_file = os.path.join(folder, 'mesh.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, mesh.material) + + print("Done exporting mesh") diff --git a/models/lrm/online_render/src/utils/render.py b/models/lrm/online_render/src/utils/render.py new file mode 100755 index 0000000000000000000000000000000000000000..a59e97f2c99f2dd91ed89d3cf5c8c6f402b689d1 --- /dev/null +++ b/models/lrm/online_render/src/utils/render.py @@ -0,0 +1,386 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import nvdiffrast.torch as dr + +from . import render_utils +from ..models.geometry.render import renderutils as ru +import numpy as np +from PIL import Image + +# ============================================================================================== +# Helper functions +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def get_mip(roughness): + return torch.where(roughness < 1.0 + , (torch.clamp(roughness, 0.04, 1.0) - 0.04) / (1.0 - 0.04) * (6 - 2) + , (torch.clamp(roughness, 1.0, 1.0) - 1.0) / (1.0 - 1.0) + 6 - 2) + +def shade_with_env(gb_pos, gb_normal, kd, metallic, roughness, view_pos, run_n_view, env, metallic_gt, roughness_gt, use_material_gt=True, gt_render=False): + + #mask = mask[..., 0] + view_pos = view_pos.expand(-1, gb_pos.shape[1], gb_pos.shape[2], -1) #.reshape(1, 512, 10240, 3) + + wo = render_utils.safe_normalize(view_pos - gb_pos) + + #roughness = roughness.reshape(roughness.shape[0], roughness.shape[1], roughness.shape[1], run_n_view, 1) + #metallic = metallic.reshape(metallic.shape[0], metallic.shape[1], metallic.shape[1], run_n_view, 1) + #kd = kd.reshape(kd.shape[0], kd.shape[1], kd.shape[1], run_n_view, 3) + # if len(diffuse_light) != 10: + # diffuse_light = [diffuse_light[0] for _ in range(10)] + # specular_light = [specular_light[0] for _ in range(10)] + + # metallic_gt = torch.zeros((8, metallic_gt.shape[1], metallic_gt.shape[2], 1)).cuda() + # roughness_gt = torch.zeros((8, roughness_gt.shape[1], roughness_gt.shape[2], 1)).cuda() + + #if use_material_gt: + spec_col = (1.0 - metallic_gt)*0.04 + kd * metallic_gt + diff_col = kd * (1.0 - metallic_gt) + # else: + + # spec_col = (1.0 - metallic)*0.04 + kd * metallic + # diff_col = kd * (1.0 - metallic) + + nrmvec = gb_normal + reflvec = render_utils.safe_normalize(render_utils.reflect(wo, nrmvec)) + + prb_rendered_list = [] + pbr_specular_light_list = [] + pbr_diffuse_light_list = [] + # pbr_specular_color_list = [] + # pbr_diffuse_color_list = [] + for i in range(run_n_view): + specular_light, diffuse_light = env[i] + diffuse_light = diffuse_light.cuda() + specular_light_new = [] + for split_specular_light in specular_light: + specular_light_new.append(split_specular_light.cuda()) + specular_light = specular_light_new + + shaded_col = torch.ones((gb_pos.shape[1], gb_pos.shape[2], 3)).cuda() + + diffuse = dr.texture(diffuse_light[None, ...], nrmvec[i,:,:,:][None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') + diffuse_comp = diffuse * diff_col[i,:,:,:][None, ...] + + # Lookup FG term from lookup texture + NdotV = torch.clamp(render_utils.dot(wo[i,:,:,:], nrmvec[i,:,:,:]), min=1e-4) + fg_uv = torch.cat((NdotV, roughness_gt[i,:,:,:]), dim=-1) + #if not hasattr(self, '_FG_LUT'): + _FG_LUT = torch.as_tensor(np.fromfile('/root/ComfyUI/custom_nodes/ComfyUI-3D-Pack/Gen_3D_Modules/LRM/online_render/src/data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') + fg_lookup = dr.texture(_FG_LUT, fg_uv[None, ...], filter_mode='linear', boundary_mode='clamp') + # Roughness adjusted specular env lookup + + miplevel = get_mip(roughness_gt[i,:,:,:]) + miplevel = miplevel[None, ...] + spec = dr.texture(specular_light[0][None, ...], reflvec[i,:,:,:][None, ...].contiguous(), mip=list(m[None, ...] for m in specular_light[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') + + # Compute aggregate lighting + reflectance = spec_col[i,:,:,:][None, ...] * fg_lookup[...,0:1] + fg_lookup[...,1:2] + specular_comp = spec * reflectance + #shaded_col += spec * reflectance + shaded_col = (specular_comp[0] + diffuse_comp[0]) + + prb_rendered_list.append(shaded_col) + pbr_specular_light_list.append(spec[0]) + pbr_diffuse_light_list.append(diffuse[0]) + + # pbr_specular_color_list.append(metallic_gt[i].repeat(1,1,3)) + # pbr_diffuse_color_list.append(roughness_gt[i].repeat(1,1,3)) + + + shaded_col_all = torch.stack(prb_rendered_list, dim=0) + pbr_specular_light = torch.stack(pbr_specular_light_list, dim=0) + pbr_diffuse_light = torch.stack(pbr_diffuse_light_list, dim=0) + # pbr_specular_color = torch.stack(pbr_specular_color_list, dim=0) + # pbr_diffuse_color = torch.stack(pbr_diffuse_color_list, dim=0) + + #shaded_col_all = shaded_col_all.reshape(shaded_col_all.shape[0], shaded_col_all.shape[1], shaded_col_all.shape[1]*run_n_view, 3) + shaded_col_all = render_utils.rgb_to_srgb(shaded_col_all).clamp(0.,1.) + pbr_specular_light = render_utils.rgb_to_srgb(pbr_specular_light).clamp(0.,1.) + pbr_diffuse_light = render_utils.rgb_to_srgb(pbr_diffuse_light).clamp(0.,1.) + # pbr_specular_color = render_utils.rgb_to_srgb(pbr_specular_color).clamp(0.,1.) + # pbr_diffuse_color = render_utils.rgb_to_srgb(pbr_diffuse_color).clamp(0.,1.) + + return shaded_col_all, pbr_specular_light, pbr_diffuse_light #, pbr_specular_color, pbr_diffuse_color + +# ============================================================================================== +# pixel shader +# ============================================================================================== +def shade( + gb_pos, + gb_geometric_normal, + gb_normal, + gb_tangent, + gb_texc, + gb_texc_deriv, + view_pos, + env, + planes, + kd_fn, + materials, + material, + mask, + gt_render + ): + + ################################################################################ + # Texture lookups + ################################################################################ + perturbed_nrm = None + resolution = gb_pos.shape[1] + N_views = view_pos.shape[0] + + if planes is None: + kd = material['kd'].sample(gb_texc, gb_texc_deriv) + + matellic_gt, roughness_gt = (materials[0] * torch.ones(*kd.shape[:-1])).unsqueeze(-1).cuda(), (materials[1] * torch.ones(*kd.shape[:-1])).unsqueeze(-1).cuda() + matellic, roughness = None, None + else: + # predict kd with MLP and interpolated feature + gb_pos_interp, mask = [gb_pos], [mask] + gb_pos_interp = [torch.cat([pos[i_view:i_view + 1] for i_view in range(N_views)], dim=2) for pos in gb_pos_interp] + mask = [torch.cat([ma[i_view:i_view + 1] for i_view in range(N_views)], dim=2) for ma in mask] + kd, matellic, roughness = kd_fn( planes[None,...], gb_pos_interp, mask[0]) + kd = torch.cat( [torch.cat([kd[i:i + 1, :, resolution * i_view: resolution * (i_view + 1)]for i_view in range(N_views)], dim=0) for i in range(len(kd))], dim=0) + + matellic_val = [x[0] for x in materials] + roughness_val = [y[1] for y in materials] + + matellic_gt = torch.full((N_views, resolution, resolution, 1), fill_value=0, dtype=torch.float32) + roughness_gt = torch.full((N_views, resolution, resolution, 1), fill_value=0, dtype=torch.float32) + + for i in range(len(matellic_gt)): + matellic_gt[i, :, :, 0].fill_(matellic_val[i]) + roughness_gt[i, :, :, 0].fill_(roughness_val[i]) + + matellic_gt = matellic_gt.cuda() + roughness_gt = roughness_gt.cuda() + + # Separate kd into alpha and color, default alpha = 1 + alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) + kd = kd[..., 0:3].clamp(0., 1.) + + ################################################################################ + # Normal perturbation & normal bend + ################################################################################ + #if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']: + perturbed_nrm = None + + gb_normal_ = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) + + ################################################################################ + # Evaluate BSDF + ################################################################################ + + shaded_col, spec_light, diff_light = shade_with_env(gb_pos, gb_normal_, kd, matellic, roughness, view_pos, N_views, env, matellic_gt, roughness_gt, use_material_gt=True, gt_render=gt_render) + + buffers = { + 'shaded' : torch.cat((shaded_col, alpha), dim=-1), + 'spec_light': torch.cat((spec_light, alpha), dim=-1), + 'diff_light': torch.cat((diff_light, alpha), dim=-1), + 'gb_normal' : torch.cat((gb_normal_, alpha), dim=-1), + 'normal' : torch.cat((gb_normal, alpha), dim=-1), + 'albedo' : torch.cat((kd, alpha), dim=-1), + # 'spec_albedo': torch.cat((spec_albedo, alpha), dim=-1), + # 'diff_albedo': torch.cat((diff_albedo, alpha), dim=-1), + } + return buffers + +# ============================================================================================== +# Render a depth slice of the mesh (scene), some limitations: +# - Single mesh +# - Single light +# - Single material +# ============================================================================================== +def render_layer( + rast, + rast_deriv, + mesh, + view_pos, + env, + planes, + kd_fn, + materials, + v_pos_clip, + resolution, + spp, + msaa, + gt_render + ): + + full_res = [resolution[0]*spp, resolution[1]*spp] + + ################################################################################ + # Rasterize + ################################################################################ + + # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution + if spp > 1 and msaa: + rast_out_s = render_utils.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') + rast_out_deriv_s = render_utils.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp + else: + rast_out_s = rast + rast_out_deriv_s = rast_deriv + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) + + # Compute geometric normals. We need those because of bent normals trick (for bump mapping) + v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] + v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] + v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] + face_normals = render_utils.safe_normalize(torch.cross(v1 - v0, v2 - v0)) + face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) + + # Compute tangent space + assert mesh.v_nrm is not None and mesh.v_tng is not None + gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) + gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents + + # Texture coordinate + assert mesh.v_tex is not None + gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s) + + # render depth + depth = torch.linalg.norm(view_pos.expand_as(gb_pos) - gb_pos, dim=-1) + + mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias(mask.clone().contiguous(), rast, v_pos_clip,mesh.t_pos_idx.int()) + + ################################################################################ + # Shade + ################################################################################ + + buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, view_pos, env, planes, kd_fn, materials, mesh.material, mask, gt_render) + buffers['depth'] = torch.cat((depth.unsqueeze(-1).repeat(1,1,1,3), torch.ones_like(gb_pos[..., 0:1])), dim=-1 ) + # print(gb_pos.shape) + buffers['ccm'] = torch.cat((gb_pos, torch.ones_like(gb_pos[..., 0:1])), dim=-1 ) + buffers['mask'] = torch.cat((antialias_mask.repeat(1,1,1,3), torch.ones_like(gb_pos[..., 0:1])), dim=-1 ) + ################################################################################ + # Prepare output + ################################################################################ + + # Scale back up to visibility resolution if using MSAA + if spp > 1 and msaa: + for key in buffers.keys(): + buffers[key] = render_utils.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest') + + # Return buffers + return buffers + +# ============================================================================================== +# Render a depth peeled mesh (scene), some limitations: +# - Single mesh +# - Single light +# - Single material +# ============================================================================================== +def render_mesh( + ctx, + mesh, + mtx_in, + view_pos, + env, + planes, + kd_fn, + materials, + resolution, + spp = 1, + num_layers = 1, + msaa = False, + background = None, + gt_render = False + ): + + def prepare_input_vector(x): + x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x + return x[:, None, None, :] if len(x.shape) == 2 else x + + def composite_buffer(key, layers, background, antialias): + accum = background + for buffers, rast in reversed(layers): + alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] + accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) + if antialias: + accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) + return accum + + assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" + assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) + + full_res = [resolution[0]*spp, resolution[1]*spp] + + # Convert numpy arrays to torch tensors + mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in + view_pos = prepare_input_vector(view_pos) + + # clip space transform + v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in) + + # Render all layers front-to-back + layers = [] + with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + layers += [(render_layer(rast, db, mesh, view_pos, env, planes, kd_fn, materials, v_pos_clip, resolution, spp, msaa, gt_render), rast)] + + # Setup background + if background is not None: + if spp > 1: + background = render_utils.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') + background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) + else: + background = torch.ones(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') + background_black = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') + + # Composite layers front-to-back + out_buffers = {} + + for key in layers[0][0].keys(): + if key == 'mask': + accum = composite_buffer(key, layers, background_black, True) + else: + accum = composite_buffer(key, layers, background, True) + + # Downscale to framebuffer resolution. Use avg pooling + out_buffers[key] = render_utils.avg_pool_nhwc(accum, spp) if spp > 1 else accum + + return out_buffers + +# ============================================================================================== +# Render UVs +# ============================================================================================== +def render_uv(ctx, mesh, resolution, mlp_texture): + + # clip space transform + uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int()) + + # Sample out textures from MLP + all_tex = mlp_texture.sample(gb_pos) + assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" + perturbed_nrm = all_tex[..., -3:] + return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], render_utils.safe_normalize(perturbed_nrm) diff --git a/models/lrm/online_render/src/utils/render_utils.py b/models/lrm/online_render/src/utils/render_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..0232bdac180d814d1ce10cea31b2e8db616617e0 --- /dev/null +++ b/models/lrm/online_render/src/utils/render_utils.py @@ -0,0 +1,507 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio +# imageio.plugins.freeimage.download() +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate +# def lookAt(eye, center, up): +# f = (center - eye) +# f = f / torch.norm(f) + +# u = up / torch.norm(up) +# s = torch.cross(f, u) +# u = torch.cross(s, f) + +# result = torch.eye(4) +# result[0, 0:3] = s +# result[1, 0:3] = u +# result[2, 0:3] = -f +# result[0, 3] = -torch.dot(s, eye) +# result[1, 3] = -torch.dot(u, eye) +# result[2, 3] = torch.dot(f, eye) + +# return result + +def look_at_opengl(eye, at, up): + # 计算前向量 + forward = (at - eye) + forward = forward / torch.norm(forward) + + # 计算右向量 + right = torch.cross(up, forward) + right = right / torch.norm(right) + + # 计算实际的上向量 + up = torch.cross(forward, right) + + # 构建视图矩阵 + view_matrix = torch.eye(4) + view_matrix[0, :3] = right + view_matrix[1, :3] = up + view_matrix[2, :3] = -forward + view_matrix[:3, 3] = -eye + + # 计算 c2w 矩阵 + # c2w = torch.inverse(view_matrix) + + return view_matrix + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + diff --git a/models/lrm/online_render/src/utils/texture.py b/models/lrm/online_render/src/utils/texture.py new file mode 100755 index 0000000000000000000000000000000000000000..307769ca9fe8a2fb9d3278df05f1d006d0aa5d5e --- /dev/null +++ b/models/lrm/online_render/src/utils/texture.py @@ -0,0 +1,189 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from ..models.geometry.rep_3d import util + +###################################################################################### +# Smooth pooling / mip computation with linear gradient upscaling +###################################################################################### + +class texture2d_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, texture): + return util.avg_pool_nhwc(texture, (2,2)) + + @staticmethod + def backward(ctx, dout): + gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), + torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), + indexing='ij') + uv = torch.stack((gx, gy), dim=-1) + return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') + +######################################################################################################## +# Simple texture class. A texture can be either +# - A 3D tensor (using auto mipmaps) +# - A list of 3D tensors (full custom mip hierarchy) +######################################################################################################## + +class Texture2D(torch.nn.Module): + # Initializes a texture from image data. + # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) + def __init__(self, init, min_max=None): + super(Texture2D, self).__init__() + + if isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + elif isinstance(init, list) and len(init) == 1: + init = init[0] + + if isinstance(init, list): + self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init) + elif len(init.shape) == 4: + self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True) + elif len(init.shape) == 3: + self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True) + elif len(init.shape) == 2: + self.data = torch.nn.Parameter(init[None, :, :, None].repeat(1,1,1,3).clone().detach(), requires_grad=True) + # breakpoint() + elif len(init.shape) == 1: + self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor + else: + assert False, "Invalid texture object" + + self.min_max = min_max + + # Filtered (trilinear) sample texture at a given location + def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): + if isinstance(self.data, list): + out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) + else: + if self.data.shape[1] > 1 and self.data.shape[2] > 1: + mips = [self.data] + while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: + mips += [texture2d_mip.apply(mips[-1])] + out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) + else: + out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) + return out + + def getRes(self): + return self.getMips()[0].shape[1:3] + + def getChannels(self): + return self.getMips()[0].shape[3] + + def getMips(self): + if isinstance(self.data, list): + return self.data + else: + return [self.data] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + if self.min_max is not None: + for mip in self.getMips(): + for i in range(mip.shape[-1]): + mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) + + # In-place clamp with no derivative to make sure values are in valid range after training + def normalize_(self): + with torch.no_grad(): + for mip in self.getMips(): + mip = util.safe_normalize(mip) + +######################################################################################################## +# Helper function to create a trainable texture from a regular texture. The trainable weights are +# initialized with texture data as an initial guess +######################################################################################################## + +def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): + with torch.no_grad(): + if isinstance(init, Texture2D): + assert isinstance(init.data, torch.Tensor) + min_max = init.min_max if min_max is None else min_max + init = init.data + elif isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + + # Pad to NHWC if needed + if len(init.shape) == 1: # Extend constant to NHWC tensor + init = init[None, None, None, :] + elif len(init.shape) == 3: + init = init[None, ...] + + # Scale input to desired resolution. + if res is not None: + init = util.scale_img_nhwc(init, res) + + # Genreate custom mipchain + if not auto_mipmaps: + mip_chain = [init.clone().detach().requires_grad_(True)] + while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: + new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] + mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] + return Texture2D(mip_chain, min_max=min_max) + else: + return Texture2D(init, min_max=min_max) + +######################################################################################################## +# Convert texture to and from SRGB +######################################################################################################## + +def srgb_to_rgb(texture): + return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) + +def rgb_to_srgb(texture): + return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) + +######################################################################################################## +# Utility functions for loading / storing a texture +######################################################################################################## + +def _load_mip2D(fn, lambda_fn=None, channels=None): + imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') + if channels is not None: + imgdata = imgdata[..., 0:channels] + if lambda_fn is not None: + imgdata = lambda_fn(imgdata) + return imgdata.detach().clone() + +def load_texture2D(fn, lambda_fn=None, channels=None): + base, ext = os.path.splitext(fn) + if os.path.exists(base + "_0" + ext): + mips = [] + while os.path.exists(base + ("_%d" % len(mips)) + ext): + mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] + return Texture2D(mips) + else: + return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + +def _save_mip2D(fn, mip, mipidx, lambda_fn): + if lambda_fn is not None: + data = lambda_fn(mip).detach().cpu().numpy() + else: + data = mip.detach().cpu().numpy() + + if mipidx is None: + util.save_image(fn, data) + else: + base, ext = os.path.splitext(fn) + util.save_image(base + ("_%d" % mipidx) + ext, data) + +def save_texture2D(fn, tex, lambda_fn=None): + if isinstance(tex.data, list): + for i, mip in enumerate(tex.data): + _save_mip2D(fn, mip[0,...], i, lambda_fn) + else: + _save_mip2D(fn, tex.data[0,...], None, lambda_fn) diff --git a/models/lrm/online_render/src/utils/train_util.py b/models/lrm/online_render/src/utils/train_util.py new file mode 100755 index 0000000000000000000000000000000000000000..2e65421bffa8cc42c1517e86f2dfd8183caf52ab --- /dev/null +++ b/models/lrm/online_render/src/utils/train_util.py @@ -0,0 +1,26 @@ +import importlib + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/models/lrm/utils/__init__.py b/models/lrm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/lrm/utils/camera_util.py b/models/lrm/utils/camera_util.py new file mode 100644 index 0000000000000000000000000000000000000000..335bcba4957b69c5f58b7e0f3524121c218bf2e6 --- /dev/null +++ b/models/lrm/utils/camera_util.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +import numpy as np + + +def pad_camera_extrinsics_4x4(extrinsics): + if extrinsics.shape[-2] == 4: + return extrinsics + padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) + if extrinsics.ndim == 3: + padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) + extrinsics = torch.cat([extrinsics, padding], dim=-2) + return extrinsics + + +def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None): + """ + Create OpenGL camera extrinsics from camera locations and look-at position. + + camera_position: (M, 3) or (3,) + look_at: (3) + up_world: (3) + return: (M, 3, 4) or (3, 4) + """ + # by default, looking at the origin and world up is z-axis + if look_at is None: + look_at = torch.tensor([0, 0, 0], dtype=torch.float32) + if up_world is None: + up_world = torch.tensor([0, 0, 1], dtype=torch.float32) + if camera_position.ndim == 2: + look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) + up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) + + # OpenGL camera: z-backward, x-right, y-up + z_axis = camera_position - look_at + z_axis = F.normalize(z_axis, dim=-1).float() + x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) + x_axis = F.normalize(x_axis, dim=-1).float() + y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) + y_axis = F.normalize(y_axis, dim=-1).float() + + extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) + extrinsics = pad_camera_extrinsics_4x4(extrinsics) + return extrinsics + + +def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + + xs = radius * np.cos(elevations) * np.cos(azimuths) + ys = radius * np.cos(elevations) * np.sin(azimuths) + zs = radius * np.sin(elevations) + + cam_locations = np.stack([xs, ys, zs], axis=-1) + cam_locations = torch.from_numpy(cam_locations).float() + + c2ws = center_looking_at_camera_pose(cam_locations) + return c2ws + + +def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0): + # M: number of circular views + # radius: camera dist to center + # elevation: elevation degrees of the camera + # return: (M, 4, 4) + assert M > 0 and radius > 0 + + elevation = np.deg2rad(elevation) + + camera_positions = [] + for i in range(M): + azimuth = 2 * np.pi * i / M + x = radius * np.cos(elevation) * np.cos(azimuth) + y = radius * np.cos(elevation) * np.sin(azimuth) + z = radius * np.sin(elevation) + camera_positions.append([x, y, z]) + camera_positions = np.array(camera_positions) + camera_positions = torch.from_numpy(camera_positions).float() + extrinsics = center_looking_at_camera_pose(camera_positions) + return extrinsics + + +def FOV_to_intrinsics(fov, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics + + +def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) + elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) + + +def get_flux_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([0, 90, 180, 270]).astype(float) + elevations = np.array([5, 5, 5, 5]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(4, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) + +def get_custom_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([0, 90, 180, 270]).astype(float) + # azimuths = np.array([270, 180, 90, 0]).astype(float) + elevations = np.array([5, 5, 5, 5]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(4, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/models/lrm/utils/infer_util.py b/models/lrm/utils/infer_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f2faf2bf3b12d4af7b33cb2292da2b5ed62eb52e --- /dev/null +++ b/models/lrm/utils/infer_util.py @@ -0,0 +1,97 @@ +import os +import imageio +import rembg +import torch +import numpy as np +import PIL.Image +from PIL import Image +from typing import Any + + +def remove_background(image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def images_to_video( + images: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + video_dir = os.path.dirname(output_path) + video_name = os.path.basename(output_path) + os.makedirs(video_dir, exist_ok=True) + + frames = [] + for i in range(len(images)): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) + + +def save_video( + frames: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] + writer = imageio.get_writer(output_path, fps=fps) + for frame in frames: + writer.append_data(frame) + writer.close() \ No newline at end of file diff --git a/models/lrm/utils/material.py b/models/lrm/utils/material.py new file mode 100644 index 0000000000000000000000000000000000000000..c03e26e83471bf7498b60012bf21c29f97b625dd --- /dev/null +++ b/models/lrm/utils/material.py @@ -0,0 +1,197 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from models.lrm.models.geometry.rep_3d import util +from . import texture + +###################################################################################### +# Wrapper to make materials behave like a python dict, but register textures as +# torch.nn.Module parameters. +###################################################################################### +class Material(torch.nn.Module): + def __init__(self, mat_dict): + super(Material, self).__init__() + self.mat_keys = set() + for key in mat_dict.keys(): + self.mat_keys.add(key) + self[key] = mat_dict[key] + + def __contains__(self, key): + return hasattr(self, key) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, val): + self.mat_keys.add(key) + setattr(self, key, val) + + def __delitem__(self, key): + self.mat_keys.remove(key) + delattr(self, key) + + def keys(self): + return self.mat_keys + +###################################################################################### +# .mtl material format loading / storing +###################################################################################### +@torch.no_grad() +def load_mtl(fn, clear_ks=True): + import re + mtl_path = os.path.dirname(fn) + + # Read file + with open(fn, 'r') as f: + lines = f.readlines() + + # Parse materials + materials = [] + for line in lines: + split_line = re.split(' +|\t+|\n+', line.strip()) + prefix = split_line[0].lower() + data = split_line[1:] + if 'newmtl' in prefix: + material = Material({'name' : data[0]}) + materials += [material] + elif materials: + if 'map_d' in prefix: + # 设置透明度为1.0,即完全不透明 + material['d'] = torch.tensor(1.0, dtype=torch.float32, device='cuda') + elif 'map_ke' in prefix: + # 设置自发光为0 + material['Ke'] = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda') + elif 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: + material[prefix] = data[0] + else: + material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + + # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps + for mat in materials: + if not 'bsdf' in mat: + mat['bsdf'] = 'pbr' + + if 'map_kd' in mat: + mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) + else: + mat['kd'] = texture.Texture2D(mat['kd']) + + if 'map_ks' in mat: + mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) + else: + mat['ks'] = texture.Texture2D(mat['ks']) + + if 'bump' in mat: + mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) + + # Convert Kd from sRGB to linear RGB + mat['kd'] = texture.srgb_to_rgb(mat['kd']) + + if clear_ks: + # Override ORM occlusion (red) channel by zeros. We hijack this channel + for mip in mat['ks'].getMips(): + mip[..., 0] = 0.0 + + return materials + +@torch.no_grad() +def save_mtl(fn, material): + folder = os.path.dirname(fn) + with open(fn, "w") as f: + f.write('newmtl defaultMat\n') + if material is not None: + f.write('bsdf %s\n' % material['bsdf']) + if 'kd' in material.keys(): + f.write('map_Kd texture_kd.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) + if 'ks' in material.keys(): + f.write('map_Ks texture_ks.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) + if 'normal' in material.keys(): + f.write('bump texture_n.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) + else: + f.write('Kd 1 1 1\n') + f.write('Ks 0 0 0\n') + f.write('Ka 0 0 0\n') + f.write('Tf 1 1 1\n') + f.write('Ni 1\n') + f.write('Ns 0\n') + +###################################################################################### +# Merge multiple materials into a single uber-material +###################################################################################### + +def _upscale_replicate(x, full_res): + x = x.permute(0, 3, 1, 2) + x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') + return x.permute(0, 2, 3, 1).contiguous() + +def merge_materials(materials, texcoords, tfaces, mfaces): + assert len(materials) > 0 + for mat in materials: + assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" + assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" + + uber_material = Material({ + 'name' : 'uber_material', + 'bsdf' : materials[0]['bsdf'], + }) + + textures = ['kd', 'ks', 'normal'] + + # Find maximum texture resolution across all materials and textures + max_res = None + for mat in materials: + for tex in textures: + tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) + max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res + + # Compute size of compund texture and round up to nearest PoT + full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(int) + + # Normalize texture resolution across all materials & combine into a single large texture + for tex in textures: + if tex in materials[0]: + # breakpoint() + tex_data_list = [] + for mat in materials: + if tex in mat: + scaled_tex = util.scale_img_nhwc(mat[tex].data, tuple(max_res)) + if scaled_tex.shape[-1] != 3: + scaled_tex = scaled_tex[:, :, :, :3] + tex_data_list.append(scaled_tex) + # tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x + tex_data = torch.cat(tuple(tex_data_list), dim=2) # 将所有纹理水平排列,NHWC 的 dim2 是 x 轴 + tex_data = _upscale_replicate(tex_data, full_res) + uber_material[tex] = texture.Texture2D(tex_data) + + # Compute scaling values for used / unused texture area + s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] + + # Recompute texture coordinates to cooincide with new composite texture + new_tverts = {} + new_tverts_data = [] + for fi in range(len(tfaces)): + matIdx = mfaces[fi] + for vi in range(3): + ti = tfaces[fi][vi] + if not (ti in new_tverts): + new_tverts[ti] = {} + if not (matIdx in new_tverts[ti]): # create new vertex + new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex + + return uber_material, new_tverts_data, tfaces + diff --git a/models/lrm/utils/mesh.py b/models/lrm/utils/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b2738d5f3a2b6fdf125ce83578abb1025d8af494 --- /dev/null +++ b/models/lrm/utils/mesh.py @@ -0,0 +1,255 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import obj +from models.lrm.models.geometry.rep_3d import util + +###################################################################################### +# Base mesh class +###################################################################################### +class Mesh: + def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, material=None, base=None): + self.v_pos = v_pos + self.v_nrm = v_nrm + self.v_tex = v_tex + self.v_tng = v_tng + self.t_pos_idx = t_pos_idx + self.t_nrm_idx = t_nrm_idx + self.t_tex_idx = t_tex_idx + self.t_tng_idx = t_tng_idx + self.material = material + + if base is not None: + self.copy_none(base) + + def copy_none(self, other): + if self.v_pos is None: + self.v_pos = other.v_pos + if self.t_pos_idx is None: + self.t_pos_idx = other.t_pos_idx + if self.v_nrm is None: + self.v_nrm = other.v_nrm + if self.t_nrm_idx is None: + self.t_nrm_idx = other.t_nrm_idx + if self.v_tex is None: + self.v_tex = other.v_tex + if self.t_tex_idx is None: + self.t_tex_idx = other.t_tex_idx + if self.v_tng is None: + self.v_tng = other.v_tng + if self.t_tng_idx is None: + self.t_tng_idx = other.t_tng_idx + if self.material is None: + self.material = other.material + + def clone(self): + out = Mesh(base=self) + if out.v_pos is not None: + out.v_pos = out.v_pos.clone().detach() + if out.t_pos_idx is not None: + out.t_pos_idx = out.t_pos_idx.clone().detach() + if out.v_nrm is not None: + out.v_nrm = out.v_nrm.clone().detach() + if out.t_nrm_idx is not None: + out.t_nrm_idx = out.t_nrm_idx.clone().detach() + if out.v_tex is not None: + out.v_tex = out.v_tex.clone().detach() + if out.t_tex_idx is not None: + out.t_tex_idx = out.t_tex_idx.clone().detach() + if out.v_tng is not None: + out.v_tng = out.v_tng.clone().detach() + if out.t_tng_idx is not None: + out.t_tng_idx = out.t_tng_idx.clone().detach() + return out + def rotate_x_90(self): + # 定义绕X轴旋转90度的旋转矩阵 + rotate_x = torch.tensor([[1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=self.v_pos.device) + + # 将旋转矩阵应用到顶点坐标 + if self.v_pos is not None: + v_pos_homo = torch.cat((self.v_pos, torch.ones(self.v_pos.shape[0], 1, device=self.v_pos.device)), dim=1) + v_pos_rotated = v_pos_homo @ rotate_x.T + self.v_pos = v_pos_rotated[:, :3] + + # 将旋转矩阵应用到法线 + if self.v_nrm is not None: + v_nrm_homo = torch.cat((self.v_nrm, torch.zeros(self.v_nrm.shape[0], 1, device=self.v_nrm.device)), dim=1) + v_nrm_rotated = v_nrm_homo @ rotate_x.T + self.v_nrm = v_nrm_rotated[:, :3] +###################################################################################### +# Mesh loeading helper +###################################################################################### + +def load_mesh(filename, mtl_override=None): + name, ext = os.path.splitext(filename) + if ext == ".obj": + return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override) + assert False, "Invalid mesh file extension" + +###################################################################################### +# Compute AABB +###################################################################################### +def aabb(mesh): + return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values + +###################################################################################### +# Compute unique edge list from attribute/vertex index list +###################################################################################### +def compute_edges(attr_idx, return_inverse=False): + with torch.no_grad(): + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Eliminate duplicates and return inverse mapping + return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) + +###################################################################################### +# Compute unique edge to face mapping from attribute/vertex index list +###################################################################################### +def compute_edge_to_face_mapping(attr_idx, return_inverse=False): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +###################################################################################### +# Align base mesh to reference mesh:move & rescale to match bounding boxes. +###################################################################################### +def unit_size(mesh): + with torch.no_grad(): + vmin, vmax = aabb(mesh) + scale = 2 / torch.max(vmax - vmin).item() + v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin + v_pos = v_pos * scale # Rescale to unit size + + return Mesh(v_pos, base=mesh) + +###################################################################################### +# Center & scale mesh for rendering +###################################################################################### +def center_by_reference(base_mesh, ref_aabb, scale): + center = (ref_aabb[0] + ref_aabb[1]) * 0.5 + scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() + v_pos = (base_mesh.v_pos - center[None, ...]) * scale + return Mesh(v_pos, base=base_mesh) + +###################################################################################### +# Simple smooth vertex normal computation +###################################################################################### +def auto_normals(imesh): + + i0 = imesh.t_pos_idx[:, 0] + i1 = imesh.t_pos_idx[:, 1] + i2 = imesh.t_pos_idx[:, 2] + + v0 = imesh.v_pos[i0, :] + v1 = imesh.v_pos[i1, :] + v2 = imesh.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(imesh.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + v_nrm = util.safe_normalize(v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh) + +###################################################################################### +# Compute tangent space from texture map coordinates +# Follows http://www.mikktspace.com/ conventions +###################################################################################### +def compute_tangents(imesh): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0,3): + pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]] + tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]] + vn_idx[i] = imesh.t_nrm_idx[:, i] + + tangents = torch.zeros_like(imesh.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) + denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) + + # Update all 3 vertices + for i in range(0,3): + idx = vn_idx[i][:, None].repeat(1,3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + + # Normalize and make sure tangent is perpendicular to normal + tangents = util.safe_normalize(tangents) + tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) diff --git a/models/lrm/utils/mesh_util.py b/models/lrm/utils/mesh_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f6659b3f7e2015b3de3f751e7f035bdcb9dae416 --- /dev/null +++ b/models/lrm/utils/mesh_util.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import trimesh +import cv2 +import numpy as np +import nvdiffrast.torch as dr +from PIL import Image + +def generate_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + import os + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = '%s/%s.mtl' % (fol, na) + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write('map_Kd %s.png\n' % na) + fid.close() + #### + + fid = open(fname, 'w') + fid.write('mtllib %s.mtl\n' % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write('vt %f %f\n' % (pp[0], pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') + + mesh = trimesh.load(fname) + return mesh + + +def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + facenp_fx3 = facenp_fx3[:, [2, 1, 0]] + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'obj') + + +def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'glb') + + +def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + import os + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = '%s/%s.mtl' % (fol, na) + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write('map_Kd %s.png\n' % na) + fid.close() + #### + + fid = open(fname, 'w') + fid.write('mtllib %s.mtl\n' % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write('vt %f %f\n' % (pp[0], pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') + + +def loadobj(meshfile): + v = [] + f = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if len(data) != 4: + continue + if data[0] == 'v': + v.append([float(d) for d in data[1:]]) + if data[0] == 'f': + data = [da.split('/')[0] for da in data] + f.append([int(d) for d in data[1:]]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + return pointnp_px3, facenp_fx3 + + +def loadobjtex(meshfile): + v = [] + vt = [] + f = [] + ft = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): + continue + if data[0] == 'v': + assert len(data) == 4 + + v.append([float(d) for d in data[1:]]) + if data[0] == 'vt': + if len(data) == 3 or len(data) == 4: + vt.append([float(d) for d in data[1:3]]) + if data[0] == 'f': + data = [da.split('/') for da in data] + if len(data) == 4: + f.append([int(d[0]) for d in data[1:]]) + ft.append([int(d[1]) for d in data[1:]]) + elif len(data) == 5: + idx1 = [1, 2, 3] + data1 = [data[i] for i in idx1] + f.append([int(d[0]) for d in data1]) + ft.append([int(d[1]) for d in data1]) + idx2 = [1, 3, 4] + data2 = [data[i] for i in idx2] + f.append([int(d[0]) for d in data2]) + ft.append([int(d[1]) for d in data2]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + uvs = np.array(vt, dtype=np.float32) + return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/models/lrm/utils/obj.py b/models/lrm/utils/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..2678cf41e6b2f048afb64e8a420c6363dc5c0ce9 --- /dev/null +++ b/models/lrm/utils/obj.py @@ -0,0 +1,209 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import torch + +from . import texture +from . import mesh +from . import material + +###################################################################################### +# Utility functions +###################################################################################### + +def _find_mat(materials, name): + for mat in materials: + if mat['name'] == name: + return mat + return materials[0] # Materials 0 is the default + + +def normalize_mesh(vertices): + # 计算边界框 + min_vals, _ = torch.min(vertices, dim=0) + max_vals, _ = torch.max(vertices, dim=0) + + # 计算中心点 + center = (max_vals + min_vals) / 2 + + # 平移顶点 + vertices = vertices - center + + # 计算缩放因子 + max_extent = torch.max(max_vals - min_vals) + scale = 2.0 / max_extent + + # 缩放顶点 + vertices = vertices * scale + + return vertices + +###################################################################################### +# Create mesh object from objfile +###################################################################################### +def rotate_y_90(v_pos): + # 定义绕X轴旋转90度的旋转矩阵 + rotate_y = torch.tensor([[0, 0, 1, 0], + [0, 1, 0, 0], + [-1, 0, 0, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=v_pos.device) + return rotate_y + +def load_obj(filename, clear_ks=True, mtl_override=None, return_attributes=False, path_is_attributrs=False): + obj_path = os.path.dirname(filename) + + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # Load materials + all_materials = [ + { + 'name' : '_default_mat', + 'bsdf' : 'pbr', + 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), + 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) + } + ] + if mtl_override is None: + for line in lines: + if len(line.split()) == 0: + continue + if line.split()[0] == 'mtllib': + all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library + else: + all_materials += material.load_mtl(mtl_override) + + # load vertices + vertices, texcoords, normals = [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'vn': + normals.append([float(v) for v in line.split()[1:]]) + + # load faces + activeMatIdx = None + used_materials = [] + faces, tfaces, nfaces, mfaces = [], [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + mat = _find_mat(all_materials, line.split()[1]) + if not mat in used_materials: + used_materials.append(mat) + activeMatIdx = used_materials.index(mat) + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + n0 = int(vv[2]) - 1 if vv[2] != "" else -1 + for i in range(nv - 2): # Triangulate polygons + vv = vs[i + 1].split('/') + v1 = int(vv[0]) - 1 + t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + vv = vs[i + 2].split('/') + v2 = int(vv[0]) - 1 + t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + mfaces.append(activeMatIdx) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + assert len(tfaces) == len(faces) and len(nfaces) == len (faces) + + # Create an "uber" material by combining all textures into a larger texture + if len(used_materials) > 1: + uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) + else: + uber_material = used_materials[0] + + vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') + texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None + normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None + + faces = torch.tensor(faces, dtype=torch.int64, device='cuda') + tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None + nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None + + vertices = normalize_mesh(vertices) + # vertices = vertices @ rotate_y_90(vertices)[:3,:3] + + if return_attributes: + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material), vertices, faces, normals, nfaces, texcoords, tfaces, uber_material + return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + +###################################################################################### +# Save mesh object to objfile +###################################################################################### + +def write_obj(folder, mesh, save_material=True): + obj_file = os.path.join(folder, 'mesh.obj') + print("Writing mesh: ", obj_file) + with open(obj_file, "w") as f: + f.write("mtllib mesh.mtl\n") + f.write("g default\n") + + v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None + v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None + v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None + + t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None + t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None + t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None: + print(" writing %d texcoords" % len(v_tex)) + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + if v_nrm is not None: + print(" writing %d normals" % len(v_nrm)) + assert(len(t_pos_idx) == len(t_nrm_idx)) + for v in v_nrm: + f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + if save_material: + mtl_file = os.path.join(folder, 'mesh.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, mesh.material) + + print("Done exporting mesh") diff --git a/models/lrm/utils/render.py b/models/lrm/utils/render.py new file mode 100644 index 0000000000000000000000000000000000000000..5468bd8e5c8be7f0b19baec24d841a658fb39408 --- /dev/null +++ b/models/lrm/utils/render.py @@ -0,0 +1,386 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import nvdiffrast.torch as dr + +from . import render_utils +from models.lrm.models.geometry.render import renderutils as ru +import numpy as np +from PIL import Image + +# ============================================================================================== +# Helper functions +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def get_mip(roughness): + return torch.where(roughness < 1.0 + , (torch.clamp(roughness, 0.04, 1.0) - 0.04) / (1.0 - 0.04) * (6 - 2) + , (torch.clamp(roughness, 1.0, 1.0) - 1.0) / (1.0 - 1.0) + 6 - 2) + +def shade_with_env(gb_pos, gb_normal, kd, metallic, roughness, view_pos, run_n_view, env, metallic_gt, roughness_gt, use_material_gt=True, gt_render=False): + + #mask = mask[..., 0] + view_pos = view_pos.expand(-1, gb_pos.shape[1], gb_pos.shape[2], -1) #.reshape(1, 512, 10240, 3) + + wo = render_utils.safe_normalize(view_pos - gb_pos) + + #roughness = roughness.reshape(roughness.shape[0], roughness.shape[1], roughness.shape[1], run_n_view, 1) + #metallic = metallic.reshape(metallic.shape[0], metallic.shape[1], metallic.shape[1], run_n_view, 1) + #kd = kd.reshape(kd.shape[0], kd.shape[1], kd.shape[1], run_n_view, 3) + # if len(diffuse_light) != 10: + # diffuse_light = [diffuse_light[0] for _ in range(10)] + # specular_light = [specular_light[0] for _ in range(10)] + + # metallic_gt = torch.zeros((8, metallic_gt.shape[1], metallic_gt.shape[2], 1)).cuda() + # roughness_gt = torch.zeros((8, roughness_gt.shape[1], roughness_gt.shape[2], 1)).cuda() + + #if use_material_gt: + spec_col = (1.0 - metallic_gt)*0.04 + kd * metallic_gt + diff_col = kd * (1.0 - metallic_gt) + # else: + + # spec_col = (1.0 - metallic)*0.04 + kd * metallic + # diff_col = kd * (1.0 - metallic) + + nrmvec = gb_normal + reflvec = render_utils.safe_normalize(render_utils.reflect(wo, nrmvec)) + + prb_rendered_list = [] + pbr_specular_light_list = [] + pbr_diffuse_light_list = [] + # pbr_specular_color_list = [] + # pbr_diffuse_color_list = [] + for i in range(run_n_view): + specular_light, diffuse_light = env[i] + diffuse_light = diffuse_light.cuda() + specular_light_new = [] + for split_specular_light in specular_light: + specular_light_new.append(split_specular_light.cuda()) + specular_light = specular_light_new + + shaded_col = torch.ones((gb_pos.shape[1], gb_pos.shape[2], 3)).cuda() + + diffuse = dr.texture(diffuse_light[None, ...], nrmvec[i,:,:,:][None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') + diffuse_comp = diffuse * diff_col[i,:,:,:][None, ...] + + # Lookup FG term from lookup texture + NdotV = torch.clamp(render_utils.dot(wo[i,:,:,:], nrmvec[i,:,:,:]), min=1e-4) + fg_uv = torch.cat((NdotV, roughness_gt[i,:,:,:]), dim=-1) + #if not hasattr(self, '_FG_LUT'): + _FG_LUT = torch.as_tensor(np.fromfile('./models/lrm/data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') + fg_lookup = dr.texture(_FG_LUT, fg_uv[None, ...], filter_mode='linear', boundary_mode='clamp') + # Roughness adjusted specular env lookup + + miplevel = get_mip(roughness_gt[i,:,:,:]) + miplevel = miplevel[None, ...] + spec = dr.texture(specular_light[0][None, ...], reflvec[i,:,:,:][None, ...].contiguous(), mip=list(m[None, ...] for m in specular_light[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') + + # Compute aggregate lighting + reflectance = spec_col[i,:,:,:][None, ...] * fg_lookup[...,0:1] + fg_lookup[...,1:2] + specular_comp = spec * reflectance + #shaded_col += spec * reflectance + shaded_col = (specular_comp[0] + diffuse_comp[0]) + + prb_rendered_list.append(shaded_col) + pbr_specular_light_list.append(spec[0]) + pbr_diffuse_light_list.append(diffuse[0]) + + # pbr_specular_color_list.append(metallic_gt[i].repeat(1,1,3)) + # pbr_diffuse_color_list.append(roughness_gt[i].repeat(1,1,3)) + + + shaded_col_all = torch.stack(prb_rendered_list, dim=0) + pbr_specular_light = torch.stack(pbr_specular_light_list, dim=0) + pbr_diffuse_light = torch.stack(pbr_diffuse_light_list, dim=0) + # pbr_specular_color = torch.stack(pbr_specular_color_list, dim=0) + # pbr_diffuse_color = torch.stack(pbr_diffuse_color_list, dim=0) + + #shaded_col_all = shaded_col_all.reshape(shaded_col_all.shape[0], shaded_col_all.shape[1], shaded_col_all.shape[1]*run_n_view, 3) + shaded_col_all = render_utils.rgb_to_srgb(shaded_col_all).clamp(0.,1.) + pbr_specular_light = render_utils.rgb_to_srgb(pbr_specular_light).clamp(0.,1.) + pbr_diffuse_light = render_utils.rgb_to_srgb(pbr_diffuse_light).clamp(0.,1.) + # pbr_specular_color = render_utils.rgb_to_srgb(pbr_specular_color).clamp(0.,1.) + # pbr_diffuse_color = render_utils.rgb_to_srgb(pbr_diffuse_color).clamp(0.,1.) + + return shaded_col_all, pbr_specular_light, pbr_diffuse_light #, pbr_specular_color, pbr_diffuse_color + +# ============================================================================================== +# pixel shader +# ============================================================================================== +def shade( + gb_pos, + gb_geometric_normal, + gb_normal, + gb_tangent, + gb_texc, + gb_texc_deriv, + view_pos, + env, + planes, + kd_fn, + materials, + material, + mask, + gt_render + ): + + ################################################################################ + # Texture lookups + ################################################################################ + perturbed_nrm = None + resolution = gb_pos.shape[1] + N_views = view_pos.shape[0] + + if planes is None: + kd = material['kd'].sample(gb_texc, gb_texc_deriv) + + matellic_gt, roughness_gt = (materials[0] * torch.ones(*kd.shape[:-1])).unsqueeze(-1).cuda(), (materials[1] * torch.ones(*kd.shape[:-1])).unsqueeze(-1).cuda() + matellic, roughness = None, None + else: + # predict kd with MLP and interpolated feature + gb_pos_interp, mask = [gb_pos], [mask] + gb_pos_interp = [torch.cat([pos[i_view:i_view + 1] for i_view in range(N_views)], dim=2) for pos in gb_pos_interp] + mask = [torch.cat([ma[i_view:i_view + 1] for i_view in range(N_views)], dim=2) for ma in mask] + kd, matellic, roughness = kd_fn( planes[None,...], gb_pos_interp, mask[0]) + kd = torch.cat( [torch.cat([kd[i:i + 1, :, resolution * i_view: resolution * (i_view + 1)]for i_view in range(N_views)], dim=0) for i in range(len(kd))], dim=0) + + matellic_val = [x[0] for x in materials] + roughness_val = [y[1] for y in materials] + + matellic_gt = torch.full((N_views, resolution, resolution, 1), fill_value=0, dtype=torch.float32) + roughness_gt = torch.full((N_views, resolution, resolution, 1), fill_value=0, dtype=torch.float32) + + for i in range(len(matellic_gt)): + matellic_gt[i, :, :, 0].fill_(matellic_val[i]) + roughness_gt[i, :, :, 0].fill_(roughness_val[i]) + + matellic_gt = matellic_gt.cuda() + roughness_gt = roughness_gt.cuda() + + # Separate kd into alpha and color, default alpha = 1 + alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) + kd = kd[..., 0:3].clamp(0., 1.) + + ################################################################################ + # Normal perturbation & normal bend + ################################################################################ + #if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']: + perturbed_nrm = None + + gb_normal_ = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) + + ################################################################################ + # Evaluate BSDF + ################################################################################ + + shaded_col, spec_light, diff_light = shade_with_env(gb_pos, gb_normal_, kd, matellic, roughness, view_pos, N_views, env, matellic_gt, roughness_gt, use_material_gt=True, gt_render=gt_render) + + buffers = { + 'shaded' : torch.cat((shaded_col, alpha), dim=-1), + 'spec_light': torch.cat((spec_light, alpha), dim=-1), + 'diff_light': torch.cat((diff_light, alpha), dim=-1), + 'gb_normal' : torch.cat((gb_normal_, alpha), dim=-1), + 'normal' : torch.cat((gb_normal, alpha), dim=-1), + 'albedo' : torch.cat((kd, alpha), dim=-1), + # 'spec_albedo': torch.cat((spec_albedo, alpha), dim=-1), + # 'diff_albedo': torch.cat((diff_albedo, alpha), dim=-1), + } + return buffers + +# ============================================================================================== +# Render a depth slice of the mesh (scene), some limitations: +# - Single mesh +# - Single light +# - Single material +# ============================================================================================== +def render_layer( + rast, + rast_deriv, + mesh, + view_pos, + env, + planes, + kd_fn, + materials, + v_pos_clip, + resolution, + spp, + msaa, + gt_render + ): + + full_res = [resolution[0]*spp, resolution[1]*spp] + + ################################################################################ + # Rasterize + ################################################################################ + + # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution + if spp > 1 and msaa: + rast_out_s = render_utils.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') + rast_out_deriv_s = render_utils.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp + else: + rast_out_s = rast + rast_out_deriv_s = rast_deriv + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) + + # Compute geometric normals. We need those because of bent normals trick (for bump mapping) + v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] + v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] + v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] + face_normals = render_utils.safe_normalize(torch.cross(v1 - v0, v2 - v0)) + face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) + + # Compute tangent space + assert mesh.v_nrm is not None and mesh.v_tng is not None + gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) + gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents + + # Texture coordinate + assert mesh.v_tex is not None + gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s) + + # render depth + depth = torch.linalg.norm(view_pos.expand_as(gb_pos) - gb_pos, dim=-1) + + mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias(mask.clone().contiguous(), rast, v_pos_clip,mesh.t_pos_idx.int()) + + ################################################################################ + # Shade + ################################################################################ + + buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, view_pos, env, planes, kd_fn, materials, mesh.material, mask, gt_render) + buffers['depth'] = torch.cat((depth.unsqueeze(-1).repeat(1,1,1,3), torch.ones_like(gb_pos[..., 0:1])), dim=-1 ) + # print(gb_pos.shape) + buffers['ccm'] = torch.cat((gb_pos, torch.ones_like(gb_pos[..., 0:1])), dim=-1 ) + buffers['mask'] = torch.cat((antialias_mask.repeat(1,1,1,3), torch.ones_like(gb_pos[..., 0:1])), dim=-1 ) + ################################################################################ + # Prepare output + ################################################################################ + + # Scale back up to visibility resolution if using MSAA + if spp > 1 and msaa: + for key in buffers.keys(): + buffers[key] = render_utils.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest') + + # Return buffers + return buffers + +# ============================================================================================== +# Render a depth peeled mesh (scene), some limitations: +# - Single mesh +# - Single light +# - Single material +# ============================================================================================== +def render_mesh( + ctx, + mesh, + mtx_in, + view_pos, + env, + planes, + kd_fn, + materials, + resolution, + spp = 1, + num_layers = 1, + msaa = False, + background = None, + gt_render = False + ): + + def prepare_input_vector(x): + x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x + return x[:, None, None, :] if len(x.shape) == 2 else x + + def composite_buffer(key, layers, background, antialias): + accum = background + for buffers, rast in reversed(layers): + alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] + accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) + if antialias: + accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) + return accum + + assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" + assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) + + full_res = [resolution[0]*spp, resolution[1]*spp] + + # Convert numpy arrays to torch tensors + mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in + view_pos = prepare_input_vector(view_pos) + + # clip space transform + v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in) + + # Render all layers front-to-back + layers = [] + with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + layers += [(render_layer(rast, db, mesh, view_pos, env, planes, kd_fn, materials, v_pos_clip, resolution, spp, msaa, gt_render), rast)] + + # Setup background + if background is not None: + if spp > 1: + background = render_utils.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') + background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) + else: + background = torch.ones(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') + background_black = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') + + # Composite layers front-to-back + out_buffers = {} + + for key in layers[0][0].keys(): + if key == 'mask': + accum = composite_buffer(key, layers, background_black, True) + else: + accum = composite_buffer(key, layers, background, True) + + # Downscale to framebuffer resolution. Use avg pooling + out_buffers[key] = render_utils.avg_pool_nhwc(accum, spp) if spp > 1 else accum + + return out_buffers + +# ============================================================================================== +# Render UVs +# ============================================================================================== +def render_uv(ctx, mesh, resolution, mlp_texture): + + # clip space transform + uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int()) + + # Sample out textures from MLP + all_tex = mlp_texture.sample(gb_pos) + assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" + perturbed_nrm = all_tex[..., -3:] + return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], render_utils.safe_normalize(perturbed_nrm) diff --git a/models/lrm/utils/render_utils.py b/models/lrm/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8b550fd8595e9f36a5d758c3a646ebb1d8757015 --- /dev/null +++ b/models/lrm/utils/render_utils.py @@ -0,0 +1,507 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio +imageio.plugins.freeimage.download() +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c,-s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate +# def lookAt(eye, center, up): +# f = (center - eye) +# f = f / torch.norm(f) + +# u = up / torch.norm(up) +# s = torch.cross(f, u) +# u = torch.cross(s, f) + +# result = torch.eye(4) +# result[0, 0:3] = s +# result[1, 0:3] = u +# result[2, 0:3] = -f +# result[0, 3] = -torch.dot(s, eye) +# result[1, 3] = -torch.dot(u, eye) +# result[2, 3] = torch.dot(f, eye) + +# return result + +def look_at_opengl(eye, at, up): + # 计算前向量 + forward = (at - eye) + forward = forward / torch.norm(forward) + + # 计算右向量 + right = torch.cross(up, forward) + right = right / torch.norm(right) + + # 计算实际的上向量 + up = torch.cross(forward, right) + + # 构建视图矩阵 + view_matrix = torch.eye(4) + view_matrix[0, :3] = right + view_matrix[1, :3] = up + view_matrix[2, :3] = -forward + view_matrix[:3, 3] = -eye + + # 计算 c2w 矩阵 + # c2w = torch.inverse(view_matrix) + + return view_matrix + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) + diff --git a/models/lrm/utils/texture.py b/models/lrm/utils/texture.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5820a56e838a35c6c0eed5af2284dc33d7767c --- /dev/null +++ b/models/lrm/utils/texture.py @@ -0,0 +1,189 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from models.lrm.models.geometry.rep_3d import util + +###################################################################################### +# Smooth pooling / mip computation with linear gradient upscaling +###################################################################################### + +class texture2d_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, texture): + return util.avg_pool_nhwc(texture, (2,2)) + + @staticmethod + def backward(ctx, dout): + gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), + torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), + indexing='ij') + uv = torch.stack((gx, gy), dim=-1) + return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') + +######################################################################################################## +# Simple texture class. A texture can be either +# - A 3D tensor (using auto mipmaps) +# - A list of 3D tensors (full custom mip hierarchy) +######################################################################################################## + +class Texture2D(torch.nn.Module): + # Initializes a texture from image data. + # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) + def __init__(self, init, min_max=None): + super(Texture2D, self).__init__() + + if isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + elif isinstance(init, list) and len(init) == 1: + init = init[0] + + if isinstance(init, list): + self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=True) for mip in init) + elif len(init.shape) == 4: + self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=True) + elif len(init.shape) == 3: + self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=True) + elif len(init.shape) == 2: + self.data = torch.nn.Parameter(init[None, :, :, None].repeat(1,1,1,3).clone().detach(), requires_grad=True) + # breakpoint() + elif len(init.shape) == 1: + self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=True) # Convert constant to 1x1 tensor + else: + assert False, "Invalid texture object" + + self.min_max = min_max + + # Filtered (trilinear) sample texture at a given location + def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): + if isinstance(self.data, list): + out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) + else: + if self.data.shape[1] > 1 and self.data.shape[2] > 1: + mips = [self.data] + while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: + mips += [texture2d_mip.apply(mips[-1])] + out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) + else: + out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) + return out + + def getRes(self): + return self.getMips()[0].shape[1:3] + + def getChannels(self): + return self.getMips()[0].shape[3] + + def getMips(self): + if isinstance(self.data, list): + return self.data + else: + return [self.data] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + if self.min_max is not None: + for mip in self.getMips(): + for i in range(mip.shape[-1]): + mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) + + # In-place clamp with no derivative to make sure values are in valid range after training + def normalize_(self): + with torch.no_grad(): + for mip in self.getMips(): + mip = util.safe_normalize(mip) + +######################################################################################################## +# Helper function to create a trainable texture from a regular texture. The trainable weights are +# initialized with texture data as an initial guess +######################################################################################################## + +def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): + with torch.no_grad(): + if isinstance(init, Texture2D): + assert isinstance(init.data, torch.Tensor) + min_max = init.min_max if min_max is None else min_max + init = init.data + elif isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + + # Pad to NHWC if needed + if len(init.shape) == 1: # Extend constant to NHWC tensor + init = init[None, None, None, :] + elif len(init.shape) == 3: + init = init[None, ...] + + # Scale input to desired resolution. + if res is not None: + init = util.scale_img_nhwc(init, res) + + # Genreate custom mipchain + if not auto_mipmaps: + mip_chain = [init.clone().detach().requires_grad_(True)] + while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: + new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] + mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] + return Texture2D(mip_chain, min_max=min_max) + else: + return Texture2D(init, min_max=min_max) + +######################################################################################################## +# Convert texture to and from SRGB +######################################################################################################## + +def srgb_to_rgb(texture): + return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) + +def rgb_to_srgb(texture): + return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) + +######################################################################################################## +# Utility functions for loading / storing a texture +######################################################################################################## + +def _load_mip2D(fn, lambda_fn=None, channels=None): + imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') + if channels is not None: + imgdata = imgdata[..., 0:channels] + if lambda_fn is not None: + imgdata = lambda_fn(imgdata) + return imgdata.detach().clone() + +def load_texture2D(fn, lambda_fn=None, channels=None): + base, ext = os.path.splitext(fn) + if os.path.exists(base + "_0" + ext): + mips = [] + while os.path.exists(base + ("_%d" % len(mips)) + ext): + mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] + return Texture2D(mips) + else: + return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + +def _save_mip2D(fn, mip, mipidx, lambda_fn): + if lambda_fn is not None: + data = lambda_fn(mip).detach().cpu().numpy() + else: + data = mip.detach().cpu().numpy() + + if mipidx is None: + util.save_image(fn, data) + else: + base, ext = os.path.splitext(fn) + util.save_image(base + ("_%d" % mipidx) + ext, data) + +def save_texture2D(fn, tex, lambda_fn=None): + if isinstance(tex.data, list): + for i, mip in enumerate(tex.data): + _save_mip2D(fn, mip[0,...], i, lambda_fn) + else: + _save_mip2D(fn, tex.data[0,...], None, lambda_fn) diff --git a/models/lrm/utils/train_util.py b/models/lrm/utils/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2e65421bffa8cc42c1517e86f2dfd8183caf52ab --- /dev/null +++ b/models/lrm/utils/train_util.py @@ -0,0 +1,26 @@ +import importlib + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/models/zero123plus/model.py b/models/zero123plus/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1655c45f2df23640d9a9270b6240b3453557599e --- /dev/null +++ b/models/zero123plus/model.py @@ -0,0 +1,272 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from tqdm import tqdm +from torchvision.transforms import v2 +from torchvision.utils import make_grid, save_image +from einops import rearrange + +from src.utils.train_util import instantiate_from_config +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel +from .pipeline import RefOnlyNoisedUNet + + +def scale_latents(latents): + latents = (latents - 0.22) * 0.75 + return latents + + +def unscale_latents(latents): + latents = latents / 0.75 + 0.22 + return latents + + +def scale_image(image): + image = image * 0.5 / 0.8 + return image + + +def unscale_image(image): + image = image / 0.5 * 0.8 + return image + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class MVDiffusion(pl.LightningModule): + def __init__( + self, + stable_diffusion_config, + drop_cond_prob=0.1, + ): + super(MVDiffusion, self).__init__() + + self.drop_cond_prob = drop_cond_prob + + self.register_schedule() + + # init modules + pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config) + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipeline.scheduler.config, timestep_spacing='trailing' + ) + self.pipeline = pipeline + + train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config) + if isinstance(self.pipeline.unet, UNet2DConditionModel): + self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler) + + self.train_scheduler = train_sched # use ddpm scheduler during training + + self.unet = pipeline.unet + + # validation output buffer + self.validation_step_outputs = [] + + def register_schedule(self): + self.num_timesteps = 1000 + + # replace scaled_linear schedule with linear schedule as Zero123++ + beta_start = 0.00085 + beta_end = 0.0120 + betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) + + self.register_buffer('betas', betas.float()) + self.register_buffer('alphas_cumprod', alphas_cumprod.float()) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float()) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float()) + self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float()) + + self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float()) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float()) + + def on_fit_start(self): + device = torch.device(f'cuda:{self.global_rank}') + self.pipeline.to(device) + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) + + def prepare_batch_data(self, batch): + # prepare stable diffusion input + cond_imgs = batch['cond_imgs'] # (B, C, H, W) + cond_imgs = cond_imgs.to(self.device) + + # random resize the condition image + cond_size = np.random.randint(128, 513) + cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1) + + target_imgs = batch['target_imgs'] # (B, 6, C, H, W) + target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1) + target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W) + target_imgs = target_imgs.to(self.device) + + return cond_imgs, target_imgs + + @torch.no_grad() + def forward_vision_encoder(self, images): + dtype = next(self.pipeline.vision_encoder.parameters()).dtype + image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] + image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=self.device, dtype=dtype) + global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds + global_embeds = global_embeds.unsqueeze(-2) + + encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0] + ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1) + encoder_hidden_states = encoder_hidden_states + global_embeds * ramp + + return encoder_hidden_states + + @torch.no_grad() + def encode_condition_image(self, images): + dtype = next(self.pipeline.vae.parameters()).dtype + image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] + image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=self.device, dtype=dtype) + latents = self.pipeline.vae.encode(image_pt).latent_dist.sample() + return latents + + @torch.no_grad() + def encode_target_images(self, images): + dtype = next(self.pipeline.vae.parameters()).dtype + # equals to scaling images to [-1, 1] first and then call scale_image + images = (images - 0.5) / 0.8 # [-0.625, 0.625] + posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist + latents = posterior.sample() * self.pipeline.vae.config.scaling_factor + latents = scale_latents(latents) + return latents + + def forward_unet(self, latents, t, prompt_embeds, cond_latents): + dtype = next(self.pipeline.unet.parameters()).dtype + latents = latents.to(dtype) + prompt_embeds = prompt_embeds.to(dtype) + cond_latents = cond_latents.to(dtype) + cross_attention_kwargs = dict(cond_lat=cond_latents) + pred_noise = self.pipeline.unet( + latents, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + return pred_noise + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def training_step(self, batch, batch_idx): + # get input + cond_imgs, target_imgs = self.prepare_batch_data(batch) + + # sample random timestep + B = cond_imgs.shape[0] + + t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device) + + # classifier-free guidance + if np.random.rand() < self.drop_cond_prob: + prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False) + cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs)) + else: + prompt_embeds = self.forward_vision_encoder(cond_imgs) + cond_latents = self.encode_condition_image(cond_imgs) + + latents = self.encode_target_images(target_imgs) + noise = torch.randn_like(latents) + latents_noisy = self.train_scheduler.add_noise(latents, noise, t) + + v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents) + v_target = self.get_v(latents, noise, t) + + loss, loss_dict = self.compute_loss(v_pred, v_target) + + # logging + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.global_step % 500 == 0 and self.global_rank == 0: + with torch.no_grad(): + latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred) + + latents = unscale_latents(latents_pred) + images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] + images = (images * 0.5 + 0.5).clamp(0, 1) + images = torch.cat([target_imgs, images], dim=-2) + + grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1)) + save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) + + return loss + + def compute_loss(self, noise_pred, noise_gt): + loss = F.mse_loss(noise_pred, noise_gt) + + prefix = 'train' + loss_dict = {} + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # get input + cond_imgs, target_imgs = self.prepare_batch_data(batch) + + images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])] + + outputs = [] + for cond_img in images_pil: + latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images + image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] + image = (image * 0.5 + 0.5).clamp(0, 1) + outputs.append(image) + outputs = torch.cat(outputs, dim=0).to(self.device) + images = torch.cat([target_imgs, outputs], dim=-2) + + self.validation_step_outputs.append(images) + + @torch.no_grad() + def on_validation_epoch_end(self): + images = torch.cat(self.validation_step_outputs, dim=0) + + all_images = self.all_gather(images) + all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') + + if self.global_rank == 0: + grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1)) + save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')) + + self.validation_step_outputs.clear() # free memory + + def configure_optimizers(self): + lr = self.learning_rate + + optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) + + return {'optimizer': optimizer, 'lr_scheduler': scheduler} diff --git a/models/zero123plus/pipeline.py b/models/zero123plus/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0088218346b36f07662d051670e51c658df59f1f --- /dev/null +++ b/models/zero123plus/pipeline.py @@ -0,0 +1,406 @@ +from typing import Any, Dict, Optional +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers + +import numpy +import torch +import torch.nn as nn +import torch.utils.checkpoint +import torch.distributed +import transformers +from collections import OrderedDict +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + EulerAncestralDiscreteScheduler, + UNet2DConditionModel, + ImagePipelineOutput +) +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 +from diffusers.utils.import_utils import is_xformers_available + + +def to_rgb_image(maybe_rgba: Image.Image): + if maybe_rgba.mode == 'RGB': + return maybe_rgba + elif maybe_rgba.mode == 'RGBA': + rgba = maybe_rgba + img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) + img = Image.fromarray(img, 'RGB') + img.paste(rgba, mask=rgba.getchannel('A')) + return img + else: + raise ValueError("Unsupported image type.", maybe_rgba.mode) + + +class ReferenceOnlyAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, + mode="w", ref_dict: dict = None, is_cfg_guidance = False + ) -> Any: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + if self.enabled and is_cfg_guidance: + res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) + hidden_states = hidden_states[1:] + encoder_hidden_states = encoder_hidden_states[1:] + if self.enabled: + if mode == 'w': + ref_dict[self.name] = encoder_hidden_states + elif mode == 'r': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) + elif mode == 'm': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) + else: + assert False, mode + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + if self.enabled and is_cfg_guidance: + res = torch.cat([res0, res]) + return res + + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: + super().__init__() + self.unet = unet + self.train_sched = train_sched + self.val_sched = val_sched + + unet_lora_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + if torch.__version__ >= '2.0': + default_attn_proc = AttnProcessor2_0() + elif is_xformers_available(): + default_attn_proc = XFormersAttnProcessor() + else: + default_attn_proc = AttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + unet.set_attn_processor(unet_lora_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): + if is_cfg_guidance: + encoder_hidden_states = encoder_hidden_states[1:] + class_labels = class_labels[1:] + self.unet( + noisy_cond_lat, timestep, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), + **kwargs + ) + + def forward( + self, sample, timestep, encoder_hidden_states, class_labels=None, + *args, cross_attention_kwargs, + down_block_res_samples=None, mid_block_res_sample=None, + **kwargs + ): + cond_lat = cross_attention_kwargs['cond_lat'] + is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) + noise = torch.randn_like(cond_lat) + if self.training: + noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) + noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) + else: + noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + ref_dict = {} + self.forward_cond( + noisy_cond_lat, timestep, + encoder_hidden_states, class_labels, + ref_dict, is_cfg_guidance, **kwargs + ) + weight_dtype = self.unet.dtype + return self.unet( + sample, timestep, + encoder_hidden_states, *args, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=( + mid_block_res_sample.to(dtype=weight_dtype) + if mid_block_res_sample is not None else None + ), + **kwargs + ) + + +def scale_latents(latents): + latents = (latents - 0.22) * 0.75 + return latents + + +def unscale_latents(latents): + latents = latents / 0.75 + 0.22 + return latents + + +def scale_image(image): + image = image * 0.5 / 0.8 + return image + + +def unscale_image(image): + image = image / 0.5 * 0.8 + return image + + +class DepthControlUNet(torch.nn.Module): + def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: + super().__init__() + self.unet = unet + if controlnet is None: + self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) + else: + self.controlnet = controlnet + DefaultAttnProc = AttnProcessor2_0 + if is_xformers_available(): + DefaultAttnProc = XFormersAttnProcessor + self.controlnet.set_attn_processor(DefaultAttnProc()) + self.conditioning_scale = conditioning_scale + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): + cross_attention_kwargs = dict(cross_attention_kwargs) + control_depth = cross_attention_kwargs.pop('control_depth') + down_block_res_samples, mid_block_res_sample = self.controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=control_depth, + conditioning_scale=self.conditioning_scale, + return_dict=False, + ) + return self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_res_samples=down_block_res_samples, + mid_block_res_sample=mid_block_res_sample, + cross_attention_kwargs=cross_attention_kwargs + ) + + +class ModuleListDict(torch.nn.Module): + def __init__(self, procs: dict) -> None: + super().__init__() + self.keys = sorted(procs.keys()) + self.values = torch.nn.ModuleList(procs[k] for k in self.keys) + + def __getitem__(self, key): + return self.values[self.keys.index(key)] + + +class SuperNet(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = dict(enumerate(state_dict.keys())) + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + return key.split('.')[0] + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = remap_key(key, state_dict) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): + tokenizer: transformers.CLIPTokenizer + text_encoder: transformers.CLIPTextModel + vision_encoder: transformers.CLIPVisionModelWithProjection + + feature_extractor_clip: transformers.CLIPImageProcessor + unet: UNet2DConditionModel + scheduler: diffusers.schedulers.KarrasDiffusionSchedulers + + vae: AutoencoderKL + ramping: nn.Linear + + feature_extractor_vae: transformers.CLIPImageProcessor + + depth_transforms_multi = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]) + ]) + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vision_encoder: transformers.CLIPVisionModelWithProjection, + feature_extractor_clip: CLIPImageProcessor, + feature_extractor_vae: CLIPImageProcessor, + ramping_coefficients: Optional[list] = None, + safety_checker=None, + ): + DiffusionPipeline.__init__(self) + + self.register_modules( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + unet=unet, scheduler=scheduler, safety_checker=None, + vision_encoder=vision_encoder, + feature_extractor_clip=feature_extractor_clip, + feature_extractor_vae=feature_extractor_vae + ) + self.register_to_config(ramping_coefficients=ramping_coefficients) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def prepare(self): + train_sched = DDPMScheduler.from_config(self.scheduler.config) + if isinstance(self.unet, UNet2DConditionModel): + self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() + + def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): + self.prepare() + self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) + return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) + + def encode_condition_image(self, image: torch.Tensor): + image = self.vae.encode(image).latent_dist.sample() + return image + + @torch.no_grad() + def __call__( + self, + image: Image.Image = None, + prompt = "", + *args, + num_images_per_prompt: Optional[int] = 1, + guidance_scale=4.0, + depth_image: Image.Image = None, + output_type: Optional[str] = "pil", + width=640, + height=960, + num_inference_steps=28, + return_dict=True, + **kwargs + ): + self.prepare() + if image is None: + raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") + assert not isinstance(image, torch.Tensor) + image = to_rgb_image(image) + image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values + image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values + if depth_image is not None and hasattr(self.unet, "controlnet"): + depth_image = to_rgb_image(depth_image) + depth_image = self.depth_transforms_multi(depth_image).to( + device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype + ) + image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) + image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) + cond_lat = self.encode_condition_image(image) + if guidance_scale > 1: + negative_lat = self.encode_condition_image(torch.zeros_like(image)) + cond_lat = torch.cat([negative_lat, cond_lat]) + encoded = self.vision_encoder(image_2, output_hidden_states=False) + global_embeds = encoded.image_embeds + global_embeds = global_embeds.unsqueeze(-2) + + if hasattr(self, "encode_prompt"): + encoder_hidden_states = self.encode_prompt( + prompt, + self.device, + num_images_per_prompt, + False + )[0] + else: + encoder_hidden_states = self._encode_prompt( + prompt, + self.device, + num_images_per_prompt, + False + ) + ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) + encoder_hidden_states = encoder_hidden_states + global_embeds * ramp + cak = dict(cond_lat=cond_lat) + if hasattr(self.unet, "controlnet"): + cak['control_depth'] = depth_image + latents: torch.Tensor = super().__call__( + None, + *args, + cross_attention_kwargs=cak, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=encoder_hidden_states, + num_inference_steps=num_inference_steps, + output_type='latent', + width=width, + height=height, + **kwargs + ).images + latents = unscale_latents(latents) + if not output_type == "latent": + image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..62c746d908b097e3ad07c077676858b5d5d4dd22 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +torch==2.1.0 +torchvision==0.16.0 +-f https://download.pytorch.org/whl/cu121 + +comfy==0.0.1 +einops==0.8.0 +imageio==2.34.1 +Imath==0.0.2 +jaxtyping==0.2.36 +kiui==0.2.10 +mathutils==3.3.0 +matplotlib==3.9.2 +git+https://github.com/NVlabs/nvdiffrast +omegaconf==2.3.0 +onnxruntime==1.18.0 +open3d==0.18.0 +opencv_python==4.10.0.84 +opencv_python_headless==4.9.0.80 +OpenEXR==3.3.2 +Pillow==11.0.0 +plyfile==1.1 +PyGLM==2.7.1 +pygltflib==1.16.3 +PyMCubes==0.1.4 +pymeshlab==2023.12.post1 +PyOpenGL==3.1.0 +pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt210/pytorch3d-0.7.4-cp310-cp310-linux_x86_64.whl +pytorch_lightning==2.2.0 +PyYAML==6.0.1 +rembg==2.0.57 +scipy==1.14.1 +torch-scatter +-f https://data.pyg.org/whl/torch-2.1.0+cu121.html +tqdm==4.66.4 +transformers==4.39.3 +trimesh==4.4.0 +webdataset==0.2.86 +xatlas==0.0.9 +ninja +sentencepiece +# 添加本地依赖(可编辑模式) +-e env/diffusers +accelerate diff --git a/text_to_mesh.py b/text_to_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b291de6ea68a631a67c1db0cb8338dd9237b60e9 --- /dev/null +++ b/text_to_mesh.py @@ -0,0 +1,232 @@ +import os +from einops import rearrange +from omegaconf import OmegaConf +import torch +import numpy as np +import trimesh +import torchvision +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms +from torchvision.transforms import v2 +from diffusers import HeunDiscreteScheduler +from diffusers import FluxPipeline +from pytorch_lightning import seed_everything +import os +import time +from models.lrm.utils.infer_util import save_video +from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl +from models.lrm.utils.render_utils import rotate_x, rotate_y +from models.lrm.utils.train_util import instantiate_from_config +from models.lrm.utils.camera_util import get_flux_input_cameras +from models.ISOMER.reconstruction_func import reconstruction +from models.ISOMER.projection_func import projection +from utils.tool import NormalTransfer, load_mipmap +from utils.tool import get_background, get_render_cameras_video, render_frames + +device = "cuda" +resolution = 512 +save_dir = "./outputs" +normal_transfer = NormalTransfer() +isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device) +isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device) +isomer_radius = 4.5 +isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device) +isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device) + +# model initialization and loading +# flux +flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16) +flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors') + +flux_pipe.to(device=device, dtype=torch.bfloat16) +generator = torch.Generator(device=device).manual_seed(10) + +# lrm +config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml") +model_config = config.model_config +infer_config = config.infer_config +model = instantiate_from_config(model_config) +model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt" +state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] +state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} +model.load_state_dict(state_dict, strict=True) + +model = model.to(device) +model.init_flexicubes_geometry(device, fovy=50.0) +model = model.eval() + +# Flux multi-view generation +def multi_view_rgb_normal_generation(prompt, save_path=None): + # generate multi-view images + with torch.no_grad(): + image = flux_pipe( + prompt=prompt, + num_inference_steps=30, + guidance_scale=3.5, + num_images_per_prompt=1, + width=resolution*4, + height=resolution*2, + output_type='np', + generator=generator + ).images + return image + +# lrm reconstructions +def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False): + images = image.unsqueeze(0).to(device) + images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) + # breakpoint() + with torch.no_grad(): + # get triplane + planes = model.forward_planes(images, input_cameras) + + mesh_path_idx = os.path.join(save_path, f'{name}.obj') + + mesh_out = model.extract_mesh( + planes, + use_texture_map=export_texmap, + **infer_config, + ) + if export_texmap: + vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out + save_obj_with_mtl( + vertices.data.cpu().numpy(), + uvs.data.cpu().numpy(), + faces.data.cpu().numpy(), + mesh_tex_idx.data.cpu().numpy(), + tex_map.permute(1, 2, 0).data.cpu().numpy(), + mesh_path_idx, + ) + else: + vertices, faces, vertex_colors = mesh_out + save_obj(vertices, faces, vertex_colors, mesh_path_idx) + print(f"Mesh saved to {mesh_path_idx}") + + render_size = 512 + if if_save_video: + video_path_idx = os.path.join(save_path, f'{name}.mp4') + render_size = infer_config.render_resolution + ENV = load_mipmap("models/lrm/env_mipmap/6") + materials = (0.0,0.9) + + all_mv, all_mvp, all_campos = get_render_cameras_video( + batch_size=1, + M=240, + radius=4.5, + elevation=(90, 60.0), + is_flexicubes=True, + fov=30 + ) + + frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( + model, + planes, + render_cameras=all_mvp, + camera_pos=all_campos, + env=ENV, + materials=materials, + render_size=render_size, + chunk_size=20, + is_flexicubes=True, + ) + normals = (torch.nn.functional.normalize(normals) + 1) / 2 + normals = normals * alphas + (1-alphas) + all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3) + + save_video( + all_frames, + video_path_idx, + fps=30, + ) + print(f"Video saved to {video_path_idx}") + + return vertices, faces + + +def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg): + if local_normal_images.min() >= 0: + local_normal = local_normal_images.float() * 2 - 1 + else: + local_normal = local_normal_images.float() + global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) + global_normal[...,0] *= -1 + global_normal = (global_normal + 1) / 2 + global_normal = global_normal.permute(0, 3, 1, 2) + return global_normal + +def main(): + end = time.time() + fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.' + # user prompt + prompt = "a owl wearing a hat." + save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_")) + os.makedirs(save_dir_path, exist_ok=True) + prompt = fix_prompt+" "+prompt + # generate multi-view images + rgb_normal_grid = multi_view_rgb_normal_generation(prompt) + # lrm reconstructions + images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) + images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512) + rgb_multi_view = images[:4, :3, :, :] + normal_multi_view = images[4:, :3, :, :] + multi_view_mask = get_background(normal_multi_view) + rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask) + input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device) + vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False) + # local normal to global normal + + global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations) + global_normal = global_normal * multi_view_mask + (1-multi_view_mask) + + global_normal = global_normal.permute(0,2,3,1) + rgb_multi_view = rgb_multi_view.permute(0,2,3,1) + multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1) + vertices = torch.from_numpy(vertices).to(device) + faces = torch.from_numpy(faces).to(device) + vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3] + vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3] + + # global_normal: B,H,W,3 + # multi_view_mask: B,H,W + # rgb_multi_view: B,H,W,3 + + meshes = reconstruction( + normal_pils=global_normal, + masks=multi_view_mask, + weights=isomer_geo_weights, + fov=30, + radius=isomer_radius, + camera_angles_azi=isomer_azimuths, + camera_angles_ele=isomer_elevations, + expansion_weight_stage1=0.1, + init_type="file", + init_verts=vertices, + init_faces=faces, + stage1_steps=0, + stage2_steps=50, + start_edge_len_stage1=0.1, + end_edge_len_stage1=0.02, + start_edge_len_stage2=0.02, + end_edge_len_stage2=0.005, + ) + + + save_glb_addr = projection( + meshes, + masks=multi_view_mask, + images=rgb_multi_view, + azimuths=isomer_azimuths, + elevations=isomer_elevations, + weights=isomer_color_weights, + fov=30, + radius=isomer_radius, + save_dir=f"{save_dir_path}/ISOMER/", + ) + print(f'saved to {save_glb_addr}') + print(f"Time elapsed: {time.time() - end:.2f}s") + + + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/demo.py b/utils/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf5ab58185596db5bd67b5fe245c84fb393c114 --- /dev/null +++ b/utils/demo.py @@ -0,0 +1,162 @@ +import os +import numpy as np +from PIL import Image +import rembg +import PIL +from typing import Any +import torch +import cv2 +from tqdm import tqdm +import torchvision + + +class NormalTransfer: + def __init__(self): + self.identity_w2c = torch.tensor([ + [0.0, 0.0, 1.0, 0.0], + [ 0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 4.5]]).float() + + def look_at(self,camera_position, target_position, up_vector=np.array([0, 0, 1])): + forward = camera_position - target_position + forward = forward / np.linalg.norm(forward) + + right = np.cross(up_vector, forward) + right = right / np.linalg.norm(right) + + up = np.cross(forward, right) + + rotation_matrix = np.array([right, up, forward]).T + + translation_matrix = np.eye(4) + translation_matrix[:3, 3] = -camera_position + + rotation_homogeneous = np.eye(4) + rotation_homogeneous[:3, :3] = rotation_matrix + + w2c = rotation_homogeneous @ translation_matrix + return w2c + + def generate_target_pose(self, azimuths_deg, elevations_deg, radius=4.5): + azimuths = np.deg2rad(azimuths_deg) + elevations = np.deg2rad(elevations_deg) + + x = radius * np.cos(azimuths) * np.cos(elevations) + y = radius * np.sin(azimuths) * np.cos(elevations) + z = radius * np.sin(elevations) + camera_positions = np.stack([x, y, z], axis=-1) + + target_position = np.array([0, 0, 0]) # 目标点位置 + + # 为每个相机位置生成 w2c 矩阵 + w2c_matrices = [self.look_at(cam_pos, target_position) for cam_pos in camera_positions] + w2c_matrices = np.stack(w2c_matrices, axis=0) + return w2c_matrices + + def convert_to_blender(self, pose): + # Swap the y and z axes + w2c_opengl = pose + w2c_opengl[[1, 2], :] = w2c_opengl[[2, 1], :] + + # Invert the y axis + w2c_opengl[1] *= -1 + R = w2c_opengl[:3, :3] + t = w2c_opengl[:3, 3] + + cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) + R = R.T + t = -R @ t + R_world2cv = cam_rec @ R + t_world2cv = cam_rec @ t + + RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) + return RT + + def worldNormal2camNormal(self, rot_w2c, normal_map_world): + H,W,_ = normal_map_world.shape + # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + normal_map_world = normal_map_world[...,:3] + # faster version + normal_map_flat = normal_map_world.contiguous().view(-1, 3) + + normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float()) + + # Reshape the transformed normal map back to its original shape + normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape) + + return normal_map_camera + + def trans_normal(self, normal, RT_w2c, RT_w2c_target): + """ + :param normal: (H,W,3), torch tensor, range [-1,1] + :param RT_w2c: (4,4), torch tensor, world to camera + :param RT_w2c_target: (4,4), torch tensor, world to camera + :return: normal_target_cam: (H,W,3), torch tensor, range [-1,1] + """ + relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3])) + normal_target_cam = self.worldNormal2camNormal(relative_RT[:3,:3], normal) + + return normal_target_cam + + def trans_local_2_global(self, normal_local, azimuths_deg, elevations_deg, radius=4.5, for_lotus=True): + """ + :param normal_local: (B,H,W,3), torch tensor, range [-1,1] + :param azimuths_deg: (B,), numpy array, range [0,360] + :param elevations_deg: (B,), numpy array, range [-90,90] + :param radius: float, default 4.5 + :return: global_normal: (B,H,W,3), torch tensor, range [-1,1] + + """ + # print(f"normal_local.shape:{normal_local.shape}") + # print(f"azimuths_deg.shape:{azimuths_deg.shape}") + # print(f"elevations_deg.shape:{elevations_deg.shape}") + assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0] + identity_w2c = self.identity_w2c + + # generate target pose + target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius) + target_w2c = torch.from_numpy(np.stack([self.convert_to_blender(w2c) for w2c in target_w2c])).float() + global_normal = [] + + # transform normal + for i in range(normal_local.shape[0]): + normal_local_i = normal_local[i] + normal_zero123 = self.trans_normal(normal_local_i, target_w2c[i], identity_w2c) + global_normal.append(normal_zero123) + + global_normal = torch.stack(global_normal, dim=0) + if for_lotus: + global_normal[...,0] *= -1 + global_normal = global_normal / torch.norm(global_normal, dim=-1, keepdim=True) + return global_normal + + def trans_global_2_local(self, normal_local, azimuths_deg, elevations_deg, radius=4.5): + """ + :param normal_global: (B,H,W,3), torch tensor, range [-1,1] + :param azimuths_deg: (B,), numpy array, range [0,360] + :param elevations_deg: (B,), numpy array, range [-90,90] + :param radius: float, default 4.5 + :return: local_normal: (B,H,W,3), torch tensor, range [-1,1] + + """ + print(f"normal_local.shape:{normal_local.shape}") + print(f"azimuths_deg.shape:{azimuths_deg.shape}") + print(f"elevations_deg.shape:{elevations_deg.shape}") + assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0] + identity_w2c = self.identity_w2c + + # generate target pose + target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius) + target_w2c = torch.from_numpy(np.stack([self.convert_to_blender(w2c) for w2c in target_w2c])).float() + local_normal = [] + + # transform normal + for i in range(normal_local.shape[0]): + normal_local_i = normal_local[i] + normal = self.trans_normal(normal_local_i, identity_w2c, target_w2c[i]) + local_normal.append(normal) + + local_normal = torch.stack(local_normal, dim=0) + # global_normal[...,0] *= -1 + local_normal = local_normal / torch.norm(local_normal, dim=-1, keepdim=True) + return local_normal \ No newline at end of file diff --git a/utils/florence_caption.py b/utils/florence_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..c926bb7928fbfff163e626266501cf7b342f54b3 --- /dev/null +++ b/utils/florence_caption.py @@ -0,0 +1,22 @@ + +import subprocess +import gradio as gr +from PIL import Image +import requests +import torch +import uuid +import shutil +import json +import yaml +from transformers import AutoProcessor, AutoModelForCausalLM +from tqdm import tqdm +import os + + + +if __name__ == "__main__": + # url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" + image = Image.open("examples/cartoon_dinosaur.png").convert("RGB") + breakpoint() + run_captioning(image) + diff --git a/utils/tool.py b/utils/tool.py new file mode 100644 index 0000000000000000000000000000000000000000..3257df7dd13a58631b08d3943cf0ccbd0ffba63a --- /dev/null +++ b/utils/tool.py @@ -0,0 +1,472 @@ +import rembg +import cv2 +import numpy as np +import glm +import torch +from tqdm import tqdm +import torchvision +import torchvision.transforms.v2 as T +from models.lrm.utils import render_utils +import os +# get the background of the image +import torch +import numpy as np +import scipy +import cv2 +from rembg import remove + + +def load_mipmap(env_path): + diffuse_path = os.path.join(env_path, "diffuse.pth") + diffuse = torch.load(diffuse_path, map_location=torch.device('cpu')) + + specular = [] + for i in range(6): + specular_path = os.path.join(env_path, f"specular_{i}.pth") + specular_tensor = torch.load(specular_path, map_location=torch.device('cpu')) + specular.append(specular_tensor) + return [specular, diffuse] + +def get_background(img_tensor): + """ + Args: + img_tensor: 输入图像张量,形状为 (B, 3, H, W),数值范围为 [0, 1] 或 [0, 255]。 + Returns: + mask_tensor: 输出掩码张量,形状为 (B, 1, H, W),二值化。 + """ + B, C, H, W = img_tensor.shape + assert C == 3, "Input tensor must have 3 channels (RGB)." + + # 将 tensor 转换为 numpy 格式 (B, H, W, C),并归一化到 [0, 255] + img_numpy = (img_tensor.permute(0, 2, 3, 1) * 255).byte().cpu().numpy() # (B, H, W, C) + + masks = [] + for i in range(B): + # 调用 rembg 生成掩码 + mask = remove(img_numpy[i], only_mask=True) + + # 转换为二值掩码 + mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] + + # 添加到结果列表 (H, W, 1) + masks.append(mask_binary[..., None]) + + # 将所有掩码组合成 numpy 数组,形状为 (B, H, W, 1) + masks = np.stack(masks, axis=0) + + # 转换为 PyTorch 张量,形状为 (B, 1, H, W),值为 {0, 1} + mask_tensor = torch.from_numpy(masks).permute(0, 3, 1, 2).float() / 255.0 + # breakpoint() + return mask_tensor + +def get_render_cameras_video(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50): + """ + Get the rendering camera parameters. + """ + train_res = [512, 512] + cam_near_far = [0.1, 1000.0] + fovy = np.deg2rad(fov) + proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1]) + all_mv = [] + all_mvp = [] + all_campos = [] + if isinstance(elevation, tuple): + elevation_0 = np.deg2rad(elevation[0]) + elevation_1 = np.deg2rad(elevation[1]) + for i in range(M//2): + azimuth = 2 * np.pi * i / (M // 2) + z = radius * np.cos(azimuth) * np.sin(elevation_0) + x = radius * np.sin(azimuth) * np.sin(elevation_0) + y = radius * np.cos(elevation_0) + + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + mv = torch.from_numpy(np.array(view_matrix)) + mvp = proj_mtx @ (mv) #w2c + campos = torch.linalg.inv(mv)[:3, 3] + all_mv.append(mv[None, ...].cuda()) + all_mvp.append(mvp[None, ...].cuda()) + all_campos.append(campos[None, ...].cuda()) + for i in range(M//2): + azimuth = 2 * np.pi * i / (M // 2) + z = radius * np.cos(azimuth) * np.sin(elevation_1) + x = radius * np.sin(azimuth) * np.sin(elevation_1) + y = radius * np.cos(elevation_1) + + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + mv = torch.from_numpy(np.array(view_matrix)) + mvp = proj_mtx @ (mv) #w2c + campos = torch.linalg.inv(mv)[:3, 3] + all_mv.append(mv[None, ...].cuda()) + all_mvp.append(mvp[None, ...].cuda()) + all_campos.append(campos[None, ...].cuda()) + else: + # elevation = 90 - elevation + for i in range(M): + azimuth = 2 * np.pi * i / M + z = radius * np.cos(azimuth) * np.sin(elevation) + x = radius * np.sin(azimuth) * np.sin(elevation) + y = radius * np.cos(elevation) + + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + mv = torch.from_numpy(np.array(view_matrix)) + mvp = proj_mtx @ (mv) #w2c + campos = torch.linalg.inv(mv)[:3, 3] + all_mv.append(mv[None, ...].cuda()) + all_mvp.append(mvp[None, ...].cuda()) + all_campos.append(campos[None, ...].cuda()) + all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2) + all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2) + all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2) + return all_mv, all_mvp, all_campos + +def get_render_cameras_frames(batch_size=1, radius=4.0, azimuths=0, elevations=20.0, fov=30): + """ + Get the rendering camera parameters. + """ + train_res = [512, 512] + cam_near_far = [0.1, 1000.0] + fovy = np.deg2rad(fov) + proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1]) + all_mv = [] + all_mvp = [] + all_campos = [] + elevations = 90 - elevations + if isinstance(elevations, np.ndarray) or isinstance(elevations, torch.Tensor): + if isinstance(elevations, torch.Tensor): + elevations = elevations.cpu().numpy() + if isinstance(azimuths, torch.Tensor): + azimuths = azimuths.cpu().numpy() + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + for azi, ele in zip(azimuths, elevations): + z = radius * np.cos(azi) * np.sin(ele) + x = radius * np.sin(azi) * np.sin(ele) + y = radius * np.cos(ele) + + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + mv = torch.from_numpy(np.array(view_matrix)) + mvp = proj_mtx @ (mv) #w2c + campos = torch.linalg.inv(mv)[:3, 3] + all_mv.append(mv[None, ...].cuda()) + all_mvp.append(mvp[None, ...].cuda()) + all_campos.append(campos[None, ...].cuda()) + + else: + z = radius * np.cos(azimuths) * np.sin(elevations) + x = radius * np.sin(azimuths) * np.sin(elevations) + y = radius * np.cos(elevations) + + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + mv = torch.from_numpy(np.array(view_matrix)) + mvp = proj_mtx @ (mv) #w2c + campos = torch.linalg.inv(mv)[:3, 3] + all_mv.append(mv[None, ...].cuda()) + all_mvp.append(mvp[None, ...].cuda()) + all_campos.append(campos[None, ...].cuda()) + + # TODO, identity pose + identity_azimuths = np.array([0]) + identity_elevations = np.array([90]) + z = radius * np.cos(identity_azimuths) * np.sin(identity_elevations) + x = radius * np.sin(identity_azimuths) * np.sin(identity_elevations) + y = radius * np.cos(identity_elevations) + eye = glm.vec3(x, y, z) + at = glm.vec3(0.0, 0.0, 0.0) + up = glm.vec3(0.0, 1.0, 0.0) + view_matrix = glm.lookAt(eye, at, up) + identity_mv = torch.from_numpy(np.array(view_matrix)) + + all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2) + all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2) + all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2) + return all_mv, all_mvp, all_campos, identity_mv + + +def worldNormal2camNormal(rot_w2c, normal_map_world): + H,W,_ = normal_map_world.shape + # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + normal_map_world = normal_map_world[...,:3] + # faster version + normal_map_flat = normal_map_world.view(-1, 3) + normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float()) + + # Reshape the transformed normal map back to its original shape + normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape) + + return normal_map_camera + + +def trans_normal(normal, RT_w2c, RT_w2c_target): + + # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) + # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) + + relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3])) + normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal) + + return normal_target_cam + +def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1, + is_flexicubes=False, render_mv=None, local_normal=False, identity_mv=None): + """ + Render frames from triplanes. + """ + frames = [] + albedos = [] + pbr_spec_lights = [] + pbr_diffuse_lights = [] + normals = [] + alphas = [] + for i in tqdm(range(0, render_cameras.shape[1])): + out = model.forward_geometry( + planes, + render_cameras[:, i:i+chunk_size], + camera_pos[:, i:i+chunk_size], + [[env]*chunk_size], + [[materials]*chunk_size], + render_size=render_size, + ) + frame = out['pbr_img'] + albedo = out['albedo'] + pbr_spec_light = out['pbr_spec_light'] + pbr_diffuse_light = out['pbr_diffuse_light'] + normal = out['normal'] + alpha = out['mask'] + # breakpoint() + if local_normal: + # TODO global normal to local + target_w2c = render_mv[0,i,:3,:3] + identity_w2c = identity_mv[:3,:3] + # breakpoint() + # torchvision.utils.save_image((normal.permute(0,3,1,2)+1)/2, f"debug_output/global_normal.png") + normal = trans_normal(normal.squeeze(0), identity_w2c.cuda(), target_w2c.cuda()) + normal = normal / torch.norm(normal, dim=-1, keepdim=True) + # torchvision.utils.save_image((normal.permute(2,0,1)+1)/2, f"debug_output/local_normal.png") + background_normal = torch.tensor([1,1,1], dtype=torch.float32, device=normal.device) + normal = normal.unsqueeze(0) + normal[...,0] *= -1 + # breakpoint() + normal = normal * alpha.squeeze(0).permute(0,2,3,1) + background_normal * (1-alpha.squeeze(0).permute(0,2,3,1)) + frames.append(frame) + albedos.append(albedo) + pbr_spec_lights.append(pbr_spec_light) + pbr_diffuse_lights.append(pbr_diffuse_light) + normals.append(normal) + alphas.append(alpha) + + frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1 + alphas = torch.cat(alphas, dim=1)[0] + albedos = torch.cat(albedos, dim=1)[0] + pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0] + pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0] + normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3] + return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas + +# from https://github.com/cubiq/ComfyUI_essentials +def mask_fix(mask, erode_dilate=0, smooth=0, remove_isolated_pixels=0, blur=0, fill_holes=0): + masks = [] + for m in mask: + # erode and dilate + if erode_dilate != 0: + if erode_dilate < 0: + m = torch.from_numpy(scipy.ndimage.grey_erosion(m.cpu().numpy(), size=(-erode_dilate, -erode_dilate))) + else: + m = torch.from_numpy(scipy.ndimage.grey_dilation(m.cpu().numpy(), size=(erode_dilate, erode_dilate))) + + # fill holes + if fill_holes > 0: + #m = torch.from_numpy(scipy.ndimage.binary_fill_holes(m.cpu().numpy(), structure=np.ones((fill_holes,fill_holes)))).float() + m = torch.from_numpy(scipy.ndimage.grey_closing(m.cpu().numpy(), size=(fill_holes, fill_holes))) + + # remove isolated pixels + if remove_isolated_pixels > 0: + m = torch.from_numpy(scipy.ndimage.grey_opening(m.cpu().numpy(), size=(remove_isolated_pixels, remove_isolated_pixels))) + + # smooth the mask + if smooth > 0: + if smooth % 2 == 0: + smooth += 1 + m = T.functional.gaussian_blur((m > 0.5).unsqueeze(0), smooth).squeeze(0) + + # blur the mask + if blur > 0: + if blur % 2 == 0: + blur += 1 + m = T.functional.gaussian_blur(m.float().unsqueeze(0), blur).squeeze(0) + + masks.append(m.float()) + + masks = torch.stack(masks, dim=0).float() + + return masks + + +class NormalTransfer: + def __init__(self): + self.identity_w2c = torch.tensor([ + [0.0, 0.0, 1.0, 0.0], + [ 0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 4.5]]).float() + + def look_at(self,camera_position, target_position, up_vector=np.array([0, 0, 1])): + forward = camera_position - target_position + forward = forward / np.linalg.norm(forward) + + right = np.cross(up_vector, forward) + right = right / np.linalg.norm(right) + + up = np.cross(forward, right) + + rotation_matrix = np.array([right, up, forward]).T + + translation_matrix = np.eye(4) + translation_matrix[:3, 3] = -camera_position + + rotation_homogeneous = np.eye(4) + rotation_homogeneous[:3, :3] = rotation_matrix + + w2c = rotation_homogeneous @ translation_matrix + return w2c + + def generate_target_pose(self, azimuths_deg, elevations_deg, radius=4.5): + if isinstance(azimuths_deg, torch.Tensor): + azimuths_deg = azimuths_deg.cpu().numpy() + if isinstance(elevations_deg, torch.Tensor): + elevations_deg = elevations_deg.cpu().numpy() + azimuths = np.deg2rad(azimuths_deg) + elevations = np.deg2rad(elevations_deg) + + x = radius * np.cos(azimuths) * np.cos(elevations) + y = radius * np.sin(azimuths) * np.cos(elevations) + z = radius * np.sin(elevations) + camera_positions = np.stack([x, y, z], axis=-1) + + target_position = np.array([0, 0, 0]) # 目标点位置 + + # 为每个相机位置生成 w2c 矩阵 + w2c_matrices = [self.look_at(cam_pos, target_position) for cam_pos in camera_positions] + w2c_matrices = np.stack(w2c_matrices, axis=0) + return w2c_matrices + + def convert_to_blender(self, pose): + # Swap the y and z axes + w2c_opengl = pose + w2c_opengl[[1, 2], :] = w2c_opengl[[2, 1], :] + + # Invert the y axis + w2c_opengl[1] *= -1 + R = w2c_opengl[:3, :3] + t = w2c_opengl[:3, 3] + + cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) + R = R.T + t = -R @ t + R_world2cv = cam_rec @ R + t_world2cv = cam_rec @ t + + RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) + return RT + + def worldNormal2camNormal(self, rot_w2c, normal_map_world): + H,W,_ = normal_map_world.shape + # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + normal_map_world = normal_map_world[...,:3] + # faster version + normal_map_flat = normal_map_world.contiguous().view(-1, 3) + + normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float()) + + # Reshape the transformed normal map back to its original shape + normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape) + + return normal_map_camera + + def trans_normal(self, normal, RT_w2c, RT_w2c_target): + """ + :param normal: (H,W,3), torch tensor, range [-1,1] + :param RT_w2c: (4,4), torch tensor, world to camera + :param RT_w2c_target: (4,4), torch tensor, world to camera + :return: normal_target_cam: (H,W,3), torch tensor, range [-1,1] + """ + relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3])) + normal_target_cam = self.worldNormal2camNormal(relative_RT[:3,:3], normal) + + return normal_target_cam + + def trans_local_2_global(self, normal_local, azimuths_deg, elevations_deg, radius=4.5, for_lotus=True): + """ + :param normal_local: (B,H,W,3), torch tensor, range [-1,1] + :param azimuths_deg: (B,), numpy array, range [0,360] + :param elevations_deg: (B,), numpy array, range [-90,90] + :param radius: float, default 4.5 + :return: global_normal: (B,H,W,3), torch tensor, range [-1,1] + + """ + # print(f"normal_local.shape:{normal_local.shape}") + # print(f"azimuths_deg.shape:{azimuths_deg.shape}") + # print(f"elevations_deg.shape:{elevations_deg.shape}") + assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0] + identity_w2c = self.identity_w2c + + # generate target pose + target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius) + target_w2c = torch.from_numpy(np.stack([self.convert_to_blender(w2c) for w2c in target_w2c])).float() + global_normal = [] + + # transform normal + for i in range(normal_local.shape[0]): + normal_local_i = normal_local[i] + normal_zero123 = self.trans_normal(normal_local_i, target_w2c[i], identity_w2c) + global_normal.append(normal_zero123) + + global_normal = torch.stack(global_normal, dim=0) + if for_lotus: + global_normal[...,0] *= -1 + global_normal = global_normal / torch.norm(global_normal, dim=-1, keepdim=True) + return global_normal + + def trans_global_2_local(self, normal_local, azimuths_deg, elevations_deg, radius=4.5): + """ + :param normal_global: (B,H,W,3), torch tensor, range [-1,1] + :param azimuths_deg: (B,), numpy array, range [0,360] + :param elevations_deg: (B,), numpy array, range [-90,90] + :param radius: float, default 4.5 + :return: local_normal: (B,H,W,3), torch tensor, range [-1,1] + + """ + print(f"normal_local.shape:{normal_local.shape}") + print(f"azimuths_deg.shape:{azimuths_deg.shape}") + print(f"elevations_deg.shape:{elevations_deg.shape}") + assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0] + identity_w2c = self.identity_w2c + + # generate target pose + target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius) + target_w2c = torch.from_numpy(np.stack([w2c for w2c in target_w2c])).float() + local_normal = [] + + # transform normal + for i in range(normal_local.shape[0]): + normal_local_i = normal_local[i] + normal = self.trans_normal(normal_local_i, identity_w2c, target_w2c[i]) + local_normal.append(normal) + + local_normal = torch.stack(local_normal, dim=0) + # global_normal[...,0] *= -1 + local_normal = local_normal / torch.norm(local_normal, dim=-1, keepdim=True) + return local_normal \ No newline at end of file