Spaces:
Sleeping
Sleeping
| from typing import Any, Callable, Dict, List, Optional, Sequence, Set | |
| import numpy as np | |
| import numpy.typing as npt | |
| from chromadb.types import ( | |
| EmbeddingRecord, | |
| VectorEmbeddingRecord, | |
| VectorQuery, | |
| VectorQueryResult, | |
| ) | |
| from chromadb.utils import distance_functions | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class BruteForceIndex: | |
| """A lightweight, numpy based brute force index that is used for batches that have not been indexed into hnsw yet. It is not | |
| thread safe and callers should ensure that only one thread is accessing it at a time. | |
| """ | |
| id_to_index: Dict[str, int] | |
| index_to_id: Dict[int, str] | |
| id_to_seq_id: Dict[str, int] | |
| deleted_ids: Set[str] | |
| free_indices: List[int] | |
| size: int | |
| dimensionality: int | |
| distance_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any]], float] | |
| vectors: npt.NDArray[Any] | |
| def __init__(self, size: int, dimensionality: int, space: str = "l2"): | |
| if space == "l2": | |
| self.distance_fn = distance_functions.l2 | |
| elif space == "ip": | |
| self.distance_fn = distance_functions.ip | |
| elif space == "cosine": | |
| self.distance_fn = distance_functions.cosine | |
| else: | |
| raise Exception(f"Unknown distance function: {space}") | |
| self.id_to_index = {} | |
| self.index_to_id = {} | |
| self.id_to_seq_id = {} | |
| self.deleted_ids = set() | |
| self.free_indices = list(range(size)) | |
| self.size = size | |
| self.dimensionality = dimensionality | |
| self.vectors = np.zeros((size, dimensionality)) | |
| def __len__(self) -> int: | |
| return len(self.id_to_index) | |
| def clear(self) -> None: | |
| self.id_to_index = {} | |
| self.index_to_id = {} | |
| self.id_to_seq_id = {} | |
| self.deleted_ids.clear() | |
| self.free_indices = list(range(self.size)) | |
| self.vectors.fill(0) | |
| def upsert(self, records: List[EmbeddingRecord]) -> None: | |
| if len(records) + len(self) > self.size: | |
| raise Exception( | |
| "Index with capacity {} and {} current entries cannot add {} records".format( | |
| self.size, len(self), len(records) | |
| ) | |
| ) | |
| for i, record in enumerate(records): | |
| id = record["id"] | |
| vector = record["embedding"] | |
| self.id_to_seq_id[id] = record["seq_id"] | |
| if id in self.deleted_ids: | |
| self.deleted_ids.remove(id) | |
| # TODO: It may be faster to use multi-index selection on the vectors array | |
| if id in self.id_to_index: | |
| # Update | |
| index = self.id_to_index[id] | |
| self.vectors[index] = vector | |
| else: | |
| # Add | |
| next_index = self.free_indices.pop() | |
| self.id_to_index[id] = next_index | |
| self.index_to_id[next_index] = id | |
| self.vectors[next_index] = vector | |
| def delete(self, records: List[EmbeddingRecord]) -> None: | |
| for record in records: | |
| id = record["id"] | |
| if id in self.id_to_index: | |
| index = self.id_to_index[id] | |
| self.deleted_ids.add(id) | |
| del self.id_to_index[id] | |
| del self.index_to_id[index] | |
| del self.id_to_seq_id[id] | |
| self.vectors[index].fill(np.NaN) | |
| self.free_indices.append(index) | |
| else: | |
| logger.warning(f"Delete of nonexisting embedding ID: {id}") | |
| def has_id(self, id: str) -> bool: | |
| """Returns whether the index contains the given ID""" | |
| return id in self.id_to_index and id not in self.deleted_ids | |
| def get_vectors( | |
| self, ids: Optional[Sequence[str]] = None | |
| ) -> Sequence[VectorEmbeddingRecord]: | |
| target_ids = ids or self.id_to_index.keys() | |
| return [ | |
| VectorEmbeddingRecord( | |
| id=id, | |
| embedding=self.vectors[self.id_to_index[id]].tolist(), | |
| seq_id=self.id_to_seq_id[id], | |
| ) | |
| for id in target_ids | |
| ] | |
| def query(self, query: VectorQuery) -> Sequence[Sequence[VectorQueryResult]]: | |
| np_query = np.array(query["vectors"]) | |
| allowed_ids = ( | |
| None if query["allowed_ids"] is None else set(query["allowed_ids"]) | |
| ) | |
| distances = np.apply_along_axis( | |
| lambda query: np.apply_along_axis(self.distance_fn, 1, self.vectors, query), | |
| 1, | |
| np_query, | |
| ) | |
| indices = np.argsort(distances).tolist() | |
| # Filter out deleted labels | |
| filtered_results = [] | |
| for i, index_list in enumerate(indices): | |
| curr_results = [] | |
| for j in index_list: | |
| # If the index is in the index_to_id map, then it has been added | |
| if j in self.index_to_id: | |
| id = self.index_to_id[j] | |
| if id not in self.deleted_ids and ( | |
| allowed_ids is None or id in allowed_ids | |
| ): | |
| curr_results.append( | |
| VectorQueryResult( | |
| id=id, | |
| distance=distances[i][j].item(), | |
| seq_id=self.id_to_seq_id[id], | |
| embedding=self.vectors[j].tolist(), | |
| ) | |
| ) | |
| filtered_results.append(curr_results) | |
| return filtered_results | |