File size: 3,613 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Dict, Generator, List, Optional, Sequence, Union
import numpy as np
from numpy.typing import NDArray

import pytest
import chromadb
from chromadb.api.types import URI, DataLoader, Documents, IDs, Image, URIs
from chromadb.api import ServerAPI
from chromadb.test.ef.test_multimodal_ef import hashing_multimodal_ef


def encode_data(data: str) -> NDArray[np.uint8]:
    return np.array(data.encode())


class DefaultDataLoader(DataLoader[List[Optional[Image]]]):
    def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
        # Convert each URI to a numpy array
        return [None if uri is None else encode_data(uri) for uri in uris]


def record_set_with_uris(n: int = 3) -> Dict[str, Union[IDs, Documents, URIs]]:
    return {
        "ids": [f"{i}" for i in range(n)],
        "documents": [f"document_{i}" for i in range(n)],
        "uris": [f"uri_{i}" for i in range(n)],
    }


@pytest.fixture()
def collection_with_data_loader(
    api: ServerAPI,
) -> Generator[chromadb.Collection, None, None]:
    collection = api.create_collection(
        name="collection_with_data_loader",
        data_loader=DefaultDataLoader(),
        embedding_function=hashing_multimodal_ef(),
    )
    yield collection
    api.delete_collection(collection.name)


@pytest.fixture
def collection_without_data_loader(
    api: ServerAPI,
) -> Generator[chromadb.Collection, None, None]:
    collection = api.create_collection(
        name="collection_without_data_loader",
        embedding_function=hashing_multimodal_ef(),
    )
    yield collection
    api.delete_collection(collection.name)


def test_without_data_loader(
    collection_without_data_loader: chromadb.Collection,
    n_examples: int = 3,
) -> None:
    record_set = record_set_with_uris(n=n_examples)

    # Can't embed data in URIs without a data loader
    with pytest.raises(ValueError):
        collection_without_data_loader.add(
            ids=record_set["ids"],
            uris=record_set["uris"],
        )

    # Can't get data from URIs without a data loader
    with pytest.raises(ValueError):
        collection_without_data_loader.get(include=["data"])


def test_without_uris(
    collection_with_data_loader: chromadb.Collection, n_examples: int = 3
) -> None:
    record_set = record_set_with_uris(n=n_examples)

    collection_with_data_loader.add(
        ids=record_set["ids"],
        documents=record_set["documents"],
    )

    get_result = collection_with_data_loader.get(include=["data"])

    assert get_result["data"] is not None
    for data in get_result["data"]:
        assert data is None


def test_data_loader(
    collection_with_data_loader: chromadb.Collection, n_examples: int = 3
) -> None:
    record_set = record_set_with_uris(n=n_examples)

    collection_with_data_loader.add(
        ids=record_set["ids"],
        uris=record_set["uris"],
    )

    # Get with "data"
    get_result = collection_with_data_loader.get(include=["data"])

    assert get_result["data"] is not None
    for i, data in enumerate(get_result["data"]):
        assert data is not None
        assert data == encode_data(record_set["uris"][i])

    # Query by URI
    query_result = collection_with_data_loader.query(
        query_uris=record_set["uris"],
        n_results=len(record_set["uris"][0]),
        include=["data", "uris"],
    )

    assert query_result["data"] is not None
    for i, data in enumerate(query_result["data"][0]):
        assert data is not None
        assert query_result["uris"] is not None
        assert data == encode_data(query_result["uris"][0][i])