Commit
·
725b8ba
1
Parent(s):
70044fb
fix: image pooling
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- modeling_jina_embeddings_v4.py +18 -12
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -216,22 +216,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 216 |
Project the hidden states to single-vector embeddings.
|
| 217 |
"""
|
| 218 |
if self._input_has_image(input_ids[0]): # got document image
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
.unsqueeze(0)
|
| 229 |
-
)
|
| 230 |
|
| 231 |
else: # got query text
|
| 232 |
pooled_output = torch.sum(
|
| 233 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
| 234 |
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
|
|
|
| 235 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
| 236 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
| 237 |
|
|
@@ -310,14 +309,19 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 310 |
with torch.no_grad():
|
| 311 |
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 312 |
with torch.autocast(device_type=torch.device(self.device).type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
embeddings = self(**batch)
|
|
|
|
| 314 |
if vector_type == "single_vector":
|
| 315 |
embeddings = embeddings.single_vec_emb
|
| 316 |
if truncate_dim is not None:
|
| 317 |
embeddings = embeddings[:, :truncate_dim]
|
| 318 |
else:
|
| 319 |
embeddings = embeddings.multi_vec_emb
|
| 320 |
-
|
| 321 |
results.append(
|
| 322 |
embeddings.cpu()
|
| 323 |
if return_numpy
|
|
@@ -442,6 +446,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 442 |
encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
|
| 443 |
|
| 444 |
is_single = len(images) == 1
|
|
|
|
|
|
|
| 445 |
embeddings = self._process_batches(
|
| 446 |
data=images,
|
| 447 |
processor_fn=self.processor.process_images,
|
|
|
|
| 216 |
Project the hidden states to single-vector embeddings.
|
| 217 |
"""
|
| 218 |
if self._input_has_image(input_ids[0]): # got document image
|
| 219 |
+
img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
|
| 220 |
+
img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
|
| 221 |
+
|
| 222 |
+
batch_size, seq_len = input_ids.shape
|
| 223 |
+
position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
|
| 224 |
+
image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
|
| 225 |
+
|
| 226 |
+
masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
|
| 227 |
+
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
|
|
|
|
|
|
|
| 228 |
|
| 229 |
else: # got query text
|
| 230 |
pooled_output = torch.sum(
|
| 231 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
| 232 |
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
| 233 |
+
|
| 234 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
| 235 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
| 236 |
|
|
|
|
| 309 |
with torch.no_grad():
|
| 310 |
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 311 |
with torch.autocast(device_type=torch.device(self.device).type):
|
| 312 |
+
for key, value in batch.items():
|
| 313 |
+
if hasattr(value, 'shape'):
|
| 314 |
+
print(f"{key}: {value.shape}")
|
| 315 |
+
else:
|
| 316 |
+
print(f"{key}: {type(value)}")
|
| 317 |
embeddings = self(**batch)
|
| 318 |
+
print(embeddings.single_vec_emb.shape, embeddings.multi_vec_emb.shape)
|
| 319 |
if vector_type == "single_vector":
|
| 320 |
embeddings = embeddings.single_vec_emb
|
| 321 |
if truncate_dim is not None:
|
| 322 |
embeddings = embeddings[:, :truncate_dim]
|
| 323 |
else:
|
| 324 |
embeddings = embeddings.multi_vec_emb
|
|
|
|
| 325 |
results.append(
|
| 326 |
embeddings.cpu()
|
| 327 |
if return_numpy
|
|
|
|
| 446 |
encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
|
| 447 |
|
| 448 |
is_single = len(images) == 1
|
| 449 |
+
print(is_single)
|
| 450 |
+
print(len(images))
|
| 451 |
embeddings = self._process_batches(
|
| 452 |
data=images,
|
| 453 |
processor_fn=self.processor.process_images,
|