Spaces:
Runtime error
Runtime error
Upload 11 files
Browse files
hyvideo/modules/attenion.py
CHANGED
|
@@ -178,7 +178,7 @@ def parallel_attention(
|
|
| 178 |
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
|
| 179 |
joint_strategy="rear",
|
| 180 |
)
|
| 181 |
-
if flash_attn.__version__ >=
|
| 182 |
attn2, *_ = _flash_attn_forward(
|
| 183 |
q[:,cu_seqlens_q[1]:],
|
| 184 |
k[:,cu_seqlens_kv[1]:],
|
|
|
|
| 178 |
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
|
| 179 |
joint_strategy="rear",
|
| 180 |
)
|
| 181 |
+
if flash_attn.__version__ >= '2.7.0':
|
| 182 |
attn2, *_ = _flash_attn_forward(
|
| 183 |
q[:,cu_seqlens_q[1]:],
|
| 184 |
k[:,cu_seqlens_kv[1]:],
|
hyvideo/modules/fp8_optimization.py
CHANGED
|
@@ -83,7 +83,7 @@ def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={
|
|
| 83 |
setattr(module, "fp8_matmul_enabled", True)
|
| 84 |
|
| 85 |
# loading fp8 mapping file
|
| 86 |
-
fp8_map_path = dit_weight_path.replace(
|
| 87 |
if os.path.exists(fp8_map_path):
|
| 88 |
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
|
| 89 |
else:
|
|
@@ -91,7 +91,7 @@ def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={
|
|
| 91 |
|
| 92 |
fp8_layers = []
|
| 93 |
for key, layer in module.named_modules():
|
| 94 |
-
if isinstance(layer, nn.Linear) and (
|
| 95 |
fp8_layers.append(key)
|
| 96 |
original_forward = layer.forward
|
| 97 |
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
|
|
|
|
| 83 |
setattr(module, "fp8_matmul_enabled", True)
|
| 84 |
|
| 85 |
# loading fp8 mapping file
|
| 86 |
+
fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
|
| 87 |
if os.path.exists(fp8_map_path):
|
| 88 |
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
|
| 89 |
else:
|
|
|
|
| 91 |
|
| 92 |
fp8_layers = []
|
| 93 |
for key, layer in module.named_modules():
|
| 94 |
+
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
|
| 95 |
fp8_layers.append(key)
|
| 96 |
original_forward = layer.forward
|
| 97 |
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
|