Maximilian Werk
commited on
Commit
·
cf456d3
1
Parent(s):
b7707d5
feat: reduced default noise of the model
Browse files
configuration_jina_embeddings_v4.py
CHANGED
|
@@ -2,6 +2,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
|
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
|
|
|
| 5 |
class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
|
| 6 |
"""
|
| 7 |
Configuration for the JinaEmbeddingsV4 model.
|
|
@@ -12,10 +13,11 @@ class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
|
|
| 12 |
single_vector_pool_strategy: str = "mean",
|
| 13 |
multi_vector_projector_dim: int = 128,
|
| 14 |
pretrained_peft_model_name_or_path: Optional[str] = None,
|
|
|
|
| 15 |
**kwargs,
|
| 16 |
):
|
| 17 |
super().__init__(**kwargs)
|
| 18 |
self.single_vector_pool_strategy = single_vector_pool_strategy
|
| 19 |
self.multi_vector_projector_dim = multi_vector_projector_dim
|
| 20 |
self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
|
| 21 |
-
|
|
|
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
+
|
| 6 |
class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
|
| 7 |
"""
|
| 8 |
Configuration for the JinaEmbeddingsV4 model.
|
|
|
|
| 13 |
single_vector_pool_strategy: str = "mean",
|
| 14 |
multi_vector_projector_dim: int = 128,
|
| 15 |
pretrained_peft_model_name_or_path: Optional[str] = None,
|
| 16 |
+
verbosity: int = 0,
|
| 17 |
**kwargs,
|
| 18 |
):
|
| 19 |
super().__init__(**kwargs)
|
| 20 |
self.single_vector_pool_strategy = single_vector_pool_strategy
|
| 21 |
self.multi_vector_projector_dim = multi_vector_projector_dim
|
| 22 |
self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
|
| 23 |
+
self.verbosity = verbosity
|
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -146,6 +146,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 146 |
self.name_or_path, trust_remote_code=True, use_fast=True
|
| 147 |
)
|
| 148 |
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
|
|
|
| 149 |
self._task = None
|
| 150 |
|
| 151 |
@property
|
|
@@ -335,7 +336,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 335 |
assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
| 336 |
results = []
|
| 337 |
self.eval()
|
| 338 |
-
for batch in tqdm(dataloader, desc=desc):
|
| 339 |
with torch.no_grad():
|
| 340 |
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 341 |
with torch.autocast(
|
|
@@ -349,7 +350,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 349 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
|
| 350 |
else:
|
| 351 |
embeddings = embeddings.multi_vec_emb
|
| 352 |
-
|
| 353 |
if return_multivector and not return_numpy:
|
| 354 |
valid_tokens = batch["attention_mask"].bool()
|
| 355 |
embeddings = [
|
|
@@ -453,7 +454,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 453 |
if return_numpy:
|
| 454 |
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
|
| 455 |
return_numpy = False
|
| 456 |
-
|
| 457 |
if isinstance(texts, str):
|
| 458 |
texts = [texts]
|
| 459 |
|
|
@@ -468,7 +469,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 468 |
**encode_kwargs,
|
| 469 |
)
|
| 470 |
|
| 471 |
-
return embeddings if return_list else embeddings[0]
|
| 472 |
|
| 473 |
def _load_images_if_needed(
|
| 474 |
self, images: List[Union[str, Image.Image]]
|
|
@@ -515,7 +516,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 515 |
)
|
| 516 |
encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
|
| 517 |
task = self._validate_task(task)
|
| 518 |
-
|
| 519 |
return_list = isinstance(images, list)
|
| 520 |
|
| 521 |
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
|
@@ -527,7 +528,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 527 |
# Convert single image to list
|
| 528 |
if isinstance(images, (str, Image.Image)):
|
| 529 |
images = [images]
|
| 530 |
-
|
| 531 |
images = self._load_images_if_needed(images)
|
| 532 |
embeddings = self._process_batches(
|
| 533 |
data=images,
|
|
|
|
| 146 |
self.name_or_path, trust_remote_code=True, use_fast=True
|
| 147 |
)
|
| 148 |
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
| 149 |
+
self.verbosity = config.verbosity
|
| 150 |
self._task = None
|
| 151 |
|
| 152 |
@property
|
|
|
|
| 336 |
assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
| 337 |
results = []
|
| 338 |
self.eval()
|
| 339 |
+
for batch in tqdm(dataloader, desc=desc, disable=self.verbosity == 0):
|
| 340 |
with torch.no_grad():
|
| 341 |
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 342 |
with torch.autocast(
|
|
|
|
| 350 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
|
| 351 |
else:
|
| 352 |
embeddings = embeddings.multi_vec_emb
|
| 353 |
+
|
| 354 |
if return_multivector and not return_numpy:
|
| 355 |
valid_tokens = batch["attention_mask"].bool()
|
| 356 |
embeddings = [
|
|
|
|
| 454 |
if return_numpy:
|
| 455 |
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
|
| 456 |
return_numpy = False
|
| 457 |
+
|
| 458 |
if isinstance(texts, str):
|
| 459 |
texts = [texts]
|
| 460 |
|
|
|
|
| 469 |
**encode_kwargs,
|
| 470 |
)
|
| 471 |
|
| 472 |
+
return embeddings if return_list else embeddings[0]
|
| 473 |
|
| 474 |
def _load_images_if_needed(
|
| 475 |
self, images: List[Union[str, Image.Image]]
|
|
|
|
| 516 |
)
|
| 517 |
encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
|
| 518 |
task = self._validate_task(task)
|
| 519 |
+
|
| 520 |
return_list = isinstance(images, list)
|
| 521 |
|
| 522 |
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
|
|
|
| 528 |
# Convert single image to list
|
| 529 |
if isinstance(images, (str, Image.Image)):
|
| 530 |
images = [images]
|
| 531 |
+
|
| 532 |
images = self._load_images_if_needed(images)
|
| 533 |
embeddings = self._process_batches(
|
| 534 |
data=images,
|