Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| from glob import glob | |
| import numpy as np | |
| import trimesh | |
| from PIL import Image | |
| from embodied_gen.data.backproject_v2 import entrypoint as backproject_api | |
| from embodied_gen.data.utils import trellis_preprocess | |
| from embodied_gen.models.delight_model import DelightingModel | |
| from embodied_gen.models.gs_model import GaussianOperator | |
| from embodied_gen.models.segment_model import ( | |
| BMGG14Remover, | |
| RembgRemover, | |
| SAMPredictor, | |
| ) | |
| from embodied_gen.models.sr_model import ImageRealESRGAN | |
| from embodied_gen.scripts.render_gs import entrypoint as render_gs_api | |
| from embodied_gen.utils.gpt_clients import GPT_CLIENT | |
| from embodied_gen.utils.process_media import merge_images_video, render_video | |
| from embodied_gen.utils.tags import VERSION | |
| from embodied_gen.validators.quality_checkers import ( | |
| BaseChecker, | |
| ImageAestheticChecker, | |
| ImageSegChecker, | |
| MeshGeoChecker, | |
| ) | |
| from embodied_gen.validators.urdf_convertor import URDFGenerator | |
| current_file_path = os.path.abspath(__file__) | |
| current_dir = os.path.dirname(current_file_path) | |
| sys.path.append(os.path.join(current_dir, "../..")) | |
| from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
| ) | |
| logger = logging.getLogger(__name__) | |
| os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( | |
| "~/.cache/torch_extensions" | |
| ) | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" | |
| os.environ["SPCONV_ALGO"] = "native" | |
| DELIGHT = DelightingModel() | |
| IMAGESR_MODEL = ImageRealESRGAN(outscale=4) | |
| RBG_REMOVER = RembgRemover() | |
| RBG14_REMOVER = BMGG14Remover() | |
| SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") | |
| PIPELINE = TrellisImageTo3DPipeline.from_pretrained( | |
| "microsoft/TRELLIS-image-large" | |
| ) | |
| PIPELINE.cuda() | |
| SEG_CHECKER = ImageSegChecker(GPT_CLIENT) | |
| GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) | |
| AESTHETIC_CHECKER = ImageAestheticChecker() | |
| CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER] | |
| TMP_DIR = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" | |
| ) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Image to 3D pipeline args.") | |
| parser.add_argument( | |
| "--image_path", type=str, nargs="+", help="Path to the input images." | |
| ) | |
| parser.add_argument( | |
| "--image_root", type=str, help="Path to the input images folder." | |
| ) | |
| parser.add_argument( | |
| "--output_root", | |
| type=str, | |
| required=True, | |
| help="Root directory for saving outputs.", | |
| ) | |
| parser.add_argument( | |
| "--no_mesh", action="store_true", help="Do not output mesh files." | |
| ) | |
| parser.add_argument( | |
| "--height_range", | |
| type=str, | |
| default=None, | |
| help="The hight in meter to restore the mesh real size.", | |
| ) | |
| parser.add_argument( | |
| "--mass_range", | |
| type=str, | |
| default=None, | |
| help="The mass in kg to restore the mesh real weight.", | |
| ) | |
| parser.add_argument("--asset_type", type=str, default=None) | |
| parser.add_argument("--skip_exists", action="store_true") | |
| parser.add_argument("--strict_seg", action="store_true") | |
| parser.add_argument("--version", type=str, default=VERSION) | |
| args = parser.parse_args() | |
| assert ( | |
| args.image_path or args.image_root | |
| ), "Please provide either --image_path or --image_root." | |
| if not args.image_path: | |
| args.image_path = glob(os.path.join(args.image_root, "*.png")) | |
| args.image_path += glob(os.path.join(args.image_root, "*.jpg")) | |
| args.image_path += glob(os.path.join(args.image_root, "*.jpeg")) | |
| return args | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| for image_path in args.image_path: | |
| try: | |
| filename = os.path.basename(image_path).split(".")[0] | |
| output_root = args.output_root | |
| if args.image_root is not None: | |
| output_root = os.path.join(output_root, filename) | |
| os.makedirs(output_root, exist_ok=True) | |
| mesh_out = f"{output_root}/{filename}.obj" | |
| if args.skip_exists and os.path.exists(mesh_out): | |
| logger.info( | |
| f"Skip {image_path}, already processed in {mesh_out}" | |
| ) | |
| continue | |
| image = Image.open(image_path) | |
| image.save(f"{output_root}/{filename}_raw.png") | |
| # Segmentation: Get segmented image using SAM or Rembg. | |
| seg_path = f"{output_root}/{filename}_cond.png" | |
| if image.mode != "RGBA": | |
| seg_image = RBG_REMOVER(image, save_path=seg_path) | |
| seg_image = trellis_preprocess(seg_image) | |
| else: | |
| seg_image = image | |
| seg_image.save(seg_path) | |
| # Run the pipeline | |
| try: | |
| outputs = PIPELINE.run( | |
| seg_image, | |
| preprocess_image=False, | |
| # Optional parameters | |
| # seed=1, | |
| # sparse_structure_sampler_params={ | |
| # "steps": 12, | |
| # "cfg_strength": 7.5, | |
| # }, | |
| # slat_sampler_params={ | |
| # "steps": 12, | |
| # "cfg_strength": 3, | |
| # }, | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"[Pipeline Failed] process {image_path}: {e}, skip." | |
| ) | |
| continue | |
| # Render and save color and mesh videos | |
| gs_model = outputs["gaussian"][0] | |
| mesh_model = outputs["mesh"][0] | |
| color_images = render_video(gs_model)["color"] | |
| normal_images = render_video(mesh_model)["normal"] | |
| video_path = os.path.join(output_root, "gs_mesh.mp4") | |
| merge_images_video(color_images, normal_images, video_path) | |
| if not args.no_mesh: | |
| # Save the raw Gaussian model | |
| gs_path = mesh_out.replace(".obj", "_gs.ply") | |
| gs_model.save_ply(gs_path) | |
| # Rotate mesh and GS by 90 degrees around Z-axis. | |
| rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] | |
| gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] | |
| mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] | |
| # Addtional rotation for GS to align mesh. | |
| gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) | |
| pose = GaussianOperator.trans_to_quatpose(gs_rot) | |
| aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") | |
| GaussianOperator.resave_ply( | |
| in_ply=gs_path, | |
| out_ply=aligned_gs_path, | |
| instance_pose=pose, | |
| device="cpu", | |
| ) | |
| color_path = os.path.join(output_root, "color.png") | |
| render_gs_api(aligned_gs_path, color_path) | |
| mesh = trimesh.Trimesh( | |
| vertices=mesh_model.vertices.cpu().numpy(), | |
| faces=mesh_model.faces.cpu().numpy(), | |
| ) | |
| mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) | |
| mesh.vertices = mesh.vertices @ np.array(rot_matrix) | |
| mesh_obj_path = os.path.join(output_root, f"{filename}.obj") | |
| mesh.export(mesh_obj_path) | |
| mesh = backproject_api( | |
| delight_model=DELIGHT, | |
| imagesr_model=IMAGESR_MODEL, | |
| color_path=color_path, | |
| mesh_path=mesh_obj_path, | |
| output_path=mesh_obj_path, | |
| skip_fix_mesh=False, | |
| delight=True, | |
| texture_wh=[2048, 2048], | |
| ) | |
| mesh_glb_path = os.path.join(output_root, f"{filename}.glb") | |
| mesh.export(mesh_glb_path) | |
| urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4) | |
| asset_attrs = { | |
| "version": VERSION, | |
| "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply", | |
| } | |
| if args.height_range: | |
| min_height, max_height = map( | |
| float, args.height_range.split("-") | |
| ) | |
| asset_attrs["min_height"] = min_height | |
| asset_attrs["max_height"] = max_height | |
| if args.mass_range: | |
| min_mass, max_mass = map(float, args.mass_range.split("-")) | |
| asset_attrs["min_mass"] = min_mass | |
| asset_attrs["max_mass"] = max_mass | |
| if args.asset_type: | |
| asset_attrs["category"] = args.asset_type | |
| if args.version: | |
| asset_attrs["version"] = args.version | |
| urdf_path = urdf_convertor( | |
| mesh_path=mesh_obj_path, | |
| output_root=f"{output_root}/URDF_{filename}", | |
| **asset_attrs, | |
| ) | |
| # Rescale GS and save to URDF/mesh folder. | |
| real_height = urdf_convertor.get_attr_from_urdf( | |
| urdf_path, attr_name="real_height" | |
| ) | |
| out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa | |
| GaussianOperator.resave_ply( | |
| in_ply=aligned_gs_path, | |
| out_ply=out_gs, | |
| real_height=real_height, | |
| device="cpu", | |
| ) | |
| # Quality check and update .urdf file. | |
| mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa | |
| trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb")) | |
| # image_paths = render_asset3d( | |
| # mesh_path=mesh_out, | |
| # output_root=f"{output_root}/URDF_{filename}", | |
| # output_subdir="qa_renders", | |
| # num_images=8, | |
| # elevation=(30, -30), | |
| # distance=5.5, | |
| # ) | |
| image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa | |
| image_paths = glob(f"{image_dir}/*.png") | |
| images_list = [] | |
| for checker in CHECKERS: | |
| images = image_paths | |
| if isinstance(checker, ImageSegChecker): | |
| images = [ | |
| f"{output_root}/{filename}_raw.png", | |
| f"{output_root}/{filename}_cond.png", | |
| ] | |
| images_list.append(images) | |
| results = BaseChecker.validate(CHECKERS, images_list) | |
| urdf_convertor.add_quality_tag(urdf_path, results) | |
| except Exception as e: | |
| logger.error(f"Failed to process {image_path}: {e}, skip.") | |
| continue | |
| logger.info(f"Processing complete. Outputs saved to {args.output_root}") | |