Upload handler.py
Browse files- handler.py +5 -7
 
    	
        handler.py
    CHANGED
    
    | 
         @@ -9,8 +9,12 @@ import torch 
     | 
|
| 9 | 
         
             
            from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight
         
     | 
| 10 | 
         
             
            from torchao.quantization.quant_api import PerRow
         
     | 
| 11 | 
         
             
            from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
         
     | 
| 
         | 
|
| 12 | 
         | 
| 13 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 14 | 
         
             
            IS_COMPILE = True
         
     | 
| 15 | 
         
             
            IS_TURBO = False
         
     | 
| 16 | 
         
             
            IS_4BIT = False
         
     | 
| 
         @@ -19,16 +23,10 @@ IS_4BIT = False 
     | 
|
| 19 | 
         
             
            # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
         
     | 
| 20 | 
         
             
            if IS_NEW_GPU: torch.set_float32_matmul_precision("high")
         
     | 
| 21 | 
         | 
| 22 | 
         
            -
            import subprocess
         
     | 
| 23 | 
         
            -
            subprocess.run("nvcc -V", shell=True)
         
     | 
| 24 | 
         
            -
            subprocess.run("pip list", shell=True)
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
             
            if IS_COMPILE:
         
     | 
| 27 | 
         
             
                import torch._dynamo
         
     | 
| 28 | 
         
             
                torch._dynamo.config.suppress_errors = True
         
     | 
| 29 | 
         | 
| 30 | 
         
            -
            from huggingface_inference_toolkit.logging import logger
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
             
            def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
         
     | 
| 33 | 
         
             
                quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
         
     | 
| 34 | 
         
             
                vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
         
     | 
| 
         | 
|
| 9 | 
         
             
            from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight
         
     | 
| 10 | 
         
             
            from torchao.quantization.quant_api import PerRow
         
     | 
| 11 | 
         
             
            from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
         
     | 
| 12 | 
         
            +
            from huggingface_inference_toolkit.logging import logger
         
     | 
| 13 | 
         | 
| 14 | 
         
            +
            import subprocess
         
     | 
| 15 | 
         
            +
            subprocess.run("pip list", shell=True)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            IS_NEW_GPU = False
         
     | 
| 18 | 
         
             
            IS_COMPILE = True
         
     | 
| 19 | 
         
             
            IS_TURBO = False
         
     | 
| 20 | 
         
             
            IS_4BIT = False
         
     | 
| 
         | 
|
| 23 | 
         
             
            # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
         
     | 
| 24 | 
         
             
            if IS_NEW_GPU: torch.set_float32_matmul_precision("high")
         
     | 
| 25 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 26 | 
         
             
            if IS_COMPILE:
         
     | 
| 27 | 
         
             
                import torch._dynamo
         
     | 
| 28 | 
         
             
                torch._dynamo.config.suppress_errors = True
         
     | 
| 29 | 
         | 
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
            def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
         
     | 
| 31 | 
         
             
                quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
         
     | 
| 32 | 
         
             
                vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
         
     |