Brian Tang
commited on
Commit
·
8f0a794
1
Parent(s):
49ebb9c
Adds flash attention check with the device type
Browse files
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -569,7 +569,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 569 |
kwargs["torch_dtype"] = "auto"
|
| 570 |
|
| 571 |
kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
|
| 572 |
-
|
|
|
|
| 573 |
kwargs["attn_implementation"] = "sdpa"
|
| 574 |
|
| 575 |
base_model = super().from_pretrained(
|
|
|
|
| 569 |
kwargs["torch_dtype"] = "auto"
|
| 570 |
|
| 571 |
kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
|
| 572 |
+
device = kwargs.get("device", "auto")
|
| 573 |
+
if not is_flash_attn_2_available() or device == "cpu":
|
| 574 |
kwargs["attn_implementation"] = "sdpa"
|
| 575 |
|
| 576 |
base_model = super().from_pretrained(
|