feat: return a single tensor when a single image is given
Browse files
    	
        modeling_jina_embeddings_v4.py
    CHANGED
    
    | 
         @@ -417,7 +417,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): 
     | 
|
| 417 | 
         
             
                    return_numpy: bool = False,
         
     | 
| 418 | 
         
             
                    truncate_dim: Optional[int] = None,
         
     | 
| 419 | 
         
             
                    prompt_name: Optional[str] = None,
         
     | 
| 420 | 
         
            -
                ) -> List[torch.Tensor]:
         
     | 
| 421 | 
         
             
                    """
         
     | 
| 422 | 
         
             
                    Encodes a list of texts into embeddings.
         
     | 
| 423 | 
         | 
| 
         @@ -431,7 +431,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): 
     | 
|
| 431 | 
         
             
                        prompt_name: Type of text being encoded ('query' or 'passage')
         
     | 
| 432 | 
         | 
| 433 | 
         
             
                    Returns:
         
     | 
| 434 | 
         
            -
                        List of text embeddings as tensors or numpy arrays
         
     | 
| 435 | 
         
             
                    """
         
     | 
| 436 | 
         
             
                    prompt_name = prompt_name or "query"
         
     | 
| 437 | 
         
             
                    encode_kwargs = self._validate_encoding_params(
         
     | 
| 
         @@ -459,7 +459,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): 
     | 
|
| 459 | 
         
             
                        **encode_kwargs,
         
     | 
| 460 | 
         
             
                    )
         
     | 
| 461 | 
         | 
| 462 | 
         
            -
                    return embeddings
         
     | 
| 463 | 
         | 
| 464 | 
         
             
                def _load_images_if_needed(
         
     | 
| 465 | 
         
             
                    self, images: List[Union[str, Image.Image]]
         
     | 
| 
         @@ -484,9 +484,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): 
     | 
|
| 484 | 
         
             
                    return_numpy: bool = False,
         
     | 
| 485 | 
         
             
                    truncate_dim: Optional[int] = None,
         
     | 
| 486 | 
         
             
                    max_pixels: Optional[int] = None,
         
     | 
| 487 | 
         
            -
                ) -> List[torch.Tensor]:
         
     | 
| 488 | 
         
             
                    """
         
     | 
| 489 | 
         
            -
                    Encodes a list of images into  
     | 
| 490 | 
         | 
| 491 | 
         
             
                    Args:
         
     | 
| 492 | 
         
             
                        images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
         
     | 
| 
         @@ -497,7 +497,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): 
     | 
|
| 497 | 
         
             
                        max_pixels: Maximum number of pixels to process per image
         
     | 
| 498 | 
         | 
| 499 | 
         
             
                    Returns:
         
     | 
| 500 | 
         
            -
                        List of image embeddings as tensors or numpy arrays
         
     | 
| 501 | 
         
             
                    """
         
     | 
| 502 | 
         
             
                    if max_pixels:
         
     | 
| 503 | 
         
             
                        default_max_pixels = self.processor.image_processor.max_pixels
         
     | 
| 
         @@ -525,7 +525,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): 
     | 
|
| 525 | 
         
             
                    if max_pixels:
         
     | 
| 526 | 
         
             
                        self.processor.image_processor.max_pixels = default_max_pixels
         
     | 
| 527 | 
         | 
| 528 | 
         
            -
                    return embeddings
         
     | 
| 529 | 
         | 
| 530 | 
         
             
                @classmethod
         
     | 
| 531 | 
         
             
                def from_pretrained(
         
     | 
| 
         | 
|
| 417 | 
         
             
                    return_numpy: bool = False,
         
     | 
| 418 | 
         
             
                    truncate_dim: Optional[int] = None,
         
     | 
| 419 | 
         
             
                    prompt_name: Optional[str] = None,
         
     | 
| 420 | 
         
            +
                ) -> Union[List[torch.Tensor], torch.Tensor]:
         
     | 
| 421 | 
         
             
                    """
         
     | 
| 422 | 
         
             
                    Encodes a list of texts into embeddings.
         
     | 
| 423 | 
         | 
| 
         | 
|
| 431 | 
         
             
                        prompt_name: Type of text being encoded ('query' or 'passage')
         
     | 
| 432 | 
         | 
| 433 | 
         
             
                    Returns:
         
     | 
| 434 | 
         
            +
                        List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
         
     | 
| 435 | 
         
             
                    """
         
     | 
| 436 | 
         
             
                    prompt_name = prompt_name or "query"
         
     | 
| 437 | 
         
             
                    encode_kwargs = self._validate_encoding_params(
         
     | 
| 
         | 
|
| 459 | 
         
             
                        **encode_kwargs,
         
     | 
| 460 | 
         
             
                    )
         
     | 
| 461 | 
         | 
| 462 | 
         
            +
                    return embeddings if len(texts) > 1 else embeddings[0]
         
     | 
| 463 | 
         | 
| 464 | 
         
             
                def _load_images_if_needed(
         
     | 
| 465 | 
         
             
                    self, images: List[Union[str, Image.Image]]
         
     | 
| 
         | 
|
| 484 | 
         
             
                    return_numpy: bool = False,
         
     | 
| 485 | 
         
             
                    truncate_dim: Optional[int] = None,
         
     | 
| 486 | 
         
             
                    max_pixels: Optional[int] = None,
         
     | 
| 487 | 
         
            +
                ) -> Union[List[torch.Tensor], torch.Tensor]:
         
     | 
| 488 | 
         
             
                    """
         
     | 
| 489 | 
         
            +
                    Encodes a list of images or a single image into embedding(s).
         
     | 
| 490 | 
         | 
| 491 | 
         
             
                    Args:
         
     | 
| 492 | 
         
             
                        images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
         
     | 
| 
         | 
|
| 497 | 
         
             
                        max_pixels: Maximum number of pixels to process per image
         
     | 
| 498 | 
         | 
| 499 | 
         
             
                    Returns:
         
     | 
| 500 | 
         
            +
                        List of image embeddings as tensors or numpy arrays when encoding multiple images, or single image embedding as tensor when encoding a single image
         
     | 
| 501 | 
         
             
                    """
         
     | 
| 502 | 
         
             
                    if max_pixels:
         
     | 
| 503 | 
         
             
                        default_max_pixels = self.processor.image_processor.max_pixels
         
     | 
| 
         | 
|
| 525 | 
         
             
                    if max_pixels:
         
     | 
| 526 | 
         
             
                        self.processor.image_processor.max_pixels = default_max_pixels
         
     | 
| 527 | 
         | 
| 528 | 
         
            +
                    return embeddings if len(images) > 1 else embeddings[0]
         
     | 
| 529 | 
         | 
| 530 | 
         
             
                @classmethod
         
     | 
| 531 | 
         
             
                def from_pretrained(
         
     |