import os import shlex import subprocess import imageio import numpy as np import gradio as gr import spaces import sys from loguru import logger current_path = os.path.dirname(os.path.abspath(__file__)) MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(current_path, 'out') os.makedirs(TMP_DIR, exist_ok=True) TAG = "gsdiff_gobj83k_sd15__render" # subprocess.run(shlex.split("pip install extensions/RaDe-GS/submodules/diff-gaussian-rasterization/")) # download checkpoints # subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas")) # for txt condition # subprocess.run(shlex.split("python3 download_ckpt.py --model_type pas --image_cond")) # for img condition # img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml {} \ # --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \ # --image_path {} --elevation {} --prompt {} --seed {}" # txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_pas.py configs/gsdiff_pas.yaml \ # --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \ # --image_path {} --elevation {} --prompt {} --seed {}" # SD1.5 subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15")) # for txt condition # subprocess.run(shlex.split("python3 download_ckpt.py --model_type sd15 --image_cond")) # for img condition img_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml \ --rembg_and_center --triangle_cfg_scaling --save_ply --output_video_type mp4 --guidance_scale {} \ --image_path {} --elevation {} --prompt {} --seed {}" txt_commands = "PYTHONPATH=./ bash scripts/infer.sh src/infer_gsdiff_sd.py configs/gsdiff_sd15.yaml {} \ --rembg_and_center --save_ply --output_video_type mp4 --guidance_scale {} \ --elevation {} --prompt {} --seed {}" # process function @spaces.GPU def process(input_image, prompt='a_high_quality_3D_asset', prompt_neg='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate', input_elevation=20, guidance_scale=2., input_seed=0): if input_image is not None: image_path = os.path.join(TMP_DIR, "input_image.png") image_name = image_path.split('/')[-1].split('.')[0] input_image.save(image_path) full_command = img_commands.format(TAG, guidance_scale, image_path, input_elevation, prompt, input_seed) else: full_command = txt_commands.format(TAG, guidance_scale, input_elevation, prompt, input_seed) image_name = "" os.system(full_command) # save video and ply files ckpt_dir = os.path.join(TMP_DIR, TAG, "checkpoints") infer_from_iter = int(sorted(os.listdir(ckpt_dir))[-1]) MAX_NAME_LEN = 20 # TODO: make `20` configurable prompt = prompt.replace("_", " ") prompt_name = prompt[:MAX_NAME_LEN] + "..." if prompt[:MAX_NAME_LEN] != "" else prompt name = f"[{image_name}]_[{prompt_name}]_{infer_from_iter:06d}" output_video_path = os.path.join(TMP_DIR, TAG, "inference", name + ".mp4") output_ply_path = os.path.join(TMP_DIR, TAG, "inference", name + ".ply") output_img_path = os.path.join(TMP_DIR, TAG, "inference", name + "_gs.png") logger.info(full_command, output_video_path, output_ply_path) output_image = imageio.imread(output_img_path) return output_image, output_video_path, output_ply_path # gradio UI _TITLE = '''DiffSplat: Repurposing Image Diffusion Models for Scalable Gaussian Splat Generation''' _DESCRIPTION = ''' ### If you find our work helpful, please consider citing our paper 📚 or giving the repo a star 🌟
* Input can be only text, only image, or both image and text. * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. ''' block = gr.Blocks(title=_TITLE).queue() with block: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) with gr.Row(variant='panel'): with gr.Column(scale=1): # input image input_image = gr.Image(label="image", type='pil') # input prompt input_text = gr.Textbox(label="prompt",value="a_high_quality_3D_asset") # negative prompt input_neg_text = gr.Textbox(label="negative prompt", value="") # guidance_scale guidance_scale = gr.Slider(label="guidance scale", minimum=1., maximum=7.5, step=0.5, value=2.0) # elevation input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=20) # # inference steps # input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30) # random seed input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0) # gen button button_gen = gr.Button("Generate") with gr.Column(scale=1): with gr.Tab("Video"): # final video results output_video = gr.Video(label="video") # ply file output_file = gr.File(label="3D Gaussians (ply format)") with gr.Tab("Splatter Images"): output_image = gr.Image(interactive=False, show_label=False) button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, guidance_scale, input_seed], outputs=[output_image, output_video, output_file]) gr.Examples( examples=[ f'assets/diffsplat/{image}' for image in os.listdir("assets/diffsplat") if image.endswith('.png') ], inputs=[input_image], outputs=[output_image, output_video, output_file], fn=lambda x: process(input_image=x), # cache_examples=True, run_on_click=True, label='Image-to-3D Examples' ) gr.Examples( examples=[ "a_toy_robot", "a_cute_panda", "a_book" ], inputs=[input_text], outputs=[output_image, output_video, output_file], fn=lambda x: process(input_image=None, prompt=x), # cache_examples=True, run_on_click=True, label='Text-to-3D Examples' ) # Launch the Gradio app if __name__ == "__main__": block.launch(share=True)