rahul7star commited on
Commit
09a6fb7
·
verified ·
1 Parent(s): 220e1fc

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +18 -27
optimization.py CHANGED
@@ -39,11 +39,12 @@ INDUCTOR_CONFIGS = {
39
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
40
  print("[optimize_pipeline_] Starting pipeline optimization")
41
 
42
- # Quantize and compile text encoder first (weight-only int8 quantization can be replaced by autoquant if preferred)
 
43
  pipeline.text_encoder = torchao.autoquant(
44
  torch.compile(pipeline.text_encoder, mode='max-autotune'),
45
- qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST # or remove for default quant
46
- )
47
  print("[optimize_pipeline_] Text encoder autoquantized and compiled")
48
 
49
  @spaces.GPU(duration=1500)
@@ -67,7 +68,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
67
  pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
68
  pipeline.unload_lora_weights()
69
 
70
- print("[compile_transformer] Running dummy forward pass to capture component call")
71
  with torch.inference_mode():
72
  with capture_component_call(pipeline, 'transformer') as call:
73
  pipeline(*args, **kwargs)
@@ -75,44 +76,36 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
75
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
76
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
77
 
78
- # Use autoquant + torch.compile on transformers
79
- print("[compile_transformer] Autoquantizing and compiling transformer")
 
 
80
  compiled_transformer = torchao.autoquant(
81
  torch.compile(pipeline.transformer, mode='max-autotune'),
82
- qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
83
  )
84
  compiled_transformer_2 = torchao.autoquant(
85
  torch.compile(pipeline.transformer_2, mode='max-autotune'),
86
- qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
87
  )
88
 
 
89
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
90
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
91
 
92
- if hidden_states.shape[-1] > hidden_states.shape[-2]:
93
- hidden_states_landscape = hidden_states
94
- hidden_states_portrait = hidden_states_transposed
95
- else:
96
- hidden_states_landscape = hidden_states_transposed
97
- hidden_states_portrait = hidden_states
98
-
99
- # Replace forward with quantized & compiled versions, wrapped for shape dispatch
100
  def combined_transformer_1(*a, **k):
101
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
102
  return compiled_transformer(*a, **k)
103
- else:
104
- # Swap hidden states for portrait? Use transpose if needed.
105
- k_mod = k.copy()
106
- k_mod['hidden_states'] = hidden_states_portrait
107
- return compiled_transformer(*a, **k_mod)
108
 
109
  def combined_transformer_2(*a, **k):
110
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
111
  return compiled_transformer_2(*a, **k)
112
- else:
113
- k_mod = k.copy()
114
- k_mod['hidden_states'] = hidden_states_portrait
115
- return compiled_transformer_2(*a, **k_mod)
116
 
117
  pipeline.transformer.forward = combined_transformer_1
118
  drain_module_parameters(pipeline.transformer)
@@ -121,8 +114,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
121
  drain_module_parameters(pipeline.transformer_2)
122
 
123
  print("[compile_transformer] Transformers autoquantized, compiled, and patched")
124
-
125
- # Return compiled models for reference if needed
126
  return compiled_transformer, compiled_transformer_2
127
 
128
  cl1, cl2 = compile_transformer()
 
39
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
40
  print("[optimize_pipeline_] Starting pipeline optimization")
41
 
42
+ # Text encoder: move to CPU first, then quantize+compile to avoid early CUDA init
43
+ pipeline.text_encoder = pipeline.text_encoder.cpu()
44
  pipeline.text_encoder = torchao.autoquant(
45
  torch.compile(pipeline.text_encoder, mode='max-autotune'),
46
+ qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
47
+ ).to("cuda")
48
  print("[optimize_pipeline_] Text encoder autoquantized and compiled")
49
 
50
  @spaces.GPU(duration=1500)
 
68
  pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
69
  pipeline.unload_lora_weights()
70
 
71
+ print("[compile_transformer] Running dummy forward pass to capture shapes")
72
  with torch.inference_mode():
73
  with capture_component_call(pipeline, 'transformer') as call:
74
  pipeline(*args, **kwargs)
 
76
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
77
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
78
 
79
+ # Now that we're inside GPU context, move and compile transformers
80
+ pipeline.transformer = pipeline.transformer.to("cuda")
81
+ pipeline.transformer_2 = pipeline.transformer_2.to("cuda")
82
+
83
  compiled_transformer = torchao.autoquant(
84
  torch.compile(pipeline.transformer, mode='max-autotune'),
85
+ qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
86
  )
87
  compiled_transformer_2 = torchao.autoquant(
88
  torch.compile(pipeline.transformer_2, mode='max-autotune'),
89
+ qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
90
  )
91
 
92
+ # Patch forward with landscape/portrait logic
93
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
94
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
95
 
 
 
 
 
 
 
 
 
96
  def combined_transformer_1(*a, **k):
97
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
98
  return compiled_transformer(*a, **k)
99
+ k_mod = k.copy()
100
+ k_mod['hidden_states'] = hidden_states_transposed
101
+ return compiled_transformer(*a, **k_mod)
 
 
102
 
103
  def combined_transformer_2(*a, **k):
104
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
105
  return compiled_transformer_2(*a, **k)
106
+ k_mod = k.copy()
107
+ k_mod['hidden_states'] = hidden_states_transposed
108
+ return compiled_transformer_2(*a, **k_mod)
 
109
 
110
  pipeline.transformer.forward = combined_transformer_1
111
  drain_module_parameters(pipeline.transformer)
 
114
  drain_module_parameters(pipeline.transformer_2)
115
 
116
  print("[compile_transformer] Transformers autoquantized, compiled, and patched")
 
 
117
  return compiled_transformer, compiled_transformer_2
118
 
119
  cl1, cl2 = compile_transformer()