Spaces:
Sleeping
Sleeping
from gradio_client import Client, handle_file | |
import gradio as gr | |
# Initialize Gradio Client | |
client = Client("JeffreyXiang/TRELLIS") # Replace with your Hugging Face Space | |
# Helper Functions for API Calls | |
def start_session(): | |
result = client.predict(api_name="/start_session") | |
return result | |
def preprocess_image(image): | |
result = client.predict( | |
image=handle_file(image), | |
api_name="/preprocess_image" | |
) | |
return result | |
def preprocess_images(images): | |
processed_images = [ | |
{"image": handle_file(img), "caption": None} for img in images | |
] | |
result = client.predict( | |
images=processed_images, | |
api_name="/preprocess_images" | |
) | |
return result | |
def get_seed(randomize_seed, seed): | |
result = client.predict( | |
randomize_seed=randomize_seed, | |
seed=seed, | |
api_name="/get_seed" | |
) | |
return result | |
def image_to_3d(image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo): | |
result = client.predict( | |
image=handle_file(image), | |
multiimages=[], | |
seed=seed, | |
ss_guidance_strength=ss_guidance_strength, | |
ss_sampling_steps=ss_sampling_steps, | |
slat_guidance_strength=slat_guidance_strength, | |
slat_sampling_steps=slat_sampling_steps, | |
multiimage_algo=multiimage_algo, | |
api_name="/image_to_3d" | |
) | |
return result["video"] | |
def extract_glb(mesh_simplify, texture_size): | |
result = client.predict( | |
mesh_simplify=mesh_simplify, | |
texture_size=texture_size, | |
api_name="/extract_glb" | |
) | |
return result[1] # Return the GLB file path for download | |
def extract_gaussian(): | |
result = client.predict(api_name="/extract_gaussian") | |
return result[1] # Return the Gaussian file path for download | |
# Define Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image to 3D Model with TRELLIS API") | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type="filepath", label="Upload Image") | |
seed = gr.Slider(0, 100, value=0, step=1, label="Seed") | |
ss_guidance_strength = gr.Slider(0.0, 10.0, value=7.5, step=0.1, label="SS Guidance Strength") | |
ss_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SS Sampling Steps") | |
slat_guidance_strength = gr.Slider(0.0, 10.0, value=3.0, step=0.1, label="SLAT Guidance Strength") | |
slat_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SLAT Sampling Steps") | |
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="stochastic", label="Multi-image Algorithm") | |
generate_btn = gr.Button("Generate 3D Model") | |
with gr.Column(): | |
video_output = gr.Video(label="3D Model Preview") | |
download_glb_btn = gr.Button("Download GLB") | |
download_gaussian_btn = gr.Button("Download Gaussian") | |
glb_file = gr.File(label="GLB File") | |
gaussian_file = gr.File(label="Gaussian File") | |
# Define Actions | |
generate_btn.click( | |
fn=image_to_3d, | |
inputs=[image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], | |
outputs=[video_output] | |
) | |
# Download Buttons and Actions | |
download_glb_btn.click( | |
fn=lambda: extract_glb(mesh_simplify=0.95, texture_size=1024), # Static values | |
inputs=[], # No dynamic inputs | |
outputs=[glb_file] # Output file for GLB | |
) | |
download_gaussian_btn.click( | |
fn=extract_gaussian, # Direct function call | |
inputs=[], # No dynamic inputs | |
outputs=[gaussian_file] # Output file for Gaussian | |
) | |
# Launch Gradio App | |
if __name__ == "__main__": | |
demo.launch(show_error=True) |