Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| import re | |
| import shutil | |
| import cv2 | |
| import imageio | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import wandb | |
| from matplotlib import cm | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from PIL import Image, ImageDraw | |
| from pytorch_lightning.loggers import WandbLogger | |
| import lrm | |
| from ..models.mesh import Mesh | |
| from ..utils.typing import * | |
| class SaverMixin: | |
| _save_dir: Optional[str] = None | |
| _wandb_logger: Optional[WandbLogger] = None | |
| def set_save_dir(self, save_dir: str): | |
| self._save_dir = save_dir | |
| def get_save_dir(self): | |
| if self._save_dir is None: | |
| raise ValueError("Save dir is not set") | |
| return self._save_dir | |
| def convert_data(self, data): | |
| if data is None: | |
| return None | |
| elif isinstance(data, np.ndarray): | |
| return data | |
| elif isinstance(data, torch.Tensor): | |
| if data.dtype in [torch.float16, torch.bfloat16]: | |
| data = data.float() | |
| return data.detach().cpu().numpy() | |
| elif isinstance(data, list): | |
| return [self.convert_data(d) for d in data] | |
| elif isinstance(data, dict): | |
| return {k: self.convert_data(v) for k, v in data.items()} | |
| else: | |
| raise TypeError( | |
| "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", | |
| type(data), | |
| ) | |
| def get_save_path(self, filename): | |
| save_path = os.path.join(self.get_save_dir(), filename) | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| return save_path | |
| DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} | |
| DEFAULT_UV_KWARGS = { | |
| "data_format": "HWC", | |
| "data_range": (0, 1), | |
| "cmap": "checkerboard", | |
| } | |
| DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} | |
| DEFAULT_GRID_KWARGS = {"align": "max"} | |
| def get_rgb_image_(self, img, data_format, data_range, rgba=False): | |
| img = self.convert_data(img) | |
| assert data_format in ["CHW", "HWC"] | |
| if data_format == "CHW": | |
| img = img.transpose(1, 2, 0) | |
| if img.dtype != np.uint8: | |
| img = img.clip(min=data_range[0], max=data_range[1]) | |
| img = ( | |
| (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 | |
| ).astype(np.uint8) | |
| nc = 4 if rgba else 3 | |
| imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] | |
| imgs = [ | |
| img_ | |
| if img_.shape[-1] == nc | |
| else np.concatenate( | |
| [ | |
| img_, | |
| np.zeros( | |
| (img_.shape[0], img_.shape[1], nc - img_.shape[2]), | |
| dtype=img_.dtype, | |
| ), | |
| ], | |
| axis=-1, | |
| ) | |
| for img_ in imgs | |
| ] | |
| img = np.concatenate(imgs, axis=1) | |
| if rgba: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) | |
| else: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| return img | |
| def _save_rgb_image( | |
| self, | |
| filename, | |
| img, | |
| data_format, | |
| data_range, | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| ): | |
| img = self.get_rgb_image_(img, data_format, data_range) | |
| cv2.imwrite(filename, img) | |
| if name and self._wandb_logger: | |
| self._wandb_logger.log_image( | |
| key=name, images=[self.get_save_path(filename)], step=step | |
| ) | |
| def save_rgb_image( | |
| self, | |
| filename, | |
| img, | |
| data_format=DEFAULT_RGB_KWARGS["data_format"], | |
| data_range=DEFAULT_RGB_KWARGS["data_range"], | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| ) -> str: | |
| save_path = self.get_save_path(filename) | |
| self._save_rgb_image(save_path, img, data_format, data_range, name, step) | |
| return save_path | |
| def get_uv_image_(self, img, data_format, data_range, cmap): | |
| img = self.convert_data(img) | |
| assert data_format in ["CHW", "HWC"] | |
| if data_format == "CHW": | |
| img = img.transpose(1, 2, 0) | |
| img = img.clip(min=data_range[0], max=data_range[1]) | |
| img = (img - data_range[0]) / (data_range[1] - data_range[0]) | |
| assert cmap in ["checkerboard", "color"] | |
| if cmap == "checkerboard": | |
| n_grid = 64 | |
| mask = (img * n_grid).astype(int) | |
| mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 | |
| img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 | |
| img[mask] = np.array([255, 0, 255], dtype=np.uint8) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| elif cmap == "color": | |
| img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) | |
| img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) | |
| img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) | |
| img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) | |
| img = img_ | |
| return img | |
| def save_uv_image( | |
| self, | |
| filename, | |
| img, | |
| data_format=DEFAULT_UV_KWARGS["data_format"], | |
| data_range=DEFAULT_UV_KWARGS["data_range"], | |
| cmap=DEFAULT_UV_KWARGS["cmap"], | |
| ) -> str: | |
| save_path = self.get_save_path(filename) | |
| img = self.get_uv_image_(img, data_format, data_range, cmap) | |
| cv2.imwrite(save_path, img) | |
| return save_path | |
| def get_grayscale_image_(self, img, data_range, cmap): | |
| img = self.convert_data(img) | |
| img = np.nan_to_num(img) | |
| if data_range is None: | |
| img = (img - img.min()) / (img.max() - img.min()) | |
| else: | |
| img = img.clip(data_range[0], data_range[1]) | |
| img = (img - data_range[0]) / (data_range[1] - data_range[0]) | |
| assert cmap in [None, "jet", "magma", "spectral"] | |
| if cmap == None: | |
| img = (img * 255.0).astype(np.uint8) | |
| img = np.repeat(img[..., None], 3, axis=2) | |
| elif cmap == "jet": | |
| img = (img * 255.0).astype(np.uint8) | |
| img = cv2.applyColorMap(img, cv2.COLORMAP_JET) | |
| elif cmap == "magma": | |
| img = 1.0 - img | |
| base = cm.get_cmap("magma") | |
| num_bins = 256 | |
| colormap = LinearSegmentedColormap.from_list( | |
| f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins | |
| )(np.linspace(0, 1, num_bins))[:, :3] | |
| a = np.floor(img * 255.0) | |
| b = (a + 1).clip(max=255.0) | |
| f = img * 255.0 - a | |
| a = a.astype(np.uint16).clip(0, 255) | |
| b = b.astype(np.uint16).clip(0, 255) | |
| img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] | |
| img = (img * 255.0).astype(np.uint8) | |
| elif cmap == "spectral": | |
| colormap = plt.get_cmap("Spectral") | |
| def blend_rgba(image): | |
| image = image[..., :3] * image[..., -1:] + ( | |
| 1.0 - image[..., -1:] | |
| ) # blend A to RGB | |
| return image | |
| img = colormap(img) | |
| img = blend_rgba(img) | |
| img = (img * 255).astype(np.uint8) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| return img | |
| def _save_grayscale_image( | |
| self, | |
| filename, | |
| img, | |
| data_range, | |
| cmap, | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| ): | |
| img = self.get_grayscale_image_(img, data_range, cmap) | |
| cv2.imwrite(filename, img) | |
| if name and self._wandb_logger: | |
| self._wandb_logger.log_image( | |
| key=name, images=[self.get_save_path(filename)], step=step | |
| ) | |
| def save_grayscale_image( | |
| self, | |
| filename, | |
| img, | |
| data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], | |
| cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| ) -> str: | |
| save_path = self.get_save_path(filename) | |
| self._save_grayscale_image(save_path, img, data_range, cmap, name, step) | |
| return save_path | |
| def get_image_grid_(self, imgs, align): | |
| if isinstance(imgs[0], list): | |
| return np.concatenate( | |
| [self.get_image_grid_(row, align) for row in imgs], axis=0 | |
| ) | |
| cols = [] | |
| for col in imgs: | |
| assert col["type"] in ["rgb", "uv", "grayscale"] | |
| if col["type"] == "rgb": | |
| rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() | |
| rgb_kwargs.update(col["kwargs"]) | |
| cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) | |
| elif col["type"] == "uv": | |
| uv_kwargs = self.DEFAULT_UV_KWARGS.copy() | |
| uv_kwargs.update(col["kwargs"]) | |
| cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) | |
| elif col["type"] == "grayscale": | |
| grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() | |
| grayscale_kwargs.update(col["kwargs"]) | |
| cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) | |
| if align == "max": | |
| h = max([col.shape[0] for col in cols]) | |
| elif align == "min": | |
| h = min([col.shape[0] for col in cols]) | |
| elif isinstance(align, int): | |
| h = align | |
| else: | |
| raise ValueError( | |
| f"Unsupported image grid align: {align}, should be min, max, or int" | |
| ) | |
| for i in range(len(cols)): | |
| if cols[i].shape[0] != h: | |
| w = int(cols[i].shape[1] * h / cols[i].shape[0]) | |
| cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC) | |
| return np.concatenate(cols, axis=1) | |
| def save_image_grid( | |
| self, | |
| filename, | |
| imgs, | |
| align=DEFAULT_GRID_KWARGS["align"], | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| texts: Optional[List[float]] = None, | |
| ): | |
| save_path = self.get_save_path(filename) | |
| img = self.get_image_grid_(imgs, align=align) | |
| if texts is not None: | |
| img = Image.fromarray(img) | |
| draw = ImageDraw.Draw(img) | |
| black, white = (0, 0, 0), (255, 255, 255) | |
| for i, text in enumerate(texts): | |
| draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) | |
| draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) | |
| draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) | |
| draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) | |
| draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) | |
| img = np.asarray(img) | |
| cv2.imwrite(save_path, img) | |
| if name and self._wandb_logger: | |
| self._wandb_logger.log_image(key=name, images=[save_path], step=step) | |
| return save_path | |
| def save_image(self, filename, img) -> str: | |
| save_path = self.get_save_path(filename) | |
| img = self.convert_data(img) | |
| assert img.dtype == np.uint8 or img.dtype == np.uint16 | |
| if img.ndim == 3 and img.shape[-1] == 3: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| elif img.ndim == 3 and img.shape[-1] == 4: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) | |
| cv2.imwrite(save_path, img) | |
| return save_path | |
| def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: | |
| save_path = self.get_save_path(filename) | |
| img = self.convert_data(img) | |
| assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] | |
| imgs_full = [] | |
| for start in range(0, img.shape[-1], 3): | |
| img_ = img[..., start : start + 3] | |
| img_ = np.stack( | |
| [ | |
| self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) | |
| for i in range(img_.shape[0]) | |
| ], | |
| axis=0, | |
| ) | |
| size = img_.shape[1] | |
| placeholder = np.zeros((size, size, 3), dtype=np.float32) | |
| img_full = np.concatenate( | |
| [ | |
| np.concatenate( | |
| [placeholder, img_[2], placeholder, placeholder], axis=1 | |
| ), | |
| np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), | |
| np.concatenate( | |
| [placeholder, img_[3], placeholder, placeholder], axis=1 | |
| ), | |
| ], | |
| axis=0, | |
| ) | |
| imgs_full.append(img_full) | |
| imgs_full = np.concatenate(imgs_full, axis=1) | |
| cv2.imwrite(save_path, imgs_full) | |
| return save_path | |
| def save_data(self, filename, data) -> str: | |
| data = self.convert_data(data) | |
| if isinstance(data, dict): | |
| if not filename.endswith(".npz"): | |
| filename += ".npz" | |
| save_path = self.get_save_path(filename) | |
| np.savez(save_path, **data) | |
| else: | |
| if not filename.endswith(".npy"): | |
| filename += ".npy" | |
| save_path = self.get_save_path(filename) | |
| np.save(save_path, data) | |
| return save_path | |
| def save_state_dict(self, filename, data) -> str: | |
| save_path = self.get_save_path(filename) | |
| torch.save(data, save_path) | |
| return save_path | |
| def save_img_sequence( | |
| self, | |
| filename, | |
| img_dir, | |
| matcher, | |
| save_format="mp4", | |
| fps=30, | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| ) -> str: | |
| assert save_format in ["gif", "mp4"] | |
| if not filename.endswith(save_format): | |
| filename += f".{save_format}" | |
| save_path = self.get_save_path(filename) | |
| matcher = re.compile(matcher) | |
| img_dir = os.path.join(self.get_save_dir(), img_dir) | |
| imgs = [] | |
| for f in os.listdir(img_dir): | |
| if matcher.search(f): | |
| imgs.append(f) | |
| imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) | |
| imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] | |
| if save_format == "gif": | |
| imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] | |
| imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) | |
| elif save_format == "mp4": | |
| imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] | |
| imageio.mimsave(save_path, imgs, fps=fps) | |
| if name and self._wandb_logger: | |
| lrm.warn("Wandb logger does not support video logging yet!") | |
| return save_path | |
| def save_img_sequences( | |
| self, | |
| seq_dir, | |
| matcher, | |
| save_format="mp4", | |
| fps=30, | |
| delete=True, | |
| name: Optional[str] = None, | |
| step: Optional[int] = None, | |
| ): | |
| seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) | |
| for f in os.listdir(seq_dir_): | |
| img_dir_ = os.path.join(seq_dir_, f) | |
| if not os.path.isdir(img_dir_): | |
| continue | |
| try: | |
| self.save_img_sequence( | |
| os.path.join(seq_dir, f), | |
| os.path.join(seq_dir, f), | |
| matcher, | |
| save_format=save_format, | |
| fps=fps, | |
| name=f"{name}_{f}", | |
| step=step, | |
| ) | |
| if delete: | |
| shutil.rmtree(img_dir_) | |
| except: | |
| lrm.warn(f"Video saving for directory {seq_dir_} failed!") | |
| def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str: | |
| import trimesh | |
| save_path = self.get_save_path(filename) | |
| v_pos = self.convert_data(v_pos) | |
| t_pos_idx = self.convert_data(t_pos_idx) | |
| mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx) | |
| mesh.export(save_path) | |
| return save_path | |
| def save_obj( | |
| self, | |
| filename: str, | |
| mesh: Mesh, | |
| save_mat: bool = False, | |
| save_normal: bool = False, | |
| save_uv: bool = False, | |
| save_vertex_color: bool = False, | |
| map_Kd: Optional[Float[Tensor, "H W 3"]] = None, | |
| map_Ks: Optional[Float[Tensor, "H W 3"]] = None, | |
| map_Bump: Optional[Float[Tensor, "H W 3"]] = None, | |
| map_Pm: Optional[Float[Tensor, "H W 1"]] = None, | |
| map_Pr: Optional[Float[Tensor, "H W 1"]] = None, | |
| map_format: str = "jpg", | |
| ) -> List[str]: | |
| save_paths: List[str] = [] | |
| if not filename.endswith(".obj"): | |
| filename += ".obj" | |
| v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( | |
| mesh.t_pos_idx | |
| ) | |
| v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None | |
| if save_normal: | |
| v_nrm = self.convert_data(mesh.v_nrm) | |
| if save_uv: | |
| v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( | |
| mesh.t_tex_idx | |
| ) | |
| if save_vertex_color: | |
| v_rgb = self.convert_data(mesh.v_rgb) | |
| matname, mtllib = None, None | |
| if save_mat: | |
| matname = "default" | |
| mtl_filename = filename.replace(".obj", ".mtl") | |
| mtllib = os.path.basename(mtl_filename) | |
| mtl_save_paths = self._save_mtl( | |
| mtl_filename, | |
| matname, | |
| map_Kd=self.convert_data(map_Kd), | |
| map_Ks=self.convert_data(map_Ks), | |
| map_Bump=self.convert_data(map_Bump), | |
| map_Pm=self.convert_data(map_Pm), | |
| map_Pr=self.convert_data(map_Pr), | |
| map_format=map_format, | |
| ) | |
| save_paths += mtl_save_paths | |
| obj_save_path = self._save_obj( | |
| filename, | |
| v_pos, | |
| t_pos_idx, | |
| v_nrm=v_nrm, | |
| v_tex=v_tex, | |
| t_tex_idx=t_tex_idx, | |
| v_rgb=v_rgb, | |
| matname=matname, | |
| mtllib=mtllib, | |
| ) | |
| save_paths.append(obj_save_path) | |
| return save_paths | |
| def _save_obj( | |
| self, | |
| filename, | |
| v_pos, | |
| t_pos_idx, | |
| v_nrm=None, | |
| v_tex=None, | |
| t_tex_idx=None, | |
| v_rgb=None, | |
| matname=None, | |
| mtllib=None, | |
| ) -> str: | |
| obj_str = "" | |
| if matname is not None: | |
| obj_str += f"mtllib {mtllib}\n" | |
| obj_str += f"g object\n" | |
| obj_str += f"usemtl {matname}\n" | |
| for i in range(len(v_pos)): | |
| obj_str += f"v {v_pos[i][0]} {v_pos[i][1]} {v_pos[i][2]}" | |
| if v_rgb is not None: | |
| obj_str += f" {v_rgb[i][0]} {v_rgb[i][1]} {v_rgb[i][2]}" | |
| obj_str += "\n" | |
| if v_nrm is not None: | |
| for v in v_nrm: | |
| obj_str += f"vn {v[0]} {v[1]} {v[2]}\n" | |
| if v_tex is not None: | |
| for v in v_tex: | |
| obj_str += f"vt {v[0]} {1.0 - v[1]}\n" | |
| for i in range(len(t_pos_idx)): | |
| obj_str += "f" | |
| for j in range(3): | |
| obj_str += f" {t_pos_idx[i][j] + 1}/" | |
| if v_tex is not None: | |
| obj_str += f"{t_tex_idx[i][j] + 1}" | |
| obj_str += "/" | |
| if v_nrm is not None: | |
| obj_str += f"{t_pos_idx[i][j] + 1}" | |
| obj_str += "\n" | |
| save_path = self.get_save_path(filename) | |
| with open(save_path, "w") as f: | |
| f.write(obj_str) | |
| return save_path | |
| def _save_mtl( | |
| self, | |
| filename, | |
| matname, | |
| Ka=(0.0, 0.0, 0.0), | |
| Kd=(1.0, 1.0, 1.0), | |
| Ks=(0.0, 0.0, 0.0), | |
| map_Kd=None, | |
| map_Ks=None, | |
| map_Bump=None, | |
| map_Pm=None, | |
| map_Pr=None, | |
| map_format="jpg", | |
| step: Optional[int] = None, | |
| ) -> List[str]: | |
| mtl_save_path = self.get_save_path(filename) | |
| save_paths = [mtl_save_path] | |
| mtl_str = f"newmtl {matname}\n" | |
| mtl_str += f"Ka {Ka[0]} {Ka[1]} {Ka[2]}\n" | |
| if map_Kd is not None: | |
| map_Kd_save_path = os.path.join( | |
| os.path.dirname(mtl_save_path), f"texture_kd.{map_format}" | |
| ) | |
| mtl_str += f"map_Kd texture_kd.{map_format}\n" | |
| self._save_rgb_image( | |
| map_Kd_save_path, | |
| map_Kd, | |
| data_format="HWC", | |
| data_range=(0, 1), | |
| name=f"{matname}_Kd", | |
| step=step, | |
| ) | |
| save_paths.append(map_Kd_save_path) | |
| else: | |
| mtl_str += f"Kd {Kd[0]} {Kd[1]} {Kd[2]}\n" | |
| if map_Ks is not None: | |
| map_Ks_save_path = os.path.join( | |
| os.path.dirname(mtl_save_path), f"texture_ks.{map_format}" | |
| ) | |
| mtl_str += f"map_Ks texture_ks.{map_format}\n" | |
| self._save_rgb_image( | |
| map_Ks_save_path, | |
| map_Ks, | |
| data_format="HWC", | |
| data_range=(0, 1), | |
| name=f"{matname}_Ks", | |
| step=step, | |
| ) | |
| save_paths.append(map_Ks_save_path) | |
| else: | |
| mtl_str += f"Ks {Ks[0]} {Ks[1]} {Ks[2]}\n" | |
| if map_Bump is not None: | |
| map_Bump_save_path = os.path.join( | |
| os.path.dirname(mtl_save_path), f"texture_nrm.{map_format}" | |
| ) | |
| mtl_str += f"map_Bump texture_nrm.{map_format}\n" | |
| self._save_rgb_image( | |
| map_Bump_save_path, | |
| map_Bump, | |
| data_format="HWC", | |
| data_range=(0, 1), | |
| name=f"{matname}_Bump", | |
| step=step, | |
| ) | |
| save_paths.append(map_Bump_save_path) | |
| if map_Pm is not None: | |
| map_Pm_save_path = os.path.join( | |
| os.path.dirname(mtl_save_path), f"texture_metallic.{map_format}" | |
| ) | |
| mtl_str += f"map_Pm texture_metallic.{map_format}\n" | |
| self._save_grayscale_image( | |
| map_Pm_save_path, | |
| map_Pm, | |
| data_range=(0, 1), | |
| cmap=None, | |
| name=f"{matname}_refl", | |
| step=step, | |
| ) | |
| save_paths.append(map_Pm_save_path) | |
| if map_Pr is not None: | |
| map_Pr_save_path = os.path.join( | |
| os.path.dirname(mtl_save_path), f"texture_roughness.{map_format}" | |
| ) | |
| mtl_str += f"map_Pr texture_roughness.{map_format}\n" | |
| self._save_grayscale_image( | |
| map_Pr_save_path, | |
| map_Pr, | |
| data_range=(0, 1), | |
| cmap=None, | |
| name=f"{matname}_Ns", | |
| step=step, | |
| ) | |
| save_paths.append(map_Pr_save_path) | |
| with open(self.get_save_path(filename), "w") as f: | |
| f.write(mtl_str) | |
| return save_paths | |
| def save_glb( | |
| self, | |
| filename: str, | |
| mesh: Mesh, | |
| save_mat: bool = False, | |
| save_normal: bool = False, | |
| save_uv: bool = False, | |
| save_vertex_color: bool = False, | |
| map_Kd: Optional[Float[Tensor, "H W 3"]] = None, | |
| map_Ks: Optional[Float[Tensor, "H W 3"]] = None, | |
| map_Bump: Optional[Float[Tensor, "H W 3"]] = None, | |
| map_Pm: Optional[Float[Tensor, "H W 1"]] = None, | |
| map_Pr: Optional[Float[Tensor, "H W 1"]] = None, | |
| map_format: str = "jpg", | |
| ) -> List[str]: | |
| save_paths: List[str] = [] | |
| if not filename.endswith(".glb"): | |
| filename += ".glb" | |
| v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( | |
| mesh.t_pos_idx | |
| ) | |
| v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None | |
| if save_normal: | |
| v_nrm = self.convert_data(mesh.v_nrm) | |
| if save_uv: | |
| v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( | |
| mesh.t_tex_idx | |
| ) | |
| if save_vertex_color: | |
| v_rgb = self.convert_data(mesh.v_rgb) | |
| obj_save_path = self._save_glb( | |
| filename, | |
| v_pos, | |
| t_pos_idx, | |
| v_nrm=v_nrm, | |
| v_tex=v_tex, | |
| t_tex_idx=t_tex_idx, | |
| v_rgb=v_rgb, | |
| ) | |
| save_paths.append(obj_save_path) | |
| return save_paths | |
| def _save_glb( | |
| self, | |
| filename, | |
| v_pos, | |
| t_pos_idx, | |
| v_nrm=None, | |
| v_tex=None, | |
| t_tex_idx=None, | |
| v_rgb=None, | |
| matname=None, | |
| mtllib=None, | |
| ) -> str: | |
| import trimesh | |
| mesh = trimesh.Trimesh( | |
| vertices=v_pos, faces=t_pos_idx, vertex_normals=v_nrm, vertex_colors=v_rgb | |
| ) | |
| # not tested | |
| if v_tex is not None: | |
| mesh.visual = trimesh.visual.TextureVisuals(uv=v_tex) | |
| save_path = self.get_save_path(filename) | |
| mesh.export(save_path) | |
| return save_path | |
| def save_file(self, filename, src_path, delete=False) -> str: | |
| save_path = self.get_save_path(filename) | |
| shutil.copyfile(src_path, save_path) | |
| if delete: | |
| os.remove(src_path) | |
| return save_path | |
| def save_json(self, filename, payload) -> str: | |
| save_path = self.get_save_path(filename) | |
| with open(save_path, "w") as f: | |
| f.write(json.dumps(payload)) | |
| return save_path |