Spaces:
Sleeping
Sleeping
from chromadb.api import ServerAPI | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
from chromadb.db.system import SysDB | |
from chromadb.segment import SegmentManager, MetadataReader, VectorReader | |
from chromadb.telemetry.opentelemetry import ( | |
add_attributes_to_current_span, | |
OpenTelemetryClient, | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
from chromadb.telemetry.product import ProductTelemetryClient | |
from chromadb.ingest import Producer | |
from chromadb.api.models.Collection import Collection | |
from chromadb import __version__ | |
from chromadb.errors import InvalidDimensionException, InvalidCollectionException | |
import chromadb.utils.embedding_functions as ef | |
from chromadb.api.types import ( | |
URI, | |
CollectionMetadata, | |
Embeddable, | |
Document, | |
EmbeddingFunction, | |
DataLoader, | |
IDs, | |
Embeddings, | |
Embedding, | |
Loadable, | |
Metadatas, | |
Documents, | |
URIs, | |
Where, | |
WhereDocument, | |
Include, | |
GetResult, | |
QueryResult, | |
validate_metadata, | |
validate_update_metadata, | |
validate_where, | |
validate_where_document, | |
validate_batch, | |
) | |
from chromadb.telemetry.product.events import ( | |
CollectionAddEvent, | |
CollectionDeleteEvent, | |
CollectionGetEvent, | |
CollectionUpdateEvent, | |
CollectionQueryEvent, | |
ClientCreateCollectionEvent, | |
) | |
import chromadb.types as t | |
from typing import Any, Optional, Sequence, Generator, List, cast, Set, Dict | |
from overrides import override | |
from uuid import UUID, uuid4 | |
import time | |
import logging | |
import re | |
logger = logging.getLogger(__name__) | |
# mimics s3 bucket requirements for naming | |
def check_index_name(index_name: str) -> None: | |
msg = ( | |
"Expected collection name that " | |
"(1) contains 3-63 characters, " | |
"(2) starts and ends with an alphanumeric character, " | |
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " | |
"(4) contains no two consecutive periods (..) and " | |
"(5) is not a valid IPv4 address, " | |
f"got {index_name}" | |
) | |
if len(index_name) < 3 or len(index_name) > 63: | |
raise ValueError(msg) | |
if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): | |
raise ValueError(msg) | |
if ".." in index_name: | |
raise ValueError(msg) | |
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): | |
raise ValueError(msg) | |
class SegmentAPI(ServerAPI): | |
"""API implementation utilizing the new segment-based internal architecture""" | |
_settings: Settings | |
_sysdb: SysDB | |
_manager: SegmentManager | |
_producer: Producer | |
_product_telemetry_client: ProductTelemetryClient | |
_opentelemetry_client: OpenTelemetryClient | |
_tenant_id: str | |
_topic_ns: str | |
_collection_cache: Dict[UUID, t.Collection] | |
def __init__(self, system: System): | |
super().__init__(system) | |
self._settings = system.settings | |
self._sysdb = self.require(SysDB) | |
self._manager = self.require(SegmentManager) | |
self._product_telemetry_client = self.require(ProductTelemetryClient) | |
self._opentelemetry_client = self.require(OpenTelemetryClient) | |
self._producer = self.require(Producer) | |
self._collection_cache = {} | |
def heartbeat(self) -> int: | |
return int(time.time_ns()) | |
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: | |
if len(name) < 3: | |
raise ValueError("Database name must be at least 3 characters long") | |
self._sysdb.create_database( | |
id=uuid4(), | |
name=name, | |
tenant=tenant, | |
) | |
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: | |
return self._sysdb.get_database(name=name, tenant=tenant) | |
def create_tenant(self, name: str) -> None: | |
if len(name) < 3: | |
raise ValueError("Tenant name must be at least 3 characters long") | |
self._sysdb.create_tenant( | |
name=name, | |
) | |
def get_tenant(self, name: str) -> t.Tenant: | |
return self._sysdb.get_tenant(name=name) | |
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is | |
# necessary because changing the value type from `Any` to`` `Union[str, int, float]` | |
# causes the system to somehow convert all values to strings. | |
def create_collection( | |
self, | |
name: str, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[ | |
EmbeddingFunction[Any] | |
] = ef.DefaultEmbeddingFunction(), | |
data_loader: Optional[DataLoader[Loadable]] = None, | |
get_or_create: bool = False, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Collection: | |
if metadata is not None: | |
validate_metadata(metadata) | |
# TODO: remove backwards compatibility in naming requirements | |
check_index_name(name) | |
id = uuid4() | |
coll, created = self._sysdb.create_collection( | |
id=id, | |
name=name, | |
metadata=metadata, | |
dimension=None, | |
get_or_create=get_or_create, | |
tenant=tenant, | |
database=database, | |
) | |
if created: | |
segments = self._manager.create_segments(coll) | |
for segment in segments: | |
self._sysdb.create_segment(segment) | |
# TODO: This event doesn't capture the get_or_create case appropriately | |
self._product_telemetry_client.capture( | |
ClientCreateCollectionEvent( | |
collection_uuid=str(id), | |
embedding_function=embedding_function.__class__.__name__, | |
) | |
) | |
add_attributes_to_current_span({"collection_uuid": str(id)}) | |
return Collection( | |
client=self, | |
id=coll["id"], | |
name=name, | |
metadata=coll["metadata"], # type: ignore | |
embedding_function=embedding_function, | |
data_loader=data_loader, | |
tenant=tenant, | |
database=database, | |
) | |
def get_or_create_collection( | |
self, | |
name: str, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[ | |
EmbeddingFunction[Embeddable] | |
] = ef.DefaultEmbeddingFunction(), # type: ignore | |
data_loader: Optional[DataLoader[Loadable]] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Collection: | |
return self.create_collection( # type: ignore | |
name=name, | |
metadata=metadata, | |
embedding_function=embedding_function, | |
data_loader=data_loader, | |
get_or_create=True, | |
tenant=tenant, | |
database=database, | |
) | |
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is | |
# necessary because changing the value type from `Any` to`` `Union[str, int, float]` | |
# causes the system to somehow convert all values to strings | |
def get_collection( | |
self, | |
name: Optional[str] = None, | |
id: Optional[UUID] = None, | |
embedding_function: Optional[ | |
EmbeddingFunction[Embeddable] | |
] = ef.DefaultEmbeddingFunction(), # type: ignore | |
data_loader: Optional[DataLoader[Loadable]] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Collection: | |
if id is None and name is None or (id is not None and name is not None): | |
raise ValueError("Name or id must be specified, but not both") | |
existing = self._sysdb.get_collections( | |
id=id, name=name, tenant=tenant, database=database | |
) | |
if existing: | |
return Collection( | |
client=self, | |
id=existing[0]["id"], | |
name=existing[0]["name"], | |
metadata=existing[0]["metadata"], # type: ignore | |
embedding_function=embedding_function, | |
data_loader=data_loader, | |
tenant=existing[0]["tenant"], | |
database=existing[0]["database"], | |
) | |
else: | |
raise ValueError(f"Collection {name} does not exist.") | |
def list_collections( | |
self, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Sequence[Collection]: | |
collections = [] | |
db_collections = self._sysdb.get_collections( | |
limit=limit, offset=offset, tenant=tenant, database=database | |
) | |
for db_collection in db_collections: | |
collections.append( | |
Collection( | |
client=self, | |
id=db_collection["id"], | |
name=db_collection["name"], | |
metadata=db_collection["metadata"], # type: ignore | |
tenant=db_collection["tenant"], | |
database=db_collection["database"], | |
) | |
) | |
return collections | |
def count_collections( | |
self, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> int: | |
collection_count = len( | |
self._sysdb.get_collections(tenant=tenant, database=database) | |
) | |
return collection_count | |
def _modify( | |
self, | |
id: UUID, | |
new_name: Optional[str] = None, | |
new_metadata: Optional[CollectionMetadata] = None, | |
) -> None: | |
if new_name: | |
# backwards compatibility in naming requirements (for now) | |
check_index_name(new_name) | |
if new_metadata: | |
validate_update_metadata(new_metadata) | |
# TODO eventually we'll want to use OptionalArgument and Unspecified in the | |
# signature of `_modify` but not changing the API right now. | |
if new_name and new_metadata: | |
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) | |
elif new_name: | |
self._sysdb.update_collection(id, name=new_name) | |
elif new_metadata: | |
self._sysdb.update_collection(id, metadata=new_metadata) | |
def delete_collection( | |
self, | |
name: str, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> None: | |
existing = self._sysdb.get_collections( | |
name=name, tenant=tenant, database=database | |
) | |
if existing: | |
self._sysdb.delete_collection( | |
existing[0]["id"], tenant=tenant, database=database | |
) | |
for s in self._manager.delete_segments(existing[0]["id"]): | |
self._sysdb.delete_segment(s) | |
if existing and existing[0]["id"] in self._collection_cache: | |
del self._collection_cache[existing[0]["id"]] | |
else: | |
raise ValueError(f"Collection {name} does not exist.") | |
def _add( | |
self, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.ADD) | |
validate_batch( | |
(ids, embeddings, metadatas, documents, uris), | |
{"max_batch_size": self.max_batch_size}, | |
) | |
records_to_submit = [] | |
for r in _records( | |
t.Operation.ADD, | |
ids=ids, | |
collection_id=collection_id, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
): | |
self._validate_embedding_record(coll, r) | |
records_to_submit.append(r) | |
self._producer.submit_embeddings(coll["topic"], records_to_submit) | |
self._product_telemetry_client.capture( | |
CollectionAddEvent( | |
collection_uuid=str(collection_id), | |
add_amount=len(ids), | |
with_metadata=len(ids) if metadatas is not None else 0, | |
with_documents=len(ids) if documents is not None else 0, | |
with_uris=len(ids) if uris is not None else 0, | |
) | |
) | |
return True | |
def _update( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) | |
validate_batch( | |
(ids, embeddings, metadatas, documents, uris), | |
{"max_batch_size": self.max_batch_size}, | |
) | |
records_to_submit = [] | |
for r in _records( | |
t.Operation.UPDATE, | |
ids=ids, | |
collection_id=collection_id, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
): | |
self._validate_embedding_record(coll, r) | |
records_to_submit.append(r) | |
self._producer.submit_embeddings(coll["topic"], records_to_submit) | |
self._product_telemetry_client.capture( | |
CollectionUpdateEvent( | |
collection_uuid=str(collection_id), | |
update_amount=len(ids), | |
with_embeddings=len(embeddings) if embeddings else 0, | |
with_metadata=len(metadatas) if metadatas else 0, | |
with_documents=len(documents) if documents else 0, | |
with_uris=len(uris) if uris else 0, | |
) | |
) | |
return True | |
def _upsert( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) | |
validate_batch( | |
(ids, embeddings, metadatas, documents, uris), | |
{"max_batch_size": self.max_batch_size}, | |
) | |
records_to_submit = [] | |
for r in _records( | |
t.Operation.UPSERT, | |
ids=ids, | |
collection_id=collection_id, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
): | |
self._validate_embedding_record(coll, r) | |
records_to_submit.append(r) | |
self._producer.submit_embeddings(coll["topic"], records_to_submit) | |
return True | |
def _get( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
sort: Optional[str] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
page: Optional[int] = None, | |
page_size: Optional[int] = None, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["embeddings", "metadatas", "documents"], | |
) -> GetResult: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(collection_id), | |
"ids_count": len(ids) if ids else 0, | |
} | |
) | |
where = validate_where(where) if where is not None and len(where) > 0 else None | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else None | |
) | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
if sort is not None: | |
raise NotImplementedError("Sorting is not yet supported") | |
if page and page_size: | |
offset = (page - 1) * page_size | |
limit = page_size | |
records = metadata_segment.get_metadata( | |
where=where, | |
where_document=where_document, | |
ids=ids, | |
limit=limit, | |
offset=offset, | |
) | |
if len(records) == 0: | |
# Nothing to return if there are no records | |
return GetResult( | |
ids=[], | |
embeddings=[] if "embeddings" in include else None, | |
metadatas=[] if "metadatas" in include else None, | |
documents=[] if "documents" in include else None, | |
uris=[] if "uris" in include else None, | |
data=[] if "data" in include else None, | |
) | |
vectors: Sequence[t.VectorEmbeddingRecord] = [] | |
if "embeddings" in include: | |
vector_ids = [r["id"] for r in records] | |
vector_segment = self._manager.get_segment(collection_id, VectorReader) | |
vectors = vector_segment.get_vectors(ids=vector_ids) | |
# TODO: Fix type so we don't need to ignore | |
# It is possible to have a set of records, some with metadata and some without | |
# Same with documents | |
metadatas = [r["metadata"] for r in records] | |
if "documents" in include: | |
documents = [_doc(m) for m in metadatas] | |
if "uris" in include: | |
uris = [_uri(m) for m in metadatas] | |
ids_amount = len(ids) if ids else 0 | |
self._product_telemetry_client.capture( | |
CollectionGetEvent( | |
collection_uuid=str(collection_id), | |
ids_count=ids_amount, | |
limit=limit if limit else 0, | |
include_metadata=ids_amount if "metadatas" in include else 0, | |
include_documents=ids_amount if "documents" in include else 0, | |
include_uris=ids_amount if "uris" in include else 0, | |
) | |
) | |
return GetResult( | |
ids=[r["id"] for r in records], | |
embeddings=[r["embedding"] for r in vectors] | |
if "embeddings" in include | |
else None, | |
metadatas=_clean_metadatas(metadatas) | |
if "metadatas" in include | |
else None, # type: ignore | |
documents=documents if "documents" in include else None, # type: ignore | |
uris=uris if "uris" in include else None, # type: ignore | |
data=None, | |
) | |
def _delete( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = None, | |
where_document: Optional[WhereDocument] = None, | |
) -> IDs: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(collection_id), | |
"ids_count": len(ids) if ids else 0, | |
} | |
) | |
where = validate_where(where) if where is not None and len(where) > 0 else None | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else None | |
) | |
# You must have at least one of non-empty ids, where, or where_document. | |
if ( | |
(ids is None or (ids is not None and len(ids) == 0)) | |
and (where is None or (where is not None and len(where) == 0)) | |
and ( | |
where_document is None | |
or (where_document is not None and len(where_document) == 0) | |
) | |
): | |
raise ValueError( | |
""" | |
You must provide either ids, where, or where_document to delete. If | |
you want to delete all data in a collection you can delete the | |
collection itself using the delete_collection method. Or alternatively, | |
you can get() all the relevant ids and then delete them. | |
""" | |
) | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.DELETE) | |
if (where or where_document) or not ids: | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
records = metadata_segment.get_metadata( | |
where=where, where_document=where_document, ids=ids | |
) | |
ids_to_delete = [r["id"] for r in records] | |
else: | |
ids_to_delete = ids | |
if len(ids_to_delete) == 0: | |
return [] | |
records_to_submit = [] | |
for r in _records( | |
operation=t.Operation.DELETE, ids=ids_to_delete, collection_id=collection_id | |
): | |
self._validate_embedding_record(coll, r) | |
records_to_submit.append(r) | |
self._producer.submit_embeddings(coll["topic"], records_to_submit) | |
self._product_telemetry_client.capture( | |
CollectionDeleteEvent( | |
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) | |
) | |
) | |
return ids_to_delete | |
def _count(self, collection_id: UUID) -> int: | |
add_attributes_to_current_span({"collection_id": str(collection_id)}) | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
return metadata_segment.count() | |
def _query( | |
self, | |
collection_id: UUID, | |
query_embeddings: Embeddings, | |
n_results: int = 10, | |
where: Where = {}, | |
where_document: WhereDocument = {}, | |
include: Include = ["documents", "metadatas", "distances"], | |
) -> QueryResult: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(collection_id), | |
"n_results": n_results, | |
"where": str(where), | |
} | |
) | |
where = validate_where(where) if where is not None and len(where) > 0 else where | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else where_document | |
) | |
allowed_ids = None | |
coll = self._get_collection(collection_id) | |
for embedding in query_embeddings: | |
self._validate_dimension(coll, len(embedding), update=False) | |
metadata_reader = self._manager.get_segment(collection_id, MetadataReader) | |
if where or where_document: | |
records = metadata_reader.get_metadata( | |
where=where, where_document=where_document | |
) | |
allowed_ids = [r["id"] for r in records] | |
query = t.VectorQuery( | |
vectors=query_embeddings, | |
k=n_results, | |
allowed_ids=allowed_ids, | |
include_embeddings="embeddings" in include, | |
options=None, | |
) | |
vector_reader = self._manager.get_segment(collection_id, VectorReader) | |
results = vector_reader.query_vectors(query) | |
ids: List[List[str]] = [] | |
distances: List[List[float]] = [] | |
embeddings: List[List[Embedding]] = [] | |
documents: List[List[Document]] = [] | |
uris: List[List[URI]] = [] | |
metadatas: List[List[t.Metadata]] = [] | |
for result in results: | |
ids.append([r["id"] for r in result]) | |
if "distances" in include: | |
distances.append([r["distance"] for r in result]) | |
if "embeddings" in include: | |
embeddings.append([cast(Embedding, r["embedding"]) for r in result]) | |
if "documents" in include or "metadatas" in include or "uris" in include: | |
all_ids: Set[str] = set() | |
for id_list in ids: | |
all_ids.update(id_list) | |
records = metadata_reader.get_metadata(ids=list(all_ids)) | |
metadata_by_id = {r["id"]: r["metadata"] for r in records} | |
for id_list in ids: | |
# In the segment based architecture, it is possible for one segment | |
# to have a record that another segment does not have. This results in | |
# data inconsistency. For the case of the local segments and the | |
# local segment manager, there is a case where a thread writes | |
# a record to the vector segment but not the metadata segment. | |
# Then a query'ing thread reads from the vector segment and | |
# queries the metadata segment. The metadata segment does not have | |
# the record. In this case we choose to return potentially | |
# incorrect data in the form of None. | |
metadata_list = [metadata_by_id.get(id, None) for id in id_list] | |
if "metadatas" in include: | |
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore | |
if "documents" in include: | |
doc_list = [_doc(m) for m in metadata_list] | |
documents.append(doc_list) # type: ignore | |
if "uris" in include: | |
uri_list = [_uri(m) for m in metadata_list] | |
uris.append(uri_list) # type: ignore | |
query_amount = len(query_embeddings) | |
self._product_telemetry_client.capture( | |
CollectionQueryEvent( | |
collection_uuid=str(collection_id), | |
query_amount=query_amount, | |
n_results=n_results, | |
with_metadata_filter=query_amount if where is not None else 0, | |
with_document_filter=query_amount if where_document is not None else 0, | |
include_metadatas=query_amount if "metadatas" in include else 0, | |
include_documents=query_amount if "documents" in include else 0, | |
include_uris=query_amount if "uris" in include else 0, | |
include_distances=query_amount if "distances" in include else 0, | |
) | |
) | |
return QueryResult( | |
ids=ids, | |
distances=distances if distances else None, | |
metadatas=metadatas if metadatas else None, | |
embeddings=embeddings if embeddings else None, | |
documents=documents if documents else None, | |
uris=uris if uris else None, | |
data=None, | |
) | |
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: | |
add_attributes_to_current_span({"collection_id": str(collection_id)}) | |
return self._get(collection_id, limit=n) # type: ignore | |
def get_version(self) -> str: | |
return __version__ | |
def reset_state(self) -> None: | |
self._collection_cache = {} | |
def reset(self) -> bool: | |
self._system.reset_state() | |
return True | |
def get_settings(self) -> Settings: | |
return self._settings | |
def max_batch_size(self) -> int: | |
return self._producer.max_batch_size | |
# TODO: This could potentially cause race conditions in a distributed version of the | |
# system, since the cache is only local. | |
# TODO: promote collection -> topic to a base class method so that it can be | |
# used for channel assignment in the distributed version of the system. | |
def _validate_embedding_record( | |
self, collection: t.Collection, record: t.SubmitEmbeddingRecord | |
) -> None: | |
"""Validate the dimension of an embedding record before submitting it to the system.""" | |
add_attributes_to_current_span({"collection_id": str(collection["id"])}) | |
if record["embedding"]: | |
self._validate_dimension(collection, len(record["embedding"]), update=True) | |
def _validate_dimension( | |
self, collection: t.Collection, dim: int, update: bool | |
) -> None: | |
"""Validate that a collection supports records of the given dimension. If update | |
is true, update the collection if the collection doesn't already have a | |
dimension.""" | |
if collection["dimension"] is None: | |
if update: | |
id = collection["id"] | |
self._sysdb.update_collection(id=id, dimension=dim) | |
self._collection_cache[id]["dimension"] = dim | |
elif collection["dimension"] != dim: | |
raise InvalidDimensionException( | |
f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}" | |
) | |
else: | |
return # all is well | |
def _get_collection(self, collection_id: UUID) -> t.Collection: | |
"""Read-through cache for collection data""" | |
if collection_id not in self._collection_cache: | |
collections = self._sysdb.get_collections(id=collection_id) | |
if not collections: | |
raise InvalidCollectionException( | |
f"Collection {collection_id} does not exist." | |
) | |
self._collection_cache[collection_id] = collections[0] | |
return self._collection_cache[collection_id] | |
def _records( | |
operation: t.Operation, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> Generator[t.SubmitEmbeddingRecord, None, None]: | |
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of | |
SubmitEmbeddingRecords""" | |
# Presumes that callers were invoked via Collection model, which means | |
# that we know that the embeddings, metadatas and documents have already been | |
# normalized and are guaranteed to be consistently named lists. | |
for i, id in enumerate(ids): | |
metadata = None | |
if metadatas: | |
metadata = metadatas[i] | |
if documents: | |
document = documents[i] | |
if metadata: | |
metadata = {**metadata, "chroma:document": document} | |
else: | |
metadata = {"chroma:document": document} | |
if uris: | |
uri = uris[i] | |
if metadata: | |
metadata = {**metadata, "chroma:uri": uri} | |
else: | |
metadata = {"chroma:uri": uri} | |
record = t.SubmitEmbeddingRecord( | |
id=id, | |
embedding=embeddings[i] if embeddings else None, | |
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now | |
metadata=metadata, | |
operation=operation, | |
collection_id=collection_id, | |
) | |
yield record | |
def _doc(metadata: Optional[t.Metadata]) -> Optional[str]: | |
"""Retrieve the document (if any) from a Metadata map""" | |
if metadata and "chroma:document" in metadata: | |
return str(metadata["chroma:document"]) | |
return None | |
def _uri(metadata: Optional[t.Metadata]) -> Optional[str]: | |
"""Retrieve the uri (if any) from a Metadata map""" | |
if metadata and "chroma:uri" in metadata: | |
return str(metadata["chroma:uri"]) | |
return None | |
def _clean_metadatas( | |
metadata: List[Optional[t.Metadata]], | |
) -> List[Optional[t.Metadata]]: | |
"""Remove any chroma-specific metadata keys that the client shouldn't see from a | |
list of metadata maps.""" | |
return [_clean_metadata(m) for m in metadata] | |
def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]: | |
"""Remove any chroma-specific metadata keys that the client shouldn't see from a | |
metadata map.""" | |
if not metadata: | |
return None | |
result = {} | |
for k, v in metadata.items(): | |
if not k.startswith("chroma:"): | |
result[k] = v | |
if len(result) == 0: | |
return None | |
return result | |