|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
from pytorch3d.io import load_obj |
|
import trimesh |
|
from pytorch3d.structures import Meshes |
|
|
|
|
|
def remove_color(arr): |
|
if arr.shape[-1] == 4: |
|
arr = arr[..., :3] |
|
|
|
|
|
if type(arr) is not torch.Tensor: |
|
arr = torch.tensor(arr, dtype=torch.int32) |
|
|
|
|
|
base = arr[0, 0] |
|
diffs = torch.abs(arr - base).sum(dim=-1) |
|
alpha = (diffs <= 80) |
|
|
|
arr[alpha] = 255 |
|
alpha = ~alpha |
|
alpha = alpha.unsqueeze(-1).int() * 255 |
|
arr = torch.cat([arr, alpha], dim=-1) |
|
|
|
return arr |
|
|
|
def simple_remove_bkg_normal(imgs, rm_bkg_with_rembg, return_Image=False): |
|
"""Only works for normal""" |
|
rets = [] |
|
for img in imgs: |
|
if rm_bkg_with_rembg: |
|
from rembg import remove |
|
image = Image.fromarray(img.to(torch.uint8).detach().cpu().numpy()) if isinstance(img, torch.Tensor) else img |
|
removed_image = remove(image) |
|
arr = np.array(removed_image) |
|
arr = torch.tensor(arr, dtype=torch.uint8) |
|
else: |
|
arr = remove_color(img) |
|
|
|
if return_Image: |
|
rets.append(Image.fromarray(arr.to(torch.uint8).detach().cpu().numpy())) |
|
else: |
|
rets.append(arr.to(torch.uint8)) |
|
|
|
return rets |
|
|
|
|
|
def load_glb(file_path): |
|
|
|
scene_or_mesh = trimesh.load(file_path) |
|
|
|
mesh = scene_or_mesh.dump(concatenate=True) if isinstance(scene_or_mesh, trimesh.Scene) else scene_or_mesh |
|
|
|
|
|
verts = torch.tensor(mesh.vertices, dtype=torch.float32) |
|
faces = torch.tensor(mesh.faces, dtype=torch.int64) |
|
|
|
|
|
textured_mesh = Meshes(verts=[verts], faces=[faces]) |
|
|
|
|
|
return textured_mesh |
|
|
|
def load_obj_with_verts_faces(file_path, return_mesh=True): |
|
verts, faces, _ = load_obj(file_path) |
|
|
|
verts = torch.tensor(verts, dtype=torch.float32) |
|
faces = faces.verts_idx |
|
faces = torch.tensor(faces, dtype=torch.int64) |
|
|
|
if return_mesh: |
|
return Meshes(verts=[verts], faces=[faces]) |
|
else: |
|
return verts, faces |
|
|
|
def normalize_mesh(vertices): |
|
min_vals, _ = torch.min(vertices, axis=0) |
|
max_vals, _ = torch.max(vertices, axis=0) |
|
center = (max_vals + min_vals) / 2 |
|
vertices = vertices - center |
|
max_extent = torch.max(max_vals - min_vals) |
|
scale = 2.0 / max_extent |
|
vertices = vertices * scale |
|
return vertices |
|
|