feat: add gradio ui
Browse files
app.py
CHANGED
@@ -373,156 +373,157 @@ async def remove_image_background(image: UploadFile = File(...), token: str = De
|
|
373 |
raise HTTPException(status_code=500, detail=str(e))
|
374 |
|
375 |
|
376 |
-
@app.get("/")
|
377 |
-
async def root_():
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
|
386 |
-
|
387 |
-
|
|
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
|
399 |
-
|
400 |
-
|
401 |
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
#
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
#
|
463 |
-
|
464 |
-
|
465 |
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
|
527 |
|
528 |
# Launch the Gradio app
|
@@ -535,6 +536,6 @@ if __name__ == "__main__":
|
|
535 |
pipeline.cuda()
|
536 |
|
537 |
print("STARTING SERVER ...")
|
538 |
-
|
539 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
540 |
print("SERVER STARTED!!!")
|
|
|
373 |
raise HTTPException(status_code=500, detail=str(e))
|
374 |
|
375 |
|
376 |
+
# @app.get("/")
|
377 |
+
# async def root_():
|
378 |
+
# return JSONResponse(content={"message": "HI!"})
|
379 |
+
|
380 |
+
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
381 |
+
gr.Markdown("""
|
382 |
+
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
383 |
+
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
|
384 |
+
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
385 |
|
386 |
+
✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
|
387 |
+
""")
|
388 |
+
demo.load(verify_token)
|
389 |
|
390 |
+
with gr.Row():
|
391 |
+
with gr.Column():
|
392 |
+
with gr.Tabs() as input_tabs:
|
393 |
+
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
|
394 |
+
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
|
395 |
+
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
|
396 |
+
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
|
397 |
+
gr.Markdown("""
|
398 |
+
Input different views of the object in separate images.
|
399 |
|
400 |
+
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
|
401 |
+
""")
|
402 |
|
403 |
+
with gr.Accordion(label="Generation Settings", open=False):
|
404 |
+
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
405 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
406 |
+
gr.Markdown("Stage 1: Sparse Structure Generation")
|
407 |
+
with gr.Row():
|
408 |
+
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
409 |
+
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
410 |
+
gr.Markdown("Stage 2: Structured Latent Generation")
|
411 |
+
with gr.Row():
|
412 |
+
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
413 |
+
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
414 |
+
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
|
415 |
+
|
416 |
+
generate_btn = gr.Button("Generate")
|
417 |
|
418 |
+
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
419 |
+
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
420 |
+
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
421 |
|
422 |
+
with gr.Row():
|
423 |
+
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
424 |
+
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
425 |
+
gr.Markdown("""
|
426 |
+
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
427 |
+
""")
|
428 |
+
|
429 |
+
with gr.Column():
|
430 |
+
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
431 |
+
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
|
432 |
|
433 |
+
with gr.Row():
|
434 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
435 |
+
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
436 |
|
437 |
+
is_multiimage = gr.State(False)
|
438 |
+
output_buf = gr.State()
|
439 |
+
|
440 |
+
# Example images at the bottom of the page
|
441 |
+
with gr.Row() as single_image_example:
|
442 |
+
examples = gr.Examples(
|
443 |
+
examples=[
|
444 |
+
f'assets/example_image/{image}'
|
445 |
+
for image in os.listdir("assets/example_image")
|
446 |
+
],
|
447 |
+
inputs=[image_prompt],
|
448 |
+
fn=preprocess_image,
|
449 |
+
outputs=[image_prompt],
|
450 |
+
run_on_click=True,
|
451 |
+
examples_per_page=64,
|
452 |
+
)
|
453 |
+
with gr.Row(visible=False) as multiimage_example:
|
454 |
+
examples_multi = gr.Examples(
|
455 |
+
examples=prepare_multi_example(),
|
456 |
+
inputs=[image_prompt],
|
457 |
+
fn=split_image,
|
458 |
+
outputs=[multiimage_prompt],
|
459 |
+
run_on_click=True,
|
460 |
+
examples_per_page=8,
|
461 |
+
)
|
462 |
+
|
463 |
+
# Handlers
|
464 |
+
demo.load(start_session)
|
465 |
+
demo.unload(end_session)
|
466 |
|
467 |
+
single_image_input_tab.select(
|
468 |
+
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
|
469 |
+
outputs=[is_multiimage, single_image_example, multiimage_example]
|
470 |
+
)
|
471 |
+
multiimage_input_tab.select(
|
472 |
+
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
|
473 |
+
outputs=[is_multiimage, single_image_example, multiimage_example]
|
474 |
+
)
|
475 |
|
476 |
+
image_prompt.upload(
|
477 |
+
preprocess_image,
|
478 |
+
inputs=[image_prompt],
|
479 |
+
outputs=[image_prompt],
|
480 |
+
)
|
481 |
+
multiimage_prompt.upload(
|
482 |
+
preprocess_images,
|
483 |
+
inputs=[multiimage_prompt],
|
484 |
+
outputs=[multiimage_prompt],
|
485 |
+
)
|
486 |
+
|
487 |
+
generate_btn.click(
|
488 |
+
get_seed,
|
489 |
+
inputs=[randomize_seed, seed],
|
490 |
+
outputs=[seed],
|
491 |
+
).then(
|
492 |
+
image_to_3d,
|
493 |
+
inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
|
494 |
+
outputs=[output_buf, video_output],
|
495 |
+
).then(
|
496 |
+
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
497 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
498 |
+
)
|
499 |
+
|
500 |
+
video_output.clear(
|
501 |
+
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
502 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
503 |
+
)
|
504 |
+
|
505 |
+
extract_glb_btn.click(
|
506 |
+
extract_glb,
|
507 |
+
inputs=[output_buf, mesh_simplify, texture_size],
|
508 |
+
outputs=[model_output, download_glb],
|
509 |
+
).then(
|
510 |
+
lambda: gr.Button(interactive=True),
|
511 |
+
outputs=[download_glb],
|
512 |
+
)
|
513 |
|
514 |
+
extract_gs_btn.click(
|
515 |
+
extract_gaussian,
|
516 |
+
inputs=[output_buf],
|
517 |
+
outputs=[model_output, download_gs],
|
518 |
+
).then(
|
519 |
+
lambda: gr.Button(interactive=True),
|
520 |
+
outputs=[download_gs],
|
521 |
+
)
|
522 |
+
|
523 |
+
model_output.clear(
|
524 |
+
lambda: gr.Button(interactive=False),
|
525 |
+
outputs=[download_glb],
|
526 |
+
)
|
527 |
|
528 |
|
529 |
# Launch the Gradio app
|
|
|
536 |
pipeline.cuda()
|
537 |
|
538 |
print("STARTING SERVER ...")
|
539 |
+
app = mount_gradio_app(app, demo, path="/")
|
540 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
541 |
print("SERVER STARTED!!!")
|