Kiss3DGen / pipeline /utils.py
JiantaoLin
new
476d2e1
import os
import sys
import logging
__workdir__ = '/'.join(os.path.abspath(__file__).split('/')[:-2])
sys.path.insert(0, __workdir__)
print(__workdir__)
import numpy as np
import torch
from torchvision.transforms import v2
from PIL import Image
import rembg
from models.lrm.online_render.render_single import load_mipmap
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
from models.lrm.utils.render_utils import rotate_x, rotate_y
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
from models.lrm.utils.infer_util import remove_background, resize_foreground
from models.ISOMER.reconstruction_func import reconstruction
from models.ISOMER.projection_func import projection
from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix
logging.basicConfig(
level = logging.INFO
)
logger = logging.getLogger('kiss3d_wrapper')
OUT_DIR = './outputs'
TMP_DIR = './outputs/tmp'
os.makedirs(TMP_DIR, exist_ok=True)
@torch.no_grad()
def lrm_reconstruct(model, infer_config, images,
name='', export_texmap=False,
input_camera_type='zero123',
render_3d_bundle_image=True,
render_azimuths=[270, 0, 90, 180],
render_elevations=[5, 5, 5, 5],
render_radius=4.15):
"""
image: Tensor, shape (1, c, h, w)
"""
mesh_path_idx = os.path.join(TMP_DIR, f'{name}_recon_from_{input_camera_type}.obj')
device = images.device
if input_camera_type == 'zero123':
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
elif input_camera_type == 'kiss3d':
input_cameras = get_flux_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
else:
raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
logger.info(f"==> Runing LRM reconstruction ...")
planes = model.forward_planes(images, input_cameras)
mesh_out = model.extract_mesh(
planes,
use_texture_map=export_texmap,
**infer_config,
)
if export_texmap:
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
save_obj_with_mtl(
vertices.data.cpu().numpy(),
uvs.data.cpu().numpy(),
faces.data.cpu().numpy(),
mesh_tex_idx.data.cpu().numpy(),
tex_map.permute(1, 2, 0).data.cpu().numpy(),
mesh_path_idx,
)
else:
vertices, faces, vertex_colors = mesh_out
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
logger.info(f"Mesh saved to {mesh_path_idx}")
if render_3d_bundle_image:
assert render_azimuths is not None and render_elevations is not None and render_radius is not None
render_azimuths = torch.Tensor(render_azimuths).to(device)
render_elevations = torch.Tensor(render_elevations).to(device)
render_size = infer_config.render_resolution
ENV = load_mipmap("models/lrm/env_mipmap/6")
materials = (0.0,0.9)
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
batch_size=1,
radius=render_radius,
azimuths=render_azimuths,
elevations=render_elevations,
fov=30
)
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
model,
planes,
render_cameras=all_mvp,
camera_pos=all_campos,
env=ENV,
materials=materials,
render_size=render_size,
render_mv = all_mv,
local_normal=True,
identity_mv=identity_mv,
)
else:
normals = None
frames = None
albedos = None
vertices = torch.from_numpy(vertices).to(device)
faces = torch.from_numpy(faces).to(device)
vertices = vertices @ rotate_x(np.pi / 2, device=device)[:3, :3]
vertices = vertices @ rotate_y(np.pi / 2, device=device)[:3, :3]
return vertices.cpu(), faces.cpu(), normals, frames, albedos
normal_transfer = NormalTransfer()
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
if local_normal_images.min() >= 0:
local_normal = local_normal_images.float() * 2 - 1
else:
local_normal = local_normal_images.float()
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
global_normal[...,0] *= -1
global_normal = (global_normal + 1) / 2
global_normal = global_normal.permute(0, 3, 1, 2)
return global_normal
def isomer_reconstruct(
rgb_multi_view,
normal_multi_view,
multi_view_mask,
vertices,
faces,
save_path=None,
azimuths=[0, 90, 180, 270],
elevations=[5, 5, 5, 5],
geo_weights=[1, 0.9, 1, 0.9],
color_weights=[1, 0.5, 1, 0.5],
reconstruction_stage1_steps=10,
reconstruction_stage2_steps=50,
radius=4.5):
device = rgb_multi_view.device
to_tensor_ = lambda x: torch.Tensor(x).float().to(device)
# local normal to global normal
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), to_tensor_(azimuths), to_tensor_(elevations)).to(device)
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
global_normal = global_normal.permute(0,2,3,1)
multi_view_mask = multi_view_mask.squeeze(1)
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
logger.info(f"==> Runing ISOMER reconstruction ...")
meshes = reconstruction(
normal_pils=global_normal,
masks=multi_view_mask,
weights=to_tensor_(geo_weights),
fov=30,
radius=radius,
camera_angles_azi=to_tensor_(azimuths),
camera_angles_ele=to_tensor_(elevations),
expansion_weight_stage1=0.1,
init_type="file",
init_verts=vertices,
init_faces=faces,
stage1_steps=reconstruction_stage1_steps,
stage2_steps=reconstruction_stage2_steps,
start_edge_len_stage1=0.1,
end_edge_len_stage1=0.02,
start_edge_len_stage2=0.02,
end_edge_len_stage2=0.005,
)
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
logger.info(f"==> Runing ISOMER projection ...")
save_glb_addr = projection(
meshes,
masks=multi_view_mask_proj.to(device),
images=rgb_multi_view.to(device),
azimuths=to_tensor_(azimuths),
elevations=to_tensor_(elevations),
weights=to_tensor_(color_weights),
fov=30,
radius=radius,
save_dir=TMP_DIR,
save_glb_addr=save_path
)
logger.info(f"==> Save mesh to {save_glb_addr} ...")
return save_glb_addr
def to_rgb_image(maybe_rgba):
assert isinstance(maybe_rgba, Image.Image)
if maybe_rgba.mode == 'RGB':
return maybe_rgba, None
elif maybe_rgba.mode == 'RGBA':
rgba = maybe_rgba
img = np.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
img = Image.fromarray(img, 'RGB')
img.paste(rgba, mask=rgba.getchannel('A'))
return img, rgba.getchannel('A')
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
rembg_session = rembg.new_session("u2net")
def preprocess_input_image(input_image):
"""
input_image: PIL.Image
output_image: PIL.Image, (3, 512, 512), mode = RGB, background = white
"""
image = remove_background(to_rgb_image(input_image)[0], rembg_session, bgcolor=(255, 255, 255, 255))
image = resize_foreground(image, ratio=0.85, pad_value=255)
return to_rgb_image(image)[0]