JiantaoLin
initial commit
98bebfc
import torch
import numpy as np
from PIL import Image
import os
from pytorch3d.io import load_obj
import trimesh
from pytorch3d.structures import Meshes
# from rembg import remove
def remove_color(arr):
if arr.shape[-1] == 4:
arr = arr[..., :3]
# Convert to torch tensor
if type(arr) is not torch.Tensor:
arr = torch.tensor(arr, dtype=torch.int32)
# Calculate diffs
base = arr[0, 0]
diffs = torch.abs(arr - base).sum(dim=-1)
alpha = (diffs <= 80)
arr[alpha] = 255
alpha = ~alpha
alpha = alpha.unsqueeze(-1).int() * 255
arr = torch.cat([arr, alpha], dim=-1)
return arr
def simple_remove_bkg_normal(imgs, rm_bkg_with_rembg, return_Image=False):
"""Only works for normal"""
rets = []
for img in imgs:
if rm_bkg_with_rembg:
from rembg import remove
image = Image.fromarray(img.to(torch.uint8).detach().cpu().numpy()) if isinstance(img, torch.Tensor) else img
removed_image = remove(image)
arr = np.array(removed_image)
arr = torch.tensor(arr, dtype=torch.uint8)
else:
arr = remove_color(img)
if return_Image:
rets.append(Image.fromarray(arr.to(torch.uint8).detach().cpu().numpy()))
else:
rets.append(arr.to(torch.uint8))
return rets
def load_glb(file_path):
# Load the .glb file as a scene and merge all meshes
scene_or_mesh = trimesh.load(file_path)
mesh = scene_or_mesh.dump(concatenate=True) if isinstance(scene_or_mesh, trimesh.Scene) else scene_or_mesh
# Extract vertices and faces from the merged mesh
verts = torch.tensor(mesh.vertices, dtype=torch.float32)
faces = torch.tensor(mesh.faces, dtype=torch.int64)
textured_mesh = Meshes(verts=[verts], faces=[faces])
return textured_mesh
def load_obj_with_verts_faces(file_path, return_mesh=True):
verts, faces, _ = load_obj(file_path)
verts = torch.tensor(verts, dtype=torch.float32)
faces = faces.verts_idx
faces = torch.tensor(faces, dtype=torch.int64)
if return_mesh:
return Meshes(verts=[verts], faces=[faces])
else:
return verts, faces
def normalize_mesh(vertices):
min_vals, _ = torch.min(vertices, axis=0)
max_vals, _ = torch.max(vertices, axis=0)
center = (max_vals + min_vals) / 2
vertices = vertices - center
max_extent = torch.max(max_vals - min_vals)
scale = 2.0 / max_extent
vertices = vertices * scale
return vertices