trellisapitest / app.py
theekshana03289's picture
Create app.py
690d8e3 verified
from gradio_client import Client, handle_file
import gradio as gr
# Initialize Gradio Client
client = Client("JeffreyXiang/TRELLIS") # Replace with your Hugging Face Space
# Helper Functions for API Calls
def start_session():
result = client.predict(api_name="/start_session")
return result
def preprocess_image(image):
result = client.predict(
image=handle_file(image),
api_name="/preprocess_image"
)
return result
def preprocess_images(images):
processed_images = [
{"image": handle_file(img), "caption": None} for img in images
]
result = client.predict(
images=processed_images,
api_name="/preprocess_images"
)
return result
def get_seed(randomize_seed, seed):
result = client.predict(
randomize_seed=randomize_seed,
seed=seed,
api_name="/get_seed"
)
return result
def image_to_3d(image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo):
result = client.predict(
image=handle_file(image),
multiimages=[],
seed=seed,
ss_guidance_strength=ss_guidance_strength,
ss_sampling_steps=ss_sampling_steps,
slat_guidance_strength=slat_guidance_strength,
slat_sampling_steps=slat_sampling_steps,
multiimage_algo=multiimage_algo,
api_name="/image_to_3d"
)
return result["video"]
def extract_glb(mesh_simplify, texture_size):
result = client.predict(
mesh_simplify=mesh_simplify,
texture_size=texture_size,
api_name="/extract_glb"
)
return result[1] # Return the GLB file path for download
def extract_gaussian():
result = client.predict(api_name="/extract_gaussian")
return result[1] # Return the Gaussian file path for download
# Define Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Image to 3D Model with TRELLIS API")
with gr.Row():
with gr.Column():
image = gr.Image(type="filepath", label="Upload Image")
seed = gr.Slider(0, 100, value=0, step=1, label="Seed")
ss_guidance_strength = gr.Slider(0.0, 10.0, value=7.5, step=0.1, label="SS Guidance Strength")
ss_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SS Sampling Steps")
slat_guidance_strength = gr.Slider(0.0, 10.0, value=3.0, step=0.1, label="SLAT Guidance Strength")
slat_sampling_steps = gr.Slider(1, 50, value=12, step=1, label="SLAT Sampling Steps")
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="stochastic", label="Multi-image Algorithm")
generate_btn = gr.Button("Generate 3D Model")
with gr.Column():
video_output = gr.Video(label="3D Model Preview")
download_glb_btn = gr.Button("Download GLB")
download_gaussian_btn = gr.Button("Download Gaussian")
glb_file = gr.File(label="GLB File")
gaussian_file = gr.File(label="Gaussian File")
# Define Actions
generate_btn.click(
fn=image_to_3d,
inputs=[image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
outputs=[video_output]
)
# Download Buttons and Actions
download_glb_btn.click(
fn=lambda: extract_glb(mesh_simplify=0.95, texture_size=1024), # Static values
inputs=[], # No dynamic inputs
outputs=[glb_file] # Output file for GLB
)
download_gaussian_btn.click(
fn=extract_gaussian, # Direct function call
inputs=[], # No dynamic inputs
outputs=[gaussian_file] # Output file for Gaussian
)
# Launch Gradio App
if __name__ == "__main__":
demo.launch(show_error=True)