jupyterjazz commited on
Commit
725b8ba
·
1 Parent(s): 70044fb

fix: image pooling

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +18 -12
modeling_jina_embeddings_v4.py CHANGED
@@ -216,22 +216,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
216
  Project the hidden states to single-vector embeddings.
217
  """
218
  if self._input_has_image(input_ids[0]): # got document image
219
- img_start_pos = torch.where(
220
- input_ids[0] == self.config.vision_start_token_id
221
- )[0][0]
222
- img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[
223
- 0
224
- ][0]
225
- pooled_output = (
226
- hidden_states[0][img_start_pos : img_end_pos + 1]
227
- .mean(dim=0)
228
- .unsqueeze(0)
229
- )
230
 
231
  else: # got query text
232
  pooled_output = torch.sum(
233
  hidden_states * attention_mask.unsqueeze(-1), dim=1
234
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
 
235
  single_vec_emb = self.single_vector_projector(pooled_output)
236
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
237
 
@@ -310,14 +309,19 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
310
  with torch.no_grad():
311
  batch = {k: v.to(self.device) for k, v in batch.items()}
312
  with torch.autocast(device_type=torch.device(self.device).type):
 
 
 
 
 
313
  embeddings = self(**batch)
 
314
  if vector_type == "single_vector":
315
  embeddings = embeddings.single_vec_emb
316
  if truncate_dim is not None:
317
  embeddings = embeddings[:, :truncate_dim]
318
  else:
319
  embeddings = embeddings.multi_vec_emb
320
-
321
  results.append(
322
  embeddings.cpu()
323
  if return_numpy
@@ -442,6 +446,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
442
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
443
 
444
  is_single = len(images) == 1
 
 
445
  embeddings = self._process_batches(
446
  data=images,
447
  processor_fn=self.processor.process_images,
 
216
  Project the hidden states to single-vector embeddings.
217
  """
218
  if self._input_has_image(input_ids[0]): # got document image
219
+ img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
220
+ img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
221
+
222
+ batch_size, seq_len = input_ids.shape
223
+ position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
224
+ image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
225
+
226
+ masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
227
+ pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
 
 
228
 
229
  else: # got query text
230
  pooled_output = torch.sum(
231
  hidden_states * attention_mask.unsqueeze(-1), dim=1
232
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
233
+
234
  single_vec_emb = self.single_vector_projector(pooled_output)
235
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
236
 
 
309
  with torch.no_grad():
310
  batch = {k: v.to(self.device) for k, v in batch.items()}
311
  with torch.autocast(device_type=torch.device(self.device).type):
312
+ for key, value in batch.items():
313
+ if hasattr(value, 'shape'):
314
+ print(f"{key}: {value.shape}")
315
+ else:
316
+ print(f"{key}: {type(value)}")
317
  embeddings = self(**batch)
318
+ print(embeddings.single_vec_emb.shape, embeddings.multi_vec_emb.shape)
319
  if vector_type == "single_vector":
320
  embeddings = embeddings.single_vec_emb
321
  if truncate_dim is not None:
322
  embeddings = embeddings[:, :truncate_dim]
323
  else:
324
  embeddings = embeddings.multi_vec_emb
 
325
  results.append(
326
  embeddings.cpu()
327
  if return_numpy
 
446
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
447
 
448
  is_single = len(images) == 1
449
+ print(is_single)
450
+ print(len(images))
451
  embeddings = self._process_batches(
452
  data=images,
453
  processor_fn=self.processor.process_images,