Kiss3DGen / models /ISOMER /projection_func.py
JiantaoLin
initial commit
98bebfc
raw
history blame
2.5 kB
import os
import numpy as np
import torch
from PIL import Image
import os
from .scripts.proj_commands import projection as isomer_projection
from .data.utils import simple_remove_bkg_normal
# mesh_address,
def projection(
meshes,
masks,
images,
azimuths,
elevations,
weights,
fov,
radius,
save_dir,
save_glb_addr=None,
remove_background=False,
auto_center=False,
projection_type="perspective",
below_confidence_strategy="smooth",
complete_unseen=True,
mesh_scale_factor=1.0,
rm_bkg_with_rembg=True,
):
if save_glb_addr is None:
os.makedirs(save_dir, exist_ok=True)
save_glb_addr=os.path.join(save_dir, "rgb_projected.glb")
bs = len(images)
assert len(azimuths) == bs, f'len(azimuths) ({len(azimuths)} != batchsize ({bs}))'
assert len(elevations) == bs, f'len(elevations) ({len(elevations)} != batchsize ({bs}))'
assert len(weights) == bs, f'len(weights) ({len(weights)} != batchsize ({bs}))'
image_rgba = torch.cat([images[:,:,:,:3], masks.unsqueeze(-1)], dim=-1)
assert image_rgba.shape[-1] == 4, f'image_rgba.shape is {image_rgba.shape}'
img_list = [Image.fromarray((image.cpu()*255).numpy().astype(np.uint8)) for image in image_rgba]
if remove_background:
if rm_bkg_with_rembg:
os.environ["OMP_NUM_THREADS"] = '8'
img_list = simple_remove_bkg_normal(img_list, rm_bkg_with_rembg, return_Image=True)
resolution = img_list[0].size[0]
new_img_list = []
for i in range(len(img_list)):
new_img = img_list[i].resize((resolution,resolution))
path_dir = os.path.join(save_dir, f'projection_images')
os.makedirs(path_dir, exist_ok=True)
path_ = os.path.join(path_dir, f'ProjectionImg{i}.png')
new_img.save(path_)
new_img_list.append(new_img)
img_list = new_img_list
isomer_projection(meshes,
img_list=img_list,
weights=weights,
azimuths=azimuths,
elevations=elevations,
projection_type=projection_type,
auto_center=auto_center,
resolution=resolution,
fovy=fov,
radius=radius,
scale_factor=mesh_scale_factor,
save_glb_addr=save_glb_addr,
scale_verts=True,
complete_unseen=complete_unseen,
below_confidence_strategy=below_confidence_strategy
)
return save_glb_addr