chroma / chromadb /test /ef /test_multimodal_ef.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
raw
history blame contribute delete
5.59 kB
from typing import Generator, cast
import numpy as np
import pytest
import chromadb
from chromadb.api.types import (
Embeddable,
EmbeddingFunction,
Embeddings,
Image,
Document,
)
from chromadb.test.property.strategies import hashing_embedding_function
from chromadb.test.property.invariants import _exact_distances
# A 'standard' multimodal embedding function, which converts inputs to strings
# then hashes them to a fixed dimension.
class hashing_multimodal_ef(EmbeddingFunction[Embeddable]):
def __init__(self) -> None:
self._hef = hashing_embedding_function(dim=10, dtype=np.float_)
def __call__(self, input: Embeddable) -> Embeddings:
to_texts = [str(i) for i in input]
embeddings = np.array(self._hef(to_texts))
# Normalize the embeddings
# This is so we can generate random unit vectors and have them be close to the embeddings
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
return cast(Embeddings, embeddings.tolist())
def random_image() -> Image:
return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32)
def random_document() -> Document:
return str(random_image())
@pytest.fixture
def multimodal_collection(
default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(),
) -> Generator[chromadb.Collection, None, None]:
client = chromadb.Client()
collection = client.create_collection(
name="multimodal_collection", embedding_function=default_ef
)
yield collection
client.clear_system_cache()
# Test adding and querying of a multimodal collection consisting of images and documents
def test_multimodal(
multimodal_collection: chromadb.Collection,
default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(),
n_examples: int = 10,
n_query_results: int = 3,
) -> None:
# Fix numpy's random seed for reproducibility
random_state = np.random.get_state()
np.random.seed(0)
image_ids = [str(i) for i in range(n_examples)]
images = [random_image() for _ in range(n_examples)]
image_embeddings = default_ef(images)
document_ids = [str(i) for i in range(n_examples, 2 * n_examples)]
documents = [random_document() for _ in range(n_examples)]
document_embeddings = default_ef(documents)
# Trying to add a document and an image at the same time should fail
with pytest.raises(
ValueError, match="You can only provide documents or images, not both."
):
multimodal_collection.add(
ids=image_ids[0], documents=documents[0], images=images[0]
)
# Add some documents
multimodal_collection.add(ids=document_ids, documents=documents)
# Add some images
multimodal_collection.add(ids=image_ids, images=images)
# get() should return all the documents and images
# ids corresponding to images should not have documents
get_result = multimodal_collection.get(include=["documents"])
assert len(get_result["ids"]) == len(document_ids) + len(image_ids)
for i, id in enumerate(get_result["ids"]):
assert id in document_ids or id in image_ids
assert get_result["documents"] is not None
if id in document_ids:
assert get_result["documents"][i] == documents[document_ids.index(id)]
if id in image_ids:
assert get_result["documents"][i] is None
# Generate a random query image
query_image = random_image()
query_image_embedding = default_ef([query_image])
image_neighbor_indices, _ = _exact_distances(
query_image_embedding, image_embeddings + document_embeddings
)
# Get the ids of the nearest neighbors
nearest_image_neighbor_ids = [
image_ids[i] if i < n_examples else document_ids[i % n_examples]
for i in image_neighbor_indices[0][:n_query_results]
]
# Generate a random query document
query_document = random_document()
query_document_embedding = default_ef([query_document])
document_neighbor_indices, _ = _exact_distances(
query_document_embedding, image_embeddings + document_embeddings
)
nearest_document_neighbor_ids = [
image_ids[i] if i < n_examples else document_ids[i % n_examples]
for i in document_neighbor_indices[0][:n_query_results]
]
# Querying with both images and documents should fail
with pytest.raises(ValueError):
multimodal_collection.query(
query_images=[query_image], query_texts=[query_document]
)
# Query with images
query_result = multimodal_collection.query(
query_images=[query_image], n_results=n_query_results, include=["documents"]
)
assert query_result["ids"][0] == nearest_image_neighbor_ids
# Query with documents
query_result = multimodal_collection.query(
query_texts=[query_document], n_results=n_query_results, include=["documents"]
)
assert query_result["ids"][0] == nearest_document_neighbor_ids
np.random.set_state(random_state)
@pytest.mark.xfail
def test_multimodal_update_with_image(
multimodal_collection: chromadb.Collection,
) -> None:
# Updating an entry with an existing document should remove the documentß
document = random_document()
image = random_image()
id = "0"
multimodal_collection.add(ids=id, documents=document)
multimodal_collection.update(ids=id, images=image)
get_result = multimodal_collection.get(ids=id, include=["documents"])
assert get_result["documents"] is not None
assert get_result["documents"][0] is None