rm decorator
Browse files- modeling_bd3lm.py +0 -1
modeling_bd3lm.py
CHANGED
|
@@ -71,7 +71,6 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
|
| 71 |
# **4. Combine Masks **
|
| 72 |
return block_diagonal | offset_block_causal | block_causal
|
| 73 |
|
| 74 |
-
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 75 |
def fused_flex_attention(q, k, v, mask=None):
|
| 76 |
return flex_attention(q, k, v, block_mask=mask)
|
| 77 |
|
|
|
|
| 71 |
# **4. Combine Masks **
|
| 72 |
return block_diagonal | offset_block_causal | block_causal
|
| 73 |
|
|
|
|
| 74 |
def fused_flex_attention(q, k, v, mask=None):
|
| 75 |
return flex_attention(q, k, v, block_mask=mask)
|
| 76 |
|