Commit 
							
							·
						
						690b53e
	
1
								Parent(s):
							
							bd46f72
								
Speed UP!
Browse files
    	
        app.py
    CHANGED
    
    | @@ -3,6 +3,7 @@ import spaces | |
| 3 | 
             
            from gradio_litmodel3d import LitModel3D
         | 
| 4 |  | 
| 5 | 
             
            import os
         | 
|  | |
| 6 | 
             
            from typing import *
         | 
| 7 | 
             
            import torch
         | 
| 8 | 
             
            import numpy as np
         | 
| @@ -131,7 +132,7 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s | |
| 131 | 
             
                    str: The path to the extracted GLB file.
         | 
| 132 | 
             
                """
         | 
| 133 | 
             
                gs, mesh, model_id = unpack_state(state)
         | 
| 134 | 
            -
                glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size)
         | 
| 135 | 
             
                glb_path = f"/tmp/Trellis-demo/{model_id}.glb"
         | 
| 136 | 
             
                glb.export(glb_path)
         | 
| 137 | 
             
                return glb_path, glb_path
         | 
| @@ -161,12 +162,12 @@ with gr.Blocks() as demo: | |
| 161 | 
             
                            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
         | 
| 162 | 
             
                            gr.Markdown("Stage 1: Sparse Structure Generation")
         | 
| 163 | 
             
                            with gr.Row():
         | 
| 164 | 
            -
                                ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=5 | 
| 165 | 
            -
                                ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value= | 
| 166 | 
             
                            gr.Markdown("Stage 2: Structured Latent Generation")
         | 
| 167 | 
             
                            with gr.Row():
         | 
| 168 | 
            -
                                slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value= | 
| 169 | 
            -
                                slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value= | 
| 170 |  | 
| 171 | 
             
                        generate_btn = gr.Button("Generate")
         | 
| 172 |  | 
|  | |
| 3 | 
             
            from gradio_litmodel3d import LitModel3D
         | 
| 4 |  | 
| 5 | 
             
            import os
         | 
| 6 | 
            +
            os.environ['SPCONV_ALGO'] = 'native'
         | 
| 7 | 
             
            from typing import *
         | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            import numpy as np
         | 
|  | |
| 132 | 
             
                    str: The path to the extracted GLB file.
         | 
| 133 | 
             
                """
         | 
| 134 | 
             
                gs, mesh, model_id = unpack_state(state)
         | 
| 135 | 
            +
                glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
         | 
| 136 | 
             
                glb_path = f"/tmp/Trellis-demo/{model_id}.glb"
         | 
| 137 | 
             
                glb.export(glb_path)
         | 
| 138 | 
             
                return glb_path, glb_path
         | 
|  | |
| 162 | 
             
                            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
         | 
| 163 | 
             
                            gr.Markdown("Stage 1: Sparse Structure Generation")
         | 
| 164 | 
             
                            with gr.Row():
         | 
| 165 | 
            +
                                ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
         | 
| 166 | 
            +
                                ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
         | 
| 167 | 
             
                            gr.Markdown("Stage 2: Structured Latent Generation")
         | 
| 168 | 
             
                            with gr.Row():
         | 
| 169 | 
            +
                                slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
         | 
| 170 | 
            +
                                slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
         | 
| 171 |  | 
| 172 | 
             
                        generate_btn = gr.Button("Generate")
         | 
| 173 |  | 
    	
        trellis/modules/sparse/__init__.py
    CHANGED
    
    | @@ -24,6 +24,8 @@ def __from_env(): | |
| 24 | 
             
                if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
         | 
| 25 | 
             
                    ATTN = env_sparse_attn
         | 
| 26 |  | 
|  | |
|  | |
| 27 |  | 
| 28 | 
             
            __from_env()
         | 
| 29 |  | 
|  | |
| 24 | 
             
                if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
         | 
| 25 | 
             
                    ATTN = env_sparse_attn
         | 
| 26 |  | 
| 27 | 
            +
                print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
         | 
| 28 | 
            +
                    
         | 
| 29 |  | 
| 30 | 
             
            __from_env()
         | 
| 31 |  | 
    	
        trellis/modules/sparse/conv/__init__.py
    CHANGED
    
    | @@ -1,6 +1,21 @@ | |
| 1 | 
             
            from .. import BACKEND
         | 
| 2 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 | 
             
            if BACKEND == 'torchsparse':
         | 
| 4 | 
             
                from .conv_torchsparse import *
         | 
| 5 | 
             
            elif BACKEND == 'spconv':
         | 
| 6 | 
            -
                from .conv_spconv import *
         | 
|  | |
| 1 | 
             
            from .. import BACKEND
         | 
| 2 |  | 
| 3 | 
            +
             | 
| 4 | 
            +
            SPCONV_ALGO = 'auto'    # 'auto', 'implicit_gemm', 'native'
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            def __from_env():
         | 
| 7 | 
            +
                import os
         | 
| 8 | 
            +
                    
         | 
| 9 | 
            +
                global SPCONV_ALGO
         | 
| 10 | 
            +
                env_spconv_algo = os.environ.get('SPCONV_ALGO')
         | 
| 11 | 
            +
                if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
         | 
| 12 | 
            +
                    SPCONV_ALGO = env_spconv_algo
         | 
| 13 | 
            +
                print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
         | 
| 14 | 
            +
                    
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            __from_env()
         | 
| 17 | 
            +
             | 
| 18 | 
             
            if BACKEND == 'torchsparse':
         | 
| 19 | 
             
                from .conv_torchsparse import *
         | 
| 20 | 
             
            elif BACKEND == 'spconv':
         | 
| 21 | 
            +
                from .conv_spconv import *
         | 
    	
        trellis/modules/sparse/conv/conv_spconv.py
    CHANGED
    
    | @@ -2,16 +2,22 @@ import torch | |
| 2 | 
             
            import torch.nn as nn
         | 
| 3 | 
             
            from .. import SparseTensor
         | 
| 4 | 
             
            from .. import DEBUG
         | 
|  | |
| 5 |  | 
| 6 | 
             
            class SparseConv3d(nn.Module):
         | 
| 7 | 
             
                def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
         | 
| 8 | 
             
                    super(SparseConv3d, self).__init__()
         | 
| 9 | 
             
                    if 'spconv' not in globals():
         | 
| 10 | 
             
                        import spconv.pytorch as spconv
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 | 
             
                    if stride == 1 and (padding is None):
         | 
| 12 | 
            -
                        self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key)
         | 
| 13 | 
             
                    else:
         | 
| 14 | 
            -
                        self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key)
         | 
| 15 | 
             
                    self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
         | 
| 16 | 
             
                    self.padding = padding
         | 
| 17 |  | 
|  | |
| 2 | 
             
            import torch.nn as nn
         | 
| 3 | 
             
            from .. import SparseTensor
         | 
| 4 | 
             
            from .. import DEBUG
         | 
| 5 | 
            +
            from . import SPCONV_ALGO
         | 
| 6 |  | 
| 7 | 
             
            class SparseConv3d(nn.Module):
         | 
| 8 | 
             
                def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
         | 
| 9 | 
             
                    super(SparseConv3d, self).__init__()
         | 
| 10 | 
             
                    if 'spconv' not in globals():
         | 
| 11 | 
             
                        import spconv.pytorch as spconv
         | 
| 12 | 
            +
                    algo = None
         | 
| 13 | 
            +
                    if SPCONV_ALGO == 'native':
         | 
| 14 | 
            +
                        algo = spconv.ConvAlgo.Native
         | 
| 15 | 
            +
                    elif SPCONV_ALGO == 'implicit_gemm':
         | 
| 16 | 
            +
                        algo = spconv.ConvAlgo.MaskImplicitGemm
         | 
| 17 | 
             
                    if stride == 1 and (padding is None):
         | 
| 18 | 
            +
                        self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
         | 
| 19 | 
             
                    else:
         | 
| 20 | 
            +
                        self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
         | 
| 21 | 
             
                    self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
         | 
| 22 | 
             
                    self.padding = padding
         | 
| 23 |  | 
    	
        trellis/utils/postprocessing_utils.py
    CHANGED
    
    | @@ -448,7 +448,7 @@ def to_glb( | |
| 448 | 
             
                    observations, masks, extrinsics, intrinsics,
         | 
| 449 | 
             
                    texture_size=texture_size, mode='opt',
         | 
| 450 | 
             
                    lambda_tv=0.01,
         | 
| 451 | 
            -
                    verbose= | 
| 452 | 
             
                )
         | 
| 453 | 
             
                texture = Image.fromarray(texture)
         | 
| 454 |  | 
|  | |
| 448 | 
             
                    observations, masks, extrinsics, intrinsics,
         | 
| 449 | 
             
                    texture_size=texture_size, mode='opt',
         | 
| 450 | 
             
                    lambda_tv=0.01,
         | 
| 451 | 
            +
                    verbose=verbose
         | 
| 452 | 
             
                )
         | 
| 453 | 
             
                texture = Image.fromarray(texture)
         | 
| 454 |  |