Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import os | |
| from typing import Union | |
| import cv2 | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from diffusers import ( | |
| EulerAncestralDiscreteScheduler, | |
| StableDiffusionInstructPix2PixPipeline, | |
| ) | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from embodied_gen.models.segment_model import RembgRemover | |
| __all__ = [ | |
| "DelightingModel", | |
| ] | |
| class DelightingModel(object): | |
| """A model to remove the lighting in image space. | |
| This model is encapsulated based on the Hunyuan3D-Delight model | |
| from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa | |
| Attributes: | |
| image_guide_scale (float): Weight of image guidance in diffusion process. | |
| text_guide_scale (float): Weight of text (prompt) guidance in diffusion process. | |
| num_infer_step (int): Number of inference steps for diffusion model. | |
| mask_erosion_size (int): Size of erosion kernel for alpha mask cleanup. | |
| device (str): Device used for inference, e.g., 'cuda' or 'cpu'. | |
| seed (int): Random seed for diffusion model reproducibility. | |
| model_path (str): Filesystem path to pretrained model weights. | |
| pipeline: Lazy-loaded diffusion pipeline instance. | |
| """ | |
| def __init__( | |
| self, | |
| model_path: str = None, | |
| num_infer_step: int = 50, | |
| mask_erosion_size: int = 3, | |
| image_guide_scale: float = 1.5, | |
| text_guide_scale: float = 1.0, | |
| device: str = "cuda", | |
| seed: int = 0, | |
| ) -> None: | |
| self.image_guide_scale = image_guide_scale | |
| self.text_guide_scale = text_guide_scale | |
| self.num_infer_step = num_infer_step | |
| self.mask_erosion_size = mask_erosion_size | |
| self.kernel = np.ones( | |
| (self.mask_erosion_size, self.mask_erosion_size), np.uint8 | |
| ) | |
| self.seed = seed | |
| self.device = device | |
| self.pipeline = None # lazy load model adapt to @spaces.GPU | |
| if model_path is None: | |
| suffix = "hunyuan3d-delight-v2-0" | |
| model_path = snapshot_download( | |
| repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*" | |
| ) | |
| model_path = os.path.join(model_path, suffix) | |
| self.model_path = model_path | |
| def _lazy_init_pipeline(self): | |
| if self.pipeline is None: | |
| pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| ) | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipeline.scheduler.config | |
| ) | |
| pipeline.set_progress_bar_config(disable=True) | |
| pipeline.to(self.device, torch.float16) | |
| self.pipeline = pipeline | |
| def recenter_image( | |
| self, image: Image.Image, border_ratio: float = 0.2 | |
| ) -> Image.Image: | |
| if image.mode == "RGB": | |
| return image | |
| elif image.mode == "L": | |
| image = image.convert("RGB") | |
| return image | |
| alpha_channel = np.array(image)[:, :, 3] | |
| non_zero_indices = np.argwhere(alpha_channel > 0) | |
| if non_zero_indices.size == 0: | |
| raise ValueError("Image is fully transparent") | |
| min_row, min_col = non_zero_indices.min(axis=0) | |
| max_row, max_col = non_zero_indices.max(axis=0) | |
| cropped_image = image.crop( | |
| (min_col, min_row, max_col + 1, max_row + 1) | |
| ) | |
| width, height = cropped_image.size | |
| border_width = int(width * border_ratio) | |
| border_height = int(height * border_ratio) | |
| new_width = width + 2 * border_width | |
| new_height = height + 2 * border_height | |
| square_size = max(new_width, new_height) | |
| new_image = Image.new( | |
| "RGBA", (square_size, square_size), (255, 255, 255, 0) | |
| ) | |
| paste_x = (square_size - new_width) // 2 + border_width | |
| paste_y = (square_size - new_height) // 2 + border_height | |
| new_image.paste(cropped_image, (paste_x, paste_y)) | |
| return new_image | |
| def __call__( | |
| self, | |
| image: Union[str, np.ndarray, Image.Image], | |
| preprocess: bool = False, | |
| target_wh: tuple[int, int] = None, | |
| ) -> Image.Image: | |
| self._lazy_init_pipeline() | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| if preprocess: | |
| bg_remover = RembgRemover() | |
| image = bg_remover(image) | |
| image = self.recenter_image(image) | |
| if target_wh is not None: | |
| image = image.resize(target_wh) | |
| else: | |
| target_wh = image.size | |
| image_array = np.array(image) | |
| assert image_array.shape[-1] == 4, "Image must have alpha channel" | |
| raw_alpha_channel = image_array[:, :, 3] | |
| alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1) | |
| image_array[alpha_channel == 0, :3] = 255 # must be white background | |
| image_array[:, :, 3] = alpha_channel | |
| image = self.pipeline( | |
| prompt="", | |
| image=Image.fromarray(image_array).convert("RGB"), | |
| generator=torch.manual_seed(self.seed), | |
| num_inference_steps=self.num_infer_step, | |
| image_guidance_scale=self.image_guide_scale, | |
| guidance_scale=self.text_guide_scale, | |
| ).images[0] | |
| alpha_channel = Image.fromarray(alpha_channel) | |
| rgba_image = image.convert("RGBA").resize(target_wh) | |
| rgba_image.putalpha(alpha_channel) | |
| return rgba_image | |
| if __name__ == "__main__": | |
| delighting_model = DelightingModel() | |
| image_path = "apps/assets/example_image/sample_12.jpg" | |
| image = delighting_model( | |
| image_path, preprocess=True, target_wh=(512, 512) | |
| ) # noqa | |
| image.save("delight.png") | |
| # image_path = "embodied_gen/scripts/test_robot.png" | |
| # image = delighting_model(image_path) | |
| # image.save("delighting_image_a2.png") | |