# The kiss3d pipeline wrapper for inference import os import numpy as np import torch import yaml import uuid from typing import Union, Any, Dict from einops import rearrange from PIL import Image from pipeline.utils import logger, TMP_DIR, OUT_DIR from pipeline.utils import lrm_reconstruct, isomer_reconstruct import torch import torchvision # for reconstruction model from omegaconf import OmegaConf from models.lrm.utils.train_util import instantiate_from_config from models.lrm.utils.render_utils import rotate_x, rotate_y from utils.tool import get_background # for florence2 from transformers import AutoProcessor, AutoModelForCausalLM from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline, FluxImg2ImgPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel def init_wrapper_from_config(config_path): with open(config_path, 'r') as config_file: config_ = yaml.load(config_file, yaml.FullLoader) # init flux_pipeline logger.info('==> Loading Flux model ...') flux_device = config_['flux'].get('device', 'cpu') flux_base_model_pth = config_['flux'].get('base_model', None) flux_controlnet_pth = config_['flux'].get('controlnet', None) flux_lora_pth = config_['flux'].get('lora', None) # load flux model and controlnet if flux_controlnet_pth is not None: flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth) flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], \ torch_dtype=torch.bfloat16) else: flux_pipe = FluxImg2ImgPipeline(flux_base_model_pth, torch_dtype=torch.bfloat16) # load lora weights flux_pipe.load_lora_weights(flux_lora_pth) flux_pipe.to(device=flux_device, dtype=torch.bfloat16) # TODO: load redux model # FluxPriorReduxPipeline.from_pretrained() # TODO: load pulid model # init multiview model logger.info('==> Loading multiview diffusion model ...') multiview_device = config_['multiview'].get('device', 'cpu') multiview_pipeline = DiffusionPipeline.from_pretrained( config_['multiview']['base_model'], custom_pipeline=config_['multiview']['custom_pipeline'], torch_dtype=torch.float16, ) multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( multiview_pipeline.scheduler.config, timestep_spacing='trailing' ) unet_ckpt_path = config_['multiview'].get('unet', None) if unet_ckpt_path is not None: state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict'] state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')} multiview_pipeline.unet.load_state_dict(state_dict, strict=True) multiview_pipeline.to(multiview_device) # load caption model logger.info('==> Loading caption model ...') caption_device = config_['caption'].get('device', 'cpu') caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \ torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device) caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True) # load reconstruction model logger.info('==> Loading reconstruction model ...') recon_device = config_['reconstruction'].get('device', 'cpu') recon_model_config = OmegaConf.load(config_['reconstruction']['model_config']) recon_model = instantiate_from_config(recon_model_config.model_config) # load recon model checkpoint state_dict = torch.load(config_['reconstruction']['base_model'], map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} recon_model.load_state_dict(state_dict, strict=True) recon_model.to(recon_device) recon_model.init_flexicubes_geometry(recon_device, fovy=50.0) recon_model.eval() return kiss3d_wrapper( config = config_, flux_pipeline = flux_pipe, multiview_pipeline = multiview_pipeline, caption_processor = caption_processor, caption_model = caption_model, reconstruction_model_config = recon_model_config, reconstruction_model = recon_model, ) class kiss3d_wrapper(object): def __init__(self, config: Dict, flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline], multiview_pipeline: DiffusionPipeline, caption_processor: AutoProcessor, caption_model: AutoModelForCausalLM, reconstruction_model_config: Any, reconstruction_model: Any, ): self.config = config self.flux_pipeline = flux_pipeline self.multiview_pipeline = multiview_pipeline self.caption_model = caption_model self.caption_processor = caption_processor self.recon_model_config = reconstruction_model_config self.recon_model = reconstruction_model self.renew_uuid() def renew_uuid(self): self.uuid = uuid.uuid4() def context(self): if self.config['use_zero_gpu']: import spaces return spaces.GPU() else: return torch.no_grad() def get_image_caption(self, image): """ image: PIL image or path of PIL image """ torch_dtype = torch.bfloat16 caption_device = self.config['caption'].get('device', 'cpu') if isinstance(image, str): # If image is a file path image = Image.open(image).convert("RGB") elif isinstance(image, Image): image = image.convert("RGB") else: raise NotImplementedError('unexpected image type') prompt = "<MORE_DETAILED_CAPTION>" inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype) generated_ids = self.caption_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 ) generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = self.caption_processor.post_process_generation( generated_text, task=prompt, image_size=(image.width, image.height) ) caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "") return caption_text def generate_multiview(self, image): with self.context(): mv_image = self.multiview_pipeline(image, num_inference_steps=self.config['multiview']['num_inference_steps'], width=512*2, height=512*2).images[0] return mv_image def reconstruct_from_multiview(self, mv_image): """ mv_image: PIL.Image """ recon_device = self.config['reconstruction'].get('device', 'cpu') rgb_multi_view = np.asarray(mv_image, dtype=np.float32) / 255.0 rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2).unsqueeze(0).to(recon_device) with self.context(): vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \ lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config, rgb_multi_view, name=self.uuid) return vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo def generate_reference_3D_bundle_image_zero123(self, image, save_intermediate_results=True): """ input: image, PIL.Image return: ref_3D_bundle_image, Tensor of shape (1, 3, 1024, 2048) """ mv_image = self.generate_multiview(image) if save_intermediate_results: mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png')) vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image) ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1] if save_intermediate_results: save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png') torchvision.utils.save_image(ref_3D_bundle_image, save_path) logger.info(f"Save reference 3D bundle image to {save_path}") return ref_3D_bundle_image, save_path return ref_3D_bundle_image def generate_3d_bundle_image_controlnet(self, prompt, image=None, strength=1.0, control_image=[], control_mode=[], control_guidance_start=None, control_guidance_end=None, controlnet_conditioning_scale=None, lora_scale=1.0, save_intermediate_results=True, **kwargs): control_mode_dict = { 'canny': 0, 'tile': 1, 'depth': 2, 'blur': 3, 'pose': 4, 'gray': 5, 'lq': 6, } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only flux_device = self.config['flux'].get('device', 'cpu') seed = self.config['flux'].get('seed', 0) generator = torch.Generator(device=flux_device).manual_seed(seed) hparam_dict = { 'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]), 'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device), 'strength': strength, 'num_inference_steps': 30, 'guidance_scale': 3.5, 'num_images_per_prompt': 1, 'width': 2048, 'height': 1024, 'output_type': 'np', 'generator': generator, 'joint_attention_kwargs': {"scale": lora_scale} } hparam_dict.update(kwargs) # append controlnet hparams if len(control_image) > 0: assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline) assert len(control_mode) == len(control_image) # the count of image should be the same as control mode flux_ctrl_net = self.flux_pipeline.controlnet.nets[0] self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for i in range(len(control_image))]) ctrl_hparams = { 'control_mode': [control_mode_dict[mode_] for mode_ in control_mode], 'control_image': control_image, 'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))], 'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))], 'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))], } hparam_dict.update(ctrl_hparams) with self.context(): gen_3d_bundle_image = self.flux_pipeline(**hparam_dict).images gen_3d_bundle_image_ = torch.from_numpy(gen_3d_bundle_image).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) if save_intermediate_results: save_path = os.path.join(TMP_DIR, f'{self.uuid}_gen_3d_bundle_image.png') torchvision.utils.save_image(gen_3d_bundle_image_, save_path) logger.info(f"Save generated 3D bundle image to {save_path}") return gen_3d_bundle_image_, save_path return gen_3d_bundle_image_ def generate_3d_bundle_image_text(self, prompt, image=None, strength=1.0, lora_scale=1.0, num_inference_steps=30, save_intermediate_results=True, **kwargs): """ return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.] """ if isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline): flux_pipeline = FluxImg2ImgPipeline( scheduler = self.flux_pipeline.scheduler, vae = self.flux_pipeline.vae, text_encoder = self.flux_pipeline.text_encoder, tokenizer = self.flux_pipeline.tokenizer, text_encoder_2 = self.flux_pipeline.text_encoder_2, tokenizer_2 = self.flux_pipeline.tokenizer_2, transformer = self.flux_pipeline.transformer ) else: flux_pipeline = self.flux_pipeline flux_device = self.config['flux'].get('device', 'cpu') seed = self.config['flux'].get('seed', 0) generator = torch.Generator(device=flux_device).manual_seed(seed) hparam_dict = { 'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]), 'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device), 'strength': strength, 'num_inference_steps': num_inference_steps, 'guidance_scale': 3.5, 'num_images_per_prompt': 1, 'width': 2048, 'height': 1024, 'output_type': 'np', 'generator': generator, 'joint_attention_kwargs': {"scale": lora_scale} } hparam_dict.update(kwargs) with self.context(): gen_3d_bundle_image = flux_pipeline(**hparam_dict).images gen_3d_bundle_image_ = torch.from_numpy(gen_3d_bundle_image).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048) if save_intermediate_results: save_path = os.path.join(TMP_DIR, f'{self.uuid}_gen_3d_bundle_image.png') torchvision.utils.save_image(gen_3d_bundle_image_, save_path) logger.info(f"Save generated 3D bundle image to {save_path}") return gen_3d_bundle_image_, save_path return gen_3d_bundle_image_ def reconstruct_3d_bundle_image(self, image, save_intermediate_results=True): """ image: torch.Tensor, range [0., 1.], (3, 1024, 2048) """ recon_device = self.config['reconstruction'].get('device', 'cpu') # split rgb and normal images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (3, 1024, 2048) -> (8, 3, 512, 512) rgb_multi_view, normal_multi_view = images.chunk(2, dim=0) multi_view_mask = get_background(normal_multi_view).to(recon_device) rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask) with self.context(): vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \ lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config, rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid, input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results, render_azimuths=[0, 90, 180, 270]) if save_intermediate_results: recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1] torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{k3d_wrapper.uuid})_lrm_recon_3d_bundle_image.png')) recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj") return isomer_reconstruct(rgb_multi_view=rgb_multi_view, normal_multi_view=normal_multi_view, multi_view_mask=multi_view_mask, vertices=vertices, faces=faces, save_path=recon_mesh_path) def run_text_to_3d(k3d_wrapper, prompt, init_image_path=None): # ======================================= Example of text to 3D generation ====================================== # Renew The uuid k3d_wrapper.renew_uuid() # FOR Text to 3D (also for image to image) with init image init_image = None if init_image_path is not None: init_image = Image.open(init_image_path) gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt, image=init_image, strength=1.0, save_intermediate_results=True) # recon from 3D Bundle image recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False) return gen_save_path, recon_mesh_path def run_image_to_3d(k3d_wrapper, init_image_path): # ======================================= Example of image to 3D generation ====================================== # Renew The uuid k3d_wrapper.renew_uuid() # FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image input_image = Image.open(init_image_path) reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image) caption = k3d_wrapper.get_image_caption(input_image) import pdb pdb.set_trace() if __name__ == "__main__": k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/Kiss3DGen/pipeline/pipeline_config/default.yaml') # Example of loading existing 3D bundle Image # demo_image = Image.open('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/tmp/ea25bc9b-d775-46bb-9827-660a9a6540c8_gen_3d_bundle_image.png') # gen_3d_bundle_image = torchvision.transforms.functional.to_tensor(demo_image) run_image_to_3d(k3d_wrapper, '/hpc2hdd/home/jlin695/code/Kiss3DGen/examples/蓝色小怪物.webp') # run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')