|
import os |
|
import sys |
|
import logging |
|
|
|
__workdir__ = '/'.join(os.path.abspath(__file__).split('/')[:-2]) |
|
sys.path.insert(0, __workdir__) |
|
|
|
print(__workdir__) |
|
|
|
import numpy as np |
|
import torch |
|
from torchvision.transforms import v2 |
|
from PIL import Image |
|
import rembg |
|
|
|
from models.lrm.online_render.render_single import load_mipmap |
|
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras |
|
from models.lrm.utils.render_utils import rotate_x, rotate_y |
|
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl |
|
from models.lrm.utils.infer_util import remove_background, resize_foreground |
|
|
|
from models.ISOMER.reconstruction_func import reconstruction |
|
from models.ISOMER.projection_func import projection |
|
|
|
from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix |
|
|
|
logging.basicConfig( |
|
level = logging.INFO |
|
) |
|
logger = logging.getLogger('kiss3d_wrapper') |
|
|
|
OUT_DIR = './outputs' |
|
TMP_DIR = './outputs/tmp' |
|
|
|
os.makedirs(TMP_DIR, exist_ok=True) |
|
|
|
@torch.no_grad() |
|
def lrm_reconstruct(model, infer_config, images, |
|
name='', export_texmap=False, |
|
input_camera_type='zero123', |
|
render_3d_bundle_image=True, |
|
render_azimuths=[270, 0, 90, 180], |
|
render_elevations=[5, 5, 5, 5], |
|
render_radius=4.15): |
|
""" |
|
image: Tensor, shape (1, c, h, w) |
|
""" |
|
|
|
mesh_path_idx = os.path.join(TMP_DIR, f'{name}_recon_from_{input_camera_type}.obj') |
|
|
|
device = images.device |
|
if input_camera_type == 'zero123': |
|
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device) |
|
elif input_camera_type == 'kiss3d': |
|
input_cameras = get_flux_input_cameras(batch_size=1, radius=3.5, fov=30).to(device) |
|
else: |
|
raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}') |
|
|
|
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1) |
|
|
|
logger.info(f"==> Runing LRM reconstruction ...") |
|
planes = model.forward_planes(images, input_cameras) |
|
mesh_out = model.extract_mesh( |
|
planes, |
|
use_texture_map=export_texmap, |
|
**infer_config, |
|
) |
|
if export_texmap: |
|
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out |
|
save_obj_with_mtl( |
|
vertices.data.cpu().numpy(), |
|
uvs.data.cpu().numpy(), |
|
faces.data.cpu().numpy(), |
|
mesh_tex_idx.data.cpu().numpy(), |
|
tex_map.permute(1, 2, 0).data.cpu().numpy(), |
|
mesh_path_idx, |
|
) |
|
else: |
|
vertices, faces, vertex_colors = mesh_out |
|
save_obj(vertices, faces, vertex_colors, mesh_path_idx) |
|
logger.info(f"Mesh saved to {mesh_path_idx}") |
|
|
|
if render_3d_bundle_image: |
|
assert render_azimuths is not None and render_elevations is not None and render_radius is not None |
|
render_azimuths = torch.Tensor(render_azimuths).to(device) |
|
render_elevations = torch.Tensor(render_elevations).to(device) |
|
|
|
render_size = infer_config.render_resolution |
|
ENV = load_mipmap("models/lrm/env_mipmap/6") |
|
materials = (0.0,0.9) |
|
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames( |
|
batch_size=1, |
|
radius=render_radius, |
|
azimuths=render_azimuths, |
|
elevations=render_elevations, |
|
fov=30 |
|
) |
|
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames( |
|
model, |
|
planes, |
|
render_cameras=all_mvp, |
|
camera_pos=all_campos, |
|
env=ENV, |
|
materials=materials, |
|
render_size=render_size, |
|
render_mv = all_mv, |
|
local_normal=True, |
|
identity_mv=identity_mv, |
|
) |
|
else: |
|
normals = None |
|
frames = None |
|
albedos = None |
|
|
|
|
|
vertices = torch.from_numpy(vertices).to(device) |
|
faces = torch.from_numpy(faces).to(device) |
|
vertices = vertices @ rotate_x(np.pi / 2, device=device)[:3, :3] |
|
vertices = vertices @ rotate_y(np.pi / 2, device=device)[:3, :3] |
|
|
|
return vertices.cpu(), faces.cpu(), normals, frames, albedos |
|
|
|
normal_transfer = NormalTransfer() |
|
|
|
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg): |
|
if local_normal_images.min() >= 0: |
|
local_normal = local_normal_images.float() * 2 - 1 |
|
else: |
|
local_normal = local_normal_images.float() |
|
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False) |
|
global_normal[...,0] *= -1 |
|
global_normal = (global_normal + 1) / 2 |
|
global_normal = global_normal.permute(0, 3, 1, 2) |
|
return global_normal |
|
|
|
|
|
def isomer_reconstruct( |
|
rgb_multi_view, |
|
normal_multi_view, |
|
multi_view_mask, |
|
vertices, |
|
faces, |
|
save_path=None, |
|
azimuths=[0, 90, 180, 270], |
|
elevations=[5, 5, 5, 5], |
|
geo_weights=[1, 0.9, 1, 0.9], |
|
color_weights=[1, 0.5, 1, 0.5], |
|
reconstruction_stage1_steps=10, |
|
reconstruction_stage2_steps=50, |
|
radius=4.5): |
|
|
|
device = rgb_multi_view.device |
|
to_tensor_ = lambda x: torch.Tensor(x).float().to(device) |
|
|
|
|
|
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), to_tensor_(azimuths), to_tensor_(elevations)).to(device) |
|
global_normal = global_normal * multi_view_mask + (1-multi_view_mask) |
|
|
|
global_normal = global_normal.permute(0,2,3,1) |
|
multi_view_mask = multi_view_mask.squeeze(1) |
|
rgb_multi_view = rgb_multi_view.permute(0,2,3,1) |
|
|
|
logger.info(f"==> Runing ISOMER reconstruction ...") |
|
meshes = reconstruction( |
|
normal_pils=global_normal, |
|
masks=multi_view_mask, |
|
weights=to_tensor_(geo_weights), |
|
fov=30, |
|
radius=radius, |
|
camera_angles_azi=to_tensor_(azimuths), |
|
camera_angles_ele=to_tensor_(elevations), |
|
expansion_weight_stage1=0.1, |
|
init_type="file", |
|
init_verts=vertices, |
|
init_faces=faces, |
|
stage1_steps=reconstruction_stage1_steps, |
|
stage2_steps=reconstruction_stage2_steps, |
|
start_edge_len_stage1=0.1, |
|
end_edge_len_stage1=0.02, |
|
start_edge_len_stage2=0.02, |
|
end_edge_len_stage2=0.005, |
|
) |
|
|
|
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5) |
|
|
|
|
|
logger.info(f"==> Runing ISOMER projection ...") |
|
save_glb_addr = projection( |
|
meshes, |
|
masks=multi_view_mask_proj.to(device), |
|
images=rgb_multi_view.to(device), |
|
azimuths=to_tensor_(azimuths), |
|
elevations=to_tensor_(elevations), |
|
weights=to_tensor_(color_weights), |
|
fov=30, |
|
radius=radius, |
|
save_dir=TMP_DIR, |
|
save_glb_addr=save_path |
|
) |
|
|
|
logger.info(f"==> Save mesh to {save_glb_addr} ...") |
|
return save_glb_addr |
|
|
|
|
|
def to_rgb_image(maybe_rgba): |
|
assert isinstance(maybe_rgba, Image.Image) |
|
if maybe_rgba.mode == 'RGB': |
|
return maybe_rgba, None |
|
elif maybe_rgba.mode == 'RGBA': |
|
rgba = maybe_rgba |
|
img = np.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8) |
|
img = Image.fromarray(img, 'RGB') |
|
img.paste(rgba, mask=rgba.getchannel('A')) |
|
return img, rgba.getchannel('A') |
|
else: |
|
raise ValueError("Unsupported image type.", maybe_rgba.mode) |
|
|
|
rembg_session = rembg.new_session("u2net") |
|
def preprocess_input_image(input_image): |
|
""" |
|
input_image: PIL.Image |
|
output_image: PIL.Image, (3, 512, 512), mode = RGB, background = white |
|
""" |
|
image = remove_background(to_rgb_image(input_image)[0], rembg_session, bgcolor=(255, 255, 255, 255)) |
|
image = resize_foreground(image, ratio=0.85, pad_value=255) |
|
return to_rgb_image(image)[0] |
|
|