Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from typing import Generator, Tuple, List, Union, Dict | |
| from pathlib import Path | |
| import base64 | |
| from io import BytesIO | |
| import re | |
| import io | |
| import matplotlib.cm as cm | |
| from colpali_engine.models import ColPali, ColPaliProcessor | |
| from colpali_engine.utils.torch_utils import get_torch_device | |
| from vidore_benchmark.interpretability.torch_utils import ( | |
| normalize_similarity_map_per_query_token, | |
| ) | |
| from functools import lru_cache | |
| import logging | |
| class SimMapGenerator: | |
| """ | |
| Generates similarity maps based on query embeddings and image patches using the ColPali model. | |
| """ | |
| colormap = cm.get_cmap("viridis") # Preload colormap for efficiency | |
| def __init__( | |
| self, | |
| logger: logging.Logger, | |
| model_name: str = "vidore/colpali-v1.2", | |
| n_patch: int = 32, | |
| ): | |
| """ | |
| Initializes the SimMapGenerator class with a specified model and patch dimension. | |
| Args: | |
| model_name (str): The model name for loading the ColPali model. | |
| n_patch (int): The number of patches per dimension. | |
| """ | |
| self.model_name = model_name | |
| self.n_patch = n_patch | |
| self.device = get_torch_device("auto") | |
| self.logger = logger | |
| self.logger.info(f"Using device: {self.device}") | |
| self.model, self.processor = self.load_model() | |
| def load_model(self) -> Tuple[ColPali, ColPaliProcessor]: | |
| """ | |
| Loads the ColPali model and processor. | |
| Returns: | |
| Tuple[ColPali, ColPaliProcessor]: Loaded model and processor. | |
| """ | |
| model = ColPali.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.bfloat16, # Note that the embeddings created during feed were float32 -> binarized, yet setting this seem to produce the most similar results both locally (mps) and HF (Cuda) | |
| device_map=self.device, | |
| ).eval() | |
| processor = ColPaliProcessor.from_pretrained(self.model_name) | |
| return model, processor | |
| def gen_similarity_maps( | |
| self, | |
| query: str, | |
| query_embs: torch.Tensor, | |
| token_idx_map: Dict[int, str], | |
| images: List[Union[Path, str]], | |
| vespa_sim_maps: List[Dict], | |
| ) -> Generator[Tuple[int, str, str], None, None]: | |
| """ | |
| Generates similarity maps for the provided images and query, and returns base64-encoded blended images. | |
| Args: | |
| query (str): The query string. | |
| query_embs (torch.Tensor): Query embeddings tensor. | |
| token_idx_map (dict): Mapping from indices to tokens. | |
| images (List[Union[Path, str]]): List of image paths or base64-encoded strings. | |
| vespa_sim_maps (List[Dict]): List of Vespa similarity maps. | |
| Yields: | |
| Tuple[int, str, str]: A tuple containing the image index, selected token, and base64-encoded image. | |
| """ | |
| processed_images, original_images, original_sizes = [], [], [] | |
| for img in images: | |
| img_pil = self._load_image(img) | |
| original_images.append(img_pil.copy()) | |
| original_sizes.append(img_pil.size) | |
| processed_images.append(img_pil) | |
| vespa_sim_map_tensor = self._prepare_similarity_map_tensor( | |
| query_embs, vespa_sim_maps | |
| ) | |
| similarity_map_normalized = normalize_similarity_map_per_query_token( | |
| vespa_sim_map_tensor | |
| ) | |
| for idx, img in enumerate(original_images): | |
| for token_idx, token in token_idx_map.items(): | |
| if self.should_filter_token(token): | |
| continue | |
| sim_map = similarity_map_normalized[idx, token_idx, :, :] | |
| blended_img_base64 = self._blend_image( | |
| img, sim_map, original_sizes[idx] | |
| ) | |
| yield idx, token, token_idx, blended_img_base64 | |
| def _load_image(self, img: Union[Path, str]) -> Image: | |
| """ | |
| Loads an image from a file path or a base64-encoded string. | |
| Args: | |
| img (Union[Path, str]): The image to load. | |
| Returns: | |
| Image: The loaded PIL image. | |
| """ | |
| try: | |
| if isinstance(img, Path): | |
| return Image.open(img).convert("RGB") | |
| elif isinstance(img, str): | |
| return Image.open(BytesIO(base64.b64decode(img))).convert("RGB") | |
| except Exception as e: | |
| raise ValueError(f"Failed to load image: {e}") | |
| def _prepare_similarity_map_tensor( | |
| self, query_embs: torch.Tensor, vespa_sim_maps: List[Dict] | |
| ) -> torch.Tensor: | |
| """ | |
| Prepares a similarity map tensor from Vespa similarity maps. | |
| Args: | |
| query_embs (torch.Tensor): Query embeddings tensor. | |
| vespa_sim_maps (List[Dict]): List of Vespa similarity maps. | |
| Returns: | |
| torch.Tensor: The prepared similarity map tensor. | |
| """ | |
| vespa_sim_map_tensor = torch.zeros( | |
| (len(vespa_sim_maps), query_embs.size(1), self.n_patch, self.n_patch) | |
| ) | |
| for idx, vespa_sim_map in enumerate(vespa_sim_maps): | |
| for cell in vespa_sim_map["quantized"]["cells"]: | |
| patch = int(cell["address"]["patch"]) | |
| query_token = int(cell["address"]["querytoken"]) | |
| value = cell["value"] | |
| if hasattr(self.processor, "image_seq_length"): | |
| image_seq_length = self.processor.image_seq_length | |
| else: | |
| image_seq_length = 1024 | |
| if patch >= image_seq_length: | |
| continue | |
| vespa_sim_map_tensor[ | |
| idx, | |
| query_token, | |
| patch // self.n_patch, | |
| patch % self.n_patch, | |
| ] = value | |
| return vespa_sim_map_tensor | |
| def _blend_image( | |
| self, img: Image, sim_map: torch.Tensor, original_size: Tuple[int, int] | |
| ) -> str: | |
| """ | |
| Blends an image with a similarity map and encodes it to base64. | |
| Args: | |
| img (Image): The original image. | |
| sim_map (torch.Tensor): The similarity map tensor. | |
| original_size (Tuple[int, int]): The original size of the image. | |
| Returns: | |
| str: The base64-encoded blended image. | |
| """ | |
| SCALING_FACTOR = 8 | |
| sim_map_resolution = ( | |
| max(32, int(original_size[0] / SCALING_FACTOR)), | |
| max(32, int(original_size[1] / SCALING_FACTOR)), | |
| ) | |
| sim_map_np = sim_map.cpu().float().numpy() | |
| sim_map_img = Image.fromarray(sim_map_np).resize( | |
| sim_map_resolution, resample=Image.BICUBIC | |
| ) | |
| sim_map_resized_np = np.array(sim_map_img, dtype=np.float32) | |
| sim_map_normalized = self._normalize_sim_map(sim_map_resized_np) | |
| heatmap = self.colormap(sim_map_normalized) | |
| heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)).convert("RGBA") | |
| buffer = io.BytesIO() | |
| heatmap_img.save(buffer, format="PNG") | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| def _normalize_sim_map(sim_map: np.ndarray) -> np.ndarray: | |
| """ | |
| Normalizes a similarity map to range [0, 1]. | |
| Args: | |
| sim_map (np.ndarray): The similarity map. | |
| Returns: | |
| np.ndarray: The normalized similarity map. | |
| """ | |
| sim_map_min, sim_map_max = sim_map.min(), sim_map.max() | |
| if sim_map_max - sim_map_min > 1e-6: | |
| return (sim_map - sim_map_min) / (sim_map_max - sim_map_min) | |
| return np.zeros_like(sim_map) | |
| def should_filter_token(token: str) -> bool: | |
| """ | |
| Determines if a token should be filtered out based on predefined patterns. | |
| The function filters out tokens that: | |
| - Start with '<' (e.g., '<bos>') | |
| - Consist entirely of whitespace | |
| - Are purely punctuation (excluding tokens that contain digits or start with 'β') | |
| - Start with an underscore '_' | |
| - Exactly match the word 'Question' | |
| - Are exactly the single character 'β' | |
| Output of test: | |
| Token: '2' | False | |
| Token: '0' | False | |
| Token: '2' | False | |
| Token: '3' | False | |
| Token: 'β2' | False | |
| Token: 'βhi' | False | |
| Token: 'norwegian' | False | |
| Token: 'unlisted' | False | |
| Token: '<bos>' | True | |
| Token: 'Question' | True | |
| Token: ':' | True | |
| Token: '<pad>' | True | |
| Token: '\n' | True | |
| Token: 'β' | True | |
| Token: '?' | True | |
| Token: ')' | True | |
| Token: '%' | True | |
| Token: '/)' | True | |
| Args: | |
| token (str): The token to check. | |
| Returns: | |
| bool: True if the token should be filtered out, False otherwise. | |
| """ | |
| pattern = re.compile( | |
| r"^<.*$|^\s+$|^(?!.*\d)(?!β)[^\w\s]+$|^_.*$|^Question$|^β$" | |
| ) | |
| return bool(pattern.match(token)) | |
| def get_query_embeddings_and_token_map( | |
| self, query: str | |
| ) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Retrieves query embeddings and a token index map. | |
| Args: | |
| query (str): The query string. | |
| Returns: | |
| Tuple[torch.Tensor, dict]: Query embeddings and token index map. | |
| """ | |
| inputs = self.processor.process_queries([query]).to(self.model.device) | |
| with torch.no_grad(): | |
| q_emb = self.model(**inputs).to("cpu")[0] | |
| query_tokens = self.processor.tokenizer.tokenize( | |
| self.processor.decode(inputs.input_ids[0]) | |
| ) | |
| idx_to_token = {idx: token for idx, token in enumerate(query_tokens)} | |
| return q_emb, idx_to_token | |