Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update optimization.py
Browse files- optimization.py +25 -10
 
    	
        optimization.py
    CHANGED
    
    | 
         @@ -15,7 +15,7 @@ from torchao.quantization import Int8WeightOnlyConfig 
     | 
|
| 15 | 
         
             
            from optimization_utils import capture_component_call
         
     | 
| 16 | 
         
             
            from optimization_utils import aoti_compile
         
     | 
| 17 | 
         
             
            from optimization_utils import ZeroGPUCompiledModel
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         | 
| 20 | 
         
             
            P = ParamSpec('P')
         
     | 
| 21 | 
         | 
| 
         @@ -43,6 +43,26 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw 
     | 
|
| 43 | 
         
             
                @spaces.GPU(duration=1500)
         
     | 
| 44 | 
         
             
                def compile_transformer():
         
     | 
| 45 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 46 | 
         
             
                    with capture_component_call(pipeline, 'transformer') as call:
         
     | 
| 47 | 
         
             
                        pipeline(*args, **kwargs)
         
     | 
| 48 | 
         | 
| 
         @@ -105,13 +125,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw 
     | 
|
| 105 | 
         
             
                    else:
         
     | 
| 106 | 
         
             
                        return cp2(*args, **kwargs)
         
     | 
| 107 | 
         | 
| 108 | 
         
            -
                 
     | 
| 109 | 
         
            -
                 
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                pipeline.transformer = combined_transformer_1
         
     | 
| 112 | 
         
            -
                pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
         
     | 
| 113 | 
         
            -
                pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
         
     | 
| 114 | 
         | 
| 115 | 
         
            -
                pipeline.transformer_2 = combined_transformer_2
         
     | 
| 116 | 
         
            -
                pipeline.transformer_2 
     | 
| 117 | 
         
            -
                pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
         
     | 
| 
         | 
|
| 15 | 
         
             
            from optimization_utils import capture_component_call
         
     | 
| 16 | 
         
             
            from optimization_utils import aoti_compile
         
     | 
| 17 | 
         
             
            from optimization_utils import ZeroGPUCompiledModel
         
     | 
| 18 | 
         
            +
            from optimization_utils import drain_module_parameters
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            P = ParamSpec('P')
         
     | 
| 21 | 
         | 
| 
         | 
|
| 43 | 
         
             
                @spaces.GPU(duration=1500)
         
     | 
| 44 | 
         
             
                def compile_transformer():
         
     | 
| 45 | 
         | 
| 46 | 
         
            +
                    pipeline.load_lora_weights(
         
     | 
| 47 | 
         
            +
                        "Kijai/WanVideo_comfy", 
         
     | 
| 48 | 
         
            +
                        weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors", 
         
     | 
| 49 | 
         
            +
                        adapter_name="lightning"
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
                    kwargs_lora = {}
         
     | 
| 52 | 
         
            +
                    kwargs_lora["load_into_transformer_2"] = True
         
     | 
| 53 | 
         
            +
                    pipeline.load_lora_weights(
         
     | 
| 54 | 
         
            +
                        "Kijai/WanVideo_comfy", 
         
     | 
| 55 | 
         
            +
                         weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors", 
         
     | 
| 56 | 
         
            +
                        #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors", 
         
     | 
| 57 | 
         
            +
                        adapter_name="lightning_2", **kwargs_lora
         
     | 
| 58 | 
         
            +
                    )
         
     | 
| 59 | 
         
            +
                    pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
         
     | 
| 63 | 
         
            +
                    pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
         
     | 
| 64 | 
         
            +
                    pipeline.unload_lora_weights()
         
     | 
| 65 | 
         
            +
                    
         
     | 
| 66 | 
         
             
                    with capture_component_call(pipeline, 'transformer') as call:
         
     | 
| 67 | 
         
             
                        pipeline(*args, **kwargs)
         
     | 
| 68 | 
         | 
| 
         | 
|
| 125 | 
         
             
                    else:
         
     | 
| 126 | 
         
             
                        return cp2(*args, **kwargs)
         
     | 
| 127 | 
         | 
| 128 | 
         
            +
                pipeline.transformer.forward = combined_transformer_1
         
     | 
| 129 | 
         
            +
                drain_module_parameters(pipeline.transformer)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         | 
| 131 | 
         
            +
                pipeline.transformer_2.forward = combined_transformer_2
         
     | 
| 132 | 
         
            +
                drain_module_parameters(pipeline.transformer_2)
         
     | 
| 
         |