Spaces:
Paused
Paused
| import spaces | |
| import os | |
| import subprocess | |
| import tempfile | |
| import uuid | |
| import glob | |
| import shutil | |
| import time | |
| import gradio as gr | |
| import sys | |
| from PIL import Image | |
| def install_cuda_toolkit(): | |
| CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" | |
| CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) | |
| subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) | |
| subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) | |
| subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) | |
| os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
| os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) | |
| os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( | |
| os.environ["CUDA_HOME"], | |
| "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], | |
| ) | |
| # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range | |
| os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" | |
| print("==> finished installation") | |
| install_cuda_toolkit() | |
| import os | |
| import torch | |
| import numpy as np | |
| import trimesh | |
| from pytorch3d.io import load_obj | |
| from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer | |
| from pixel3dmm.tracking.flame.FLAME import FLAME | |
| from pixel3dmm.tracking.tracker import Tracker | |
| from pixel3dmm import env_paths | |
| from omegaconf import OmegaConf | |
| DEVICE = "cuda" | |
| base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml') | |
| _mesh_file = env_paths.head_template | |
| flame_model = FLAME(base_conf).to(DEVICE) | |
| _obj_faces = load_obj(_mesh_file)[1] | |
| diff_renderer = NVDRenderer( | |
| image_size=base_conf.size, | |
| obj_filename=_mesh_file, | |
| no_sh=False, | |
| white_bg=True | |
| ).to(DEVICE) | |
| # Utility to select first image from a folder | |
| def first_image_from_dir(directory): | |
| patterns = ["*.jpg", "*.png", "*.jpeg"] | |
| files = [] | |
| for p in patterns: | |
| files.extend(glob.glob(os.path.join(directory, p))) | |
| if not files: | |
| return None | |
| return sorted(files)[0] | |
| # Function to reset the UI and state | |
| def reset_all(): | |
| return ( | |
| None, # crop_img | |
| None, # normals_img | |
| None, # uv_img | |
| None, # track_img | |
| "Awaiting new image upload...", # status | |
| {}, # state | |
| gr.update(interactive=True), # preprocess_btn | |
| gr.update(interactive=False), # normals_btn | |
| gr.update(interactive=False), # uv_map_btn | |
| gr.update(interactive=False) # track_btn | |
| ) | |
| # Step 1: Preprocess the input image (Save and Crop) | |
| # @spaces.GPU() | |
| def preprocess_image(image_array, state): | |
| if image_array is None: | |
| return "β Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False) | |
| session_id = str(uuid.uuid4()) | |
| base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id) | |
| os.makedirs(base_dir, exist_ok=True) | |
| state.update({"session_id": session_id, "base_dir": base_dir}) | |
| img = Image.fromarray(image_array) | |
| saved_image_path = os.path.join(base_dir, f"{session_id}.png") | |
| img.save(saved_image_path) | |
| state["image_path"] = saved_image_path | |
| try: | |
| p = subprocess.run([ | |
| "python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path | |
| ], check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| err = f"β Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
| shutil.rmtree(base_dir) | |
| return err, None, {}, gr.update(interactive=True), gr.update(interactive=False) | |
| crop_dir = os.path.join(base_dir, "cropped") | |
| image = first_image_from_dir(crop_dir) | |
| return "β Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=False), gr.update(interactive=True) | |
| # Step 2: Normals inference β normals image | |
| def step2_normals(state): | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return "β State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False) | |
| try: | |
| p = subprocess.run([ | |
| "python", "scripts/network_inference.py", "model.prediction_type=normals", f"video_name={session_id}" | |
| ], check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| err = f"β Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
| return err, None, state, gr.update(interactive=True), gr.update(interactive=False) | |
| normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals") | |
| image = first_image_from_dir(normals_dir) | |
| return "β Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=False), gr.update(interactive=True) | |
| # Step 3: UV map inference β uv map image | |
| def step3_uv_map(state): | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return "β State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False) | |
| try: | |
| p = subprocess.run([ | |
| "python", "scripts/network_inference.py", "model.prediction_type=uv_map", f"video_name={session_id}" | |
| ], check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as e: | |
| err = f"β UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
| return err, None, state, gr.update(interactive=True), gr.update(interactive=False) | |
| uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map") | |
| image = first_image_from_dir(uv_dir) | |
| return "β Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=False), gr.update(interactive=True) | |
| # Step 4: Tracking β final tracking image | |
| def step4_track(state): | |
| session_id = state.get("session_id") | |
| base_conf.video_name = f'{session_id}' | |
| tracker = Tracker(base_conf, flame_model, diff_renderer) | |
| tracker.run() | |
| tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames") | |
| image = first_image_from_dir(tracking_dir) | |
| return "β Pipeline complete!", image, state, gr.update(interactive=False) | |
| # Build Gradio UI | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("## Image Processing Pipeline") | |
| gr.Markdown("Upload an image, then click the buttons in order. Uploading a new image will reset the process.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_in = gr.Image(label="Upload Image", type="numpy", height=512) | |
| status = gr.Textbox(label="Status", lines=2, interactive=False, value="Upload an image to start.") | |
| state = gr.State({}) | |
| with gr.Column(): | |
| with gr.Row(): | |
| crop_img = gr.Image(label="Preprocessed", height=256) | |
| normals_img = gr.Image(label="Normals", height=256) | |
| with gr.Row(): | |
| uv_img = gr.Image(label="UV Map", height=256) | |
| track_img = gr.Image(label="Tracking", height=256) | |
| with gr.Row(): | |
| preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True) | |
| normals_btn = gr.Button("Step 2: Normals", interactive=False) | |
| uv_map_btn = gr.Button("Step 3: UV Map", interactive=False) | |
| track_btn = gr.Button("Step 4: Track", interactive=False) | |
| # Define component list for reset | |
| outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn] | |
| # Pipeline execution logic | |
| preprocess_btn.click( | |
| fn=preprocess_image, | |
| inputs=[image_in, state], | |
| outputs=[status, crop_img, state, preprocess_btn, normals_btn] | |
| ) | |
| normals_btn.click( | |
| fn=step2_normals, | |
| inputs=[state], | |
| outputs=[status, normals_img, state, normals_btn, uv_map_btn] | |
| ) | |
| uv_map_btn.click( | |
| fn=step3_uv_map, | |
| inputs=[state], | |
| outputs=[status, uv_img, state, uv_map_btn, track_btn] | |
| ) | |
| track_btn.click( | |
| fn=step4_track, | |
| inputs=[state], | |
| outputs=[status, track_img, state, track_btn] | |
| ) | |
| # Event to reset everything when a new image is uploaded | |
| image_in.upload(fn=reset_all, inputs=None, outputs=outputs_for_reset) | |
| # ------------------------------------------------------------------ | |
| # START THE GRADIO SERVER | |
| # ------------------------------------------------------------------ | |
| demo.queue() | |
| demo.launch(share=True, ssr_mode=False) |