Amir8212 commited on
Commit
e4fa547
·
1 Parent(s): 86715a0

feat: add gradio ui

Browse files
Files changed (1) hide show
  1. app.py +141 -140
app.py CHANGED
@@ -373,156 +373,157 @@ async def remove_image_background(image: UploadFile = File(...), token: str = De
373
  raise HTTPException(status_code=500, detail=str(e))
374
 
375
 
376
- @app.get("/")
377
- async def root_():
378
- return JSONResponse(content={"message": "HI!"})
379
-
380
- # with gr.Blocks(delete_cache=(600, 600)) as demo:
381
- # gr.Markdown("""
382
- # ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
383
- # * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
384
- # * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
385
 
386
- # ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
387
- # """)
 
388
 
389
- # with gr.Row():
390
- # with gr.Column():
391
- # with gr.Tabs() as input_tabs:
392
- # with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
393
- # image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
394
- # with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
395
- # multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
396
- # gr.Markdown("""
397
- # Input different views of the object in separate images.
398
 
399
- # *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
400
- # """)
401
 
402
- # with gr.Accordion(label="Generation Settings", open=False):
403
- # seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
404
- # randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
405
- # gr.Markdown("Stage 1: Sparse Structure Generation")
406
- # with gr.Row():
407
- # ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
408
- # ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
409
- # gr.Markdown("Stage 2: Structured Latent Generation")
410
- # with gr.Row():
411
- # slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
412
- # slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
413
- # multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
414
-
415
- # generate_btn = gr.Button("Generate")
416
 
417
- # with gr.Accordion(label="GLB Extraction Settings", open=False):
418
- # mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
419
- # texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
420
 
421
- # with gr.Row():
422
- # extract_glb_btn = gr.Button("Extract GLB", interactive=False)
423
- # extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
424
- # gr.Markdown("""
425
- # *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
426
- # """)
427
-
428
- # with gr.Column():
429
- # video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
430
- # model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
431
 
432
- # with gr.Row():
433
- # download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
434
- # download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
435
 
436
- # is_multiimage = gr.State(False)
437
- # output_buf = gr.State()
438
-
439
- # # Example images at the bottom of the page
440
- # with gr.Row() as single_image_example:
441
- # examples = gr.Examples(
442
- # examples=[
443
- # f'assets/example_image/{image}'
444
- # for image in os.listdir("assets/example_image")
445
- # ],
446
- # inputs=[image_prompt],
447
- # fn=preprocess_image,
448
- # outputs=[image_prompt],
449
- # run_on_click=True,
450
- # examples_per_page=64,
451
- # )
452
- # with gr.Row(visible=False) as multiimage_example:
453
- # examples_multi = gr.Examples(
454
- # examples=prepare_multi_example(),
455
- # inputs=[image_prompt],
456
- # fn=split_image,
457
- # outputs=[multiimage_prompt],
458
- # run_on_click=True,
459
- # examples_per_page=8,
460
- # )
461
-
462
- # # Handlers
463
- # demo.load(start_session)
464
- # demo.unload(end_session)
465
 
466
- # single_image_input_tab.select(
467
- # lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
468
- # outputs=[is_multiimage, single_image_example, multiimage_example]
469
- # )
470
- # multiimage_input_tab.select(
471
- # lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
472
- # outputs=[is_multiimage, single_image_example, multiimage_example]
473
- # )
474
 
475
- # image_prompt.upload(
476
- # preprocess_image,
477
- # inputs=[image_prompt],
478
- # outputs=[image_prompt],
479
- # )
480
- # multiimage_prompt.upload(
481
- # preprocess_images,
482
- # inputs=[multiimage_prompt],
483
- # outputs=[multiimage_prompt],
484
- # )
485
-
486
- # generate_btn.click(
487
- # get_seed,
488
- # inputs=[randomize_seed, seed],
489
- # outputs=[seed],
490
- # ).then(
491
- # image_to_3d,
492
- # inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
493
- # outputs=[output_buf, video_output],
494
- # ).then(
495
- # lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
496
- # outputs=[extract_glb_btn, extract_gs_btn],
497
- # )
498
-
499
- # video_output.clear(
500
- # lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
501
- # outputs=[extract_glb_btn, extract_gs_btn],
502
- # )
503
-
504
- # extract_glb_btn.click(
505
- # extract_glb,
506
- # inputs=[output_buf, mesh_simplify, texture_size],
507
- # outputs=[model_output, download_glb],
508
- # ).then(
509
- # lambda: gr.Button(interactive=True),
510
- # outputs=[download_glb],
511
- # )
512
 
513
- # extract_gs_btn.click(
514
- # extract_gaussian,
515
- # inputs=[output_buf],
516
- # outputs=[model_output, download_gs],
517
- # ).then(
518
- # lambda: gr.Button(interactive=True),
519
- # outputs=[download_gs],
520
- # )
521
-
522
- # model_output.clear(
523
- # lambda: gr.Button(interactive=False),
524
- # outputs=[download_glb],
525
- # )
526
 
527
 
528
  # Launch the Gradio app
@@ -535,6 +536,6 @@ if __name__ == "__main__":
535
  pipeline.cuda()
536
 
537
  print("STARTING SERVER ...")
538
- # app = mount_gradio_app(app, demo, path="/")
539
  uvicorn.run(app, host="0.0.0.0", port=7860)
540
  print("SERVER STARTED!!!")
 
373
  raise HTTPException(status_code=500, detail=str(e))
374
 
375
 
376
+ # @app.get("/")
377
+ # async def root_():
378
+ # return JSONResponse(content={"message": "HI!"})
379
+
380
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
381
+ gr.Markdown("""
382
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
383
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
384
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
385
 
386
+ ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
387
+ """)
388
+ demo.load(verify_token)
389
 
390
+ with gr.Row():
391
+ with gr.Column():
392
+ with gr.Tabs() as input_tabs:
393
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
394
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
395
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
396
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
397
+ gr.Markdown("""
398
+ Input different views of the object in separate images.
399
 
400
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
401
+ """)
402
 
403
+ with gr.Accordion(label="Generation Settings", open=False):
404
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
405
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
406
+ gr.Markdown("Stage 1: Sparse Structure Generation")
407
+ with gr.Row():
408
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
409
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
410
+ gr.Markdown("Stage 2: Structured Latent Generation")
411
+ with gr.Row():
412
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
413
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
414
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
415
+
416
+ generate_btn = gr.Button("Generate")
417
 
418
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
419
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
420
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
421
 
422
+ with gr.Row():
423
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
424
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
425
+ gr.Markdown("""
426
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
427
+ """)
428
+
429
+ with gr.Column():
430
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
431
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
432
 
433
+ with gr.Row():
434
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
435
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
436
 
437
+ is_multiimage = gr.State(False)
438
+ output_buf = gr.State()
439
+
440
+ # Example images at the bottom of the page
441
+ with gr.Row() as single_image_example:
442
+ examples = gr.Examples(
443
+ examples=[
444
+ f'assets/example_image/{image}'
445
+ for image in os.listdir("assets/example_image")
446
+ ],
447
+ inputs=[image_prompt],
448
+ fn=preprocess_image,
449
+ outputs=[image_prompt],
450
+ run_on_click=True,
451
+ examples_per_page=64,
452
+ )
453
+ with gr.Row(visible=False) as multiimage_example:
454
+ examples_multi = gr.Examples(
455
+ examples=prepare_multi_example(),
456
+ inputs=[image_prompt],
457
+ fn=split_image,
458
+ outputs=[multiimage_prompt],
459
+ run_on_click=True,
460
+ examples_per_page=8,
461
+ )
462
+
463
+ # Handlers
464
+ demo.load(start_session)
465
+ demo.unload(end_session)
466
 
467
+ single_image_input_tab.select(
468
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
469
+ outputs=[is_multiimage, single_image_example, multiimage_example]
470
+ )
471
+ multiimage_input_tab.select(
472
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
473
+ outputs=[is_multiimage, single_image_example, multiimage_example]
474
+ )
475
 
476
+ image_prompt.upload(
477
+ preprocess_image,
478
+ inputs=[image_prompt],
479
+ outputs=[image_prompt],
480
+ )
481
+ multiimage_prompt.upload(
482
+ preprocess_images,
483
+ inputs=[multiimage_prompt],
484
+ outputs=[multiimage_prompt],
485
+ )
486
+
487
+ generate_btn.click(
488
+ get_seed,
489
+ inputs=[randomize_seed, seed],
490
+ outputs=[seed],
491
+ ).then(
492
+ image_to_3d,
493
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
494
+ outputs=[output_buf, video_output],
495
+ ).then(
496
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
497
+ outputs=[extract_glb_btn, extract_gs_btn],
498
+ )
499
+
500
+ video_output.clear(
501
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
502
+ outputs=[extract_glb_btn, extract_gs_btn],
503
+ )
504
+
505
+ extract_glb_btn.click(
506
+ extract_glb,
507
+ inputs=[output_buf, mesh_simplify, texture_size],
508
+ outputs=[model_output, download_glb],
509
+ ).then(
510
+ lambda: gr.Button(interactive=True),
511
+ outputs=[download_glb],
512
+ )
513
 
514
+ extract_gs_btn.click(
515
+ extract_gaussian,
516
+ inputs=[output_buf],
517
+ outputs=[model_output, download_gs],
518
+ ).then(
519
+ lambda: gr.Button(interactive=True),
520
+ outputs=[download_gs],
521
+ )
522
+
523
+ model_output.clear(
524
+ lambda: gr.Button(interactive=False),
525
+ outputs=[download_glb],
526
+ )
527
 
528
 
529
  # Launch the Gradio app
 
536
  pipeline.cuda()
537
 
538
  print("STARTING SERVER ...")
539
+ app = mount_gradio_app(app, demo, path="/")
540
  uvicorn.run(app, host="0.0.0.0", port=7860)
541
  print("SERVER STARTED!!!")