Spaces:
Sleeping
Sleeping
from typing import Generator, cast | |
import numpy as np | |
import pytest | |
import chromadb | |
from chromadb.api.types import ( | |
Embeddable, | |
EmbeddingFunction, | |
Embeddings, | |
Image, | |
Document, | |
) | |
from chromadb.test.property.strategies import hashing_embedding_function | |
from chromadb.test.property.invariants import _exact_distances | |
# A 'standard' multimodal embedding function, which converts inputs to strings | |
# then hashes them to a fixed dimension. | |
class hashing_multimodal_ef(EmbeddingFunction[Embeddable]): | |
def __init__(self) -> None: | |
self._hef = hashing_embedding_function(dim=10, dtype=np.float_) | |
def __call__(self, input: Embeddable) -> Embeddings: | |
to_texts = [str(i) for i in input] | |
embeddings = np.array(self._hef(to_texts)) | |
# Normalize the embeddings | |
# This is so we can generate random unit vectors and have them be close to the embeddings | |
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) | |
return cast(Embeddings, embeddings.tolist()) | |
def random_image() -> Image: | |
return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32) | |
def random_document() -> Document: | |
return str(random_image()) | |
def multimodal_collection( | |
default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), | |
) -> Generator[chromadb.Collection, None, None]: | |
client = chromadb.Client() | |
collection = client.create_collection( | |
name="multimodal_collection", embedding_function=default_ef | |
) | |
yield collection | |
client.clear_system_cache() | |
# Test adding and querying of a multimodal collection consisting of images and documents | |
def test_multimodal( | |
multimodal_collection: chromadb.Collection, | |
default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), | |
n_examples: int = 10, | |
n_query_results: int = 3, | |
) -> None: | |
# Fix numpy's random seed for reproducibility | |
random_state = np.random.get_state() | |
np.random.seed(0) | |
image_ids = [str(i) for i in range(n_examples)] | |
images = [random_image() for _ in range(n_examples)] | |
image_embeddings = default_ef(images) | |
document_ids = [str(i) for i in range(n_examples, 2 * n_examples)] | |
documents = [random_document() for _ in range(n_examples)] | |
document_embeddings = default_ef(documents) | |
# Trying to add a document and an image at the same time should fail | |
with pytest.raises( | |
ValueError, match="You can only provide documents or images, not both." | |
): | |
multimodal_collection.add( | |
ids=image_ids[0], documents=documents[0], images=images[0] | |
) | |
# Add some documents | |
multimodal_collection.add(ids=document_ids, documents=documents) | |
# Add some images | |
multimodal_collection.add(ids=image_ids, images=images) | |
# get() should return all the documents and images | |
# ids corresponding to images should not have documents | |
get_result = multimodal_collection.get(include=["documents"]) | |
assert len(get_result["ids"]) == len(document_ids) + len(image_ids) | |
for i, id in enumerate(get_result["ids"]): | |
assert id in document_ids or id in image_ids | |
assert get_result["documents"] is not None | |
if id in document_ids: | |
assert get_result["documents"][i] == documents[document_ids.index(id)] | |
if id in image_ids: | |
assert get_result["documents"][i] is None | |
# Generate a random query image | |
query_image = random_image() | |
query_image_embedding = default_ef([query_image]) | |
image_neighbor_indices, _ = _exact_distances( | |
query_image_embedding, image_embeddings + document_embeddings | |
) | |
# Get the ids of the nearest neighbors | |
nearest_image_neighbor_ids = [ | |
image_ids[i] if i < n_examples else document_ids[i % n_examples] | |
for i in image_neighbor_indices[0][:n_query_results] | |
] | |
# Generate a random query document | |
query_document = random_document() | |
query_document_embedding = default_ef([query_document]) | |
document_neighbor_indices, _ = _exact_distances( | |
query_document_embedding, image_embeddings + document_embeddings | |
) | |
nearest_document_neighbor_ids = [ | |
image_ids[i] if i < n_examples else document_ids[i % n_examples] | |
for i in document_neighbor_indices[0][:n_query_results] | |
] | |
# Querying with both images and documents should fail | |
with pytest.raises(ValueError): | |
multimodal_collection.query( | |
query_images=[query_image], query_texts=[query_document] | |
) | |
# Query with images | |
query_result = multimodal_collection.query( | |
query_images=[query_image], n_results=n_query_results, include=["documents"] | |
) | |
assert query_result["ids"][0] == nearest_image_neighbor_ids | |
# Query with documents | |
query_result = multimodal_collection.query( | |
query_texts=[query_document], n_results=n_query_results, include=["documents"] | |
) | |
assert query_result["ids"][0] == nearest_document_neighbor_ids | |
np.random.set_state(random_state) | |
def test_multimodal_update_with_image( | |
multimodal_collection: chromadb.Collection, | |
) -> None: | |
# Updating an entry with an existing document should remove the documentß | |
document = random_document() | |
image = random_image() | |
id = "0" | |
multimodal_collection.add(ids=id, documents=document) | |
multimodal_collection.update(ids=id, images=image) | |
get_result = multimodal_collection.get(ids=id, include=["documents"]) | |
assert get_result["documents"] is not None | |
assert get_result["documents"][0] is None | |