from typing import Dict, Generator, List, Optional, Sequence, Union
import numpy as np
from numpy.typing import NDArray

import pytest
import chromadb
from chromadb.api.types import URI, DataLoader, Documents, IDs, Image, URIs
from chromadb.api import ServerAPI
from chromadb.test.ef.test_multimodal_ef import hashing_multimodal_ef


def encode_data(data: str) -> NDArray[np.uint8]:
    return np.array(data.encode())


class DefaultDataLoader(DataLoader[List[Optional[Image]]]):
    def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
        # Convert each URI to a numpy array
        return [None if uri is None else encode_data(uri) for uri in uris]


def record_set_with_uris(n: int = 3) -> Dict[str, Union[IDs, Documents, URIs]]:
    return {
        "ids": [f"{i}" for i in range(n)],
        "documents": [f"document_{i}" for i in range(n)],
        "uris": [f"uri_{i}" for i in range(n)],
    }


@pytest.fixture()
def collection_with_data_loader(
    api: ServerAPI,
) -> Generator[chromadb.Collection, None, None]:
    collection = api.create_collection(
        name="collection_with_data_loader",
        data_loader=DefaultDataLoader(),
        embedding_function=hashing_multimodal_ef(),
    )
    yield collection
    api.delete_collection(collection.name)


@pytest.fixture
def collection_without_data_loader(
    api: ServerAPI,
) -> Generator[chromadb.Collection, None, None]:
    collection = api.create_collection(
        name="collection_without_data_loader",
        embedding_function=hashing_multimodal_ef(),
    )
    yield collection
    api.delete_collection(collection.name)


def test_without_data_loader(
    collection_without_data_loader: chromadb.Collection,
    n_examples: int = 3,
) -> None:
    record_set = record_set_with_uris(n=n_examples)

    # Can't embed data in URIs without a data loader
    with pytest.raises(ValueError):
        collection_without_data_loader.add(
            ids=record_set["ids"],
            uris=record_set["uris"],
        )

    # Can't get data from URIs without a data loader
    with pytest.raises(ValueError):
        collection_without_data_loader.get(include=["data"])


def test_without_uris(
    collection_with_data_loader: chromadb.Collection, n_examples: int = 3
) -> None:
    record_set = record_set_with_uris(n=n_examples)

    collection_with_data_loader.add(
        ids=record_set["ids"],
        documents=record_set["documents"],
    )

    get_result = collection_with_data_loader.get(include=["data"])

    assert get_result["data"] is not None
    for data in get_result["data"]:
        assert data is None


def test_data_loader(
    collection_with_data_loader: chromadb.Collection, n_examples: int = 3
) -> None:
    record_set = record_set_with_uris(n=n_examples)

    collection_with_data_loader.add(
        ids=record_set["ids"],
        uris=record_set["uris"],
    )

    # Get with "data"
    get_result = collection_with_data_loader.get(include=["data"])

    assert get_result["data"] is not None
    for i, data in enumerate(get_result["data"]):
        assert data is not None
        assert data == encode_data(record_set["uris"][i])

    # Query by URI
    query_result = collection_with_data_loader.query(
        query_uris=record_set["uris"],
        n_results=len(record_set["uris"][0]),
        include=["data", "uris"],
    )

    assert query_result["data"] is not None
    for i, data in enumerate(query_result["data"][0]):
        assert data is not None
        assert query_result["uris"] is not None
        assert data == encode_data(query_result["uris"][0][i])