update modeling_qwen.py
Browse files- modeling_qwen.py +1 -1
modeling_qwen.py
CHANGED
|
@@ -520,7 +520,7 @@ class QWenAttention(nn.Module):
|
|
| 520 |
|
| 521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
| 522 |
if attention_mask is not None:
|
| 523 |
-
attention_mask = attention_mask.expand(-1, -1,
|
| 524 |
if causal_mask is not None:
|
| 525 |
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
| 526 |
else:
|
|
|
|
| 520 |
|
| 521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
| 522 |
if attention_mask is not None:
|
| 523 |
+
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
|
| 524 |
if causal_mask is not None:
|
| 525 |
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
| 526 |
else:
|