Spaces:
Sleeping
Sleeping
from typing import Any, Callable, Dict, List, Sequence, Optional | |
import fastapi | |
from fastapi import FastAPI as _FastAPI, Response | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.routing import APIRoute | |
from fastapi import HTTPException, status | |
from uuid import UUID | |
from chromadb.api.models.Collection import Collection | |
from chromadb.api.types import GetResult, QueryResult | |
from chromadb.auth import ( | |
AuthzDynamicParams, | |
AuthzResourceActions, | |
AuthzResourceTypes, | |
DynamicAuthzResource, | |
) | |
from chromadb.auth.fastapi import ( | |
FastAPIChromaAuthMiddleware, | |
FastAPIChromaAuthMiddlewareWrapper, | |
FastAPIChromaAuthzMiddleware, | |
FastAPIChromaAuthzMiddlewareWrapper, | |
authz_context, | |
set_overwrite_singleton_tenant_database_access_from_auth, | |
) | |
from chromadb.auth.fastapi_utils import ( | |
attr_from_collection_lookup, | |
attr_from_resource_object, | |
) | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
import chromadb.api | |
from chromadb.api import ServerAPI | |
from chromadb.errors import ( | |
ChromaError, | |
InvalidDimensionException, | |
InvalidHTTPVersion, | |
) | |
from chromadb.server.fastapi.types import ( | |
AddEmbedding, | |
CreateDatabase, | |
CreateTenant, | |
DeleteEmbedding, | |
GetEmbedding, | |
QueryEmbedding, | |
CreateCollection, | |
UpdateCollection, | |
UpdateEmbedding, | |
) | |
from starlette.requests import Request | |
import logging | |
from chromadb.server.fastapi.utils import fastapi_json_response, string_to_uuid as _uuid | |
from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi | |
from chromadb.types import Database, Tenant | |
from chromadb.telemetry.product import ServerContext, ProductTelemetryClient | |
from chromadb.telemetry.opentelemetry import ( | |
OpenTelemetryClient, | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
logger = logging.getLogger(__name__) | |
def use_route_names_as_operation_ids(app: _FastAPI) -> None: | |
""" | |
Simplify operation IDs so that generated API clients have simpler function | |
names. | |
Should be called only after all routes have been added. | |
""" | |
for route in app.routes: | |
if isinstance(route, APIRoute): | |
route.operation_id = route.name | |
async def catch_exceptions_middleware( | |
request: Request, call_next: Callable[[Request], Any] | |
) -> Response: | |
try: | |
return await call_next(request) | |
except ChromaError as e: | |
return fastapi_json_response(e) | |
except Exception as e: | |
logger.exception(e) | |
return JSONResponse(content={"error": repr(e)}, status_code=500) | |
async def check_http_version_middleware( | |
request: Request, call_next: Callable[[Request], Any] | |
) -> Response: | |
http_version = request.scope.get("http_version") | |
if http_version not in ["1.1", "2"]: | |
raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported") | |
return await call_next(request) | |
class ChromaAPIRouter(fastapi.APIRouter): # type: ignore | |
# A simple subclass of fastapi's APIRouter which treats URLs with a trailing "/" the | |
# same as URLs without. Docs will only contain URLs without trailing "/"s. | |
def add_api_route(self, path: str, *args: Any, **kwargs: Any) -> None: | |
# If kwargs["include_in_schema"] isn't passed OR is True, we should only | |
# include the non-"/" path. If kwargs["include_in_schema"] is False, include | |
# neither. | |
exclude_from_schema = ( | |
"include_in_schema" in kwargs and not kwargs["include_in_schema"] | |
) | |
def include_in_schema(path: str) -> bool: | |
nonlocal exclude_from_schema | |
return not exclude_from_schema and not path.endswith("/") | |
kwargs["include_in_schema"] = include_in_schema(path) | |
super().add_api_route(path, *args, **kwargs) | |
if path.endswith("/"): | |
path = path[:-1] | |
else: | |
path = path + "/" | |
kwargs["include_in_schema"] = include_in_schema(path) | |
super().add_api_route(path, *args, **kwargs) | |
class FastAPI(chromadb.server.Server): | |
def __init__(self, settings: Settings): | |
super().__init__(settings) | |
ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI | |
self._app = fastapi.FastAPI(debug=True) | |
self._system = System(settings) | |
self._api: ServerAPI = self._system.instance(ServerAPI) | |
self._opentelemetry_client = self._api.require(OpenTelemetryClient) | |
self._system.start() | |
self._app.middleware("http")(check_http_version_middleware) | |
self._app.middleware("http")(catch_exceptions_middleware) | |
self._app.add_middleware( | |
CORSMiddleware, | |
allow_headers=["*"], | |
allow_origins=settings.chroma_server_cors_allow_origins, | |
allow_methods=["*"], | |
) | |
self._app.on_event("shutdown")(self.shutdown) | |
if settings.chroma_server_authz_provider: | |
self._app.add_middleware( | |
FastAPIChromaAuthzMiddlewareWrapper, | |
authz_middleware=self._api.require(FastAPIChromaAuthzMiddleware), | |
) | |
if settings.chroma_server_auth_provider: | |
self._app.add_middleware( | |
FastAPIChromaAuthMiddlewareWrapper, | |
auth_middleware=self._api.require(FastAPIChromaAuthMiddleware), | |
) | |
set_overwrite_singleton_tenant_database_access_from_auth( | |
settings.chroma_overwrite_singleton_tenant_database_access_from_auth | |
) | |
self.router = ChromaAPIRouter() | |
self.router.add_api_route("/api/v1", self.root, methods=["GET"]) | |
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) | |
self.router.add_api_route("/api/v1/version", self.version, methods=["GET"]) | |
self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"]) | |
self.router.add_api_route( | |
"/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] | |
) | |
self.router.add_api_route( | |
"/api/v1/databases", | |
self.create_database, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/databases/{database}", | |
self.get_database, | |
methods=["GET"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/tenants", | |
self.create_tenant, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/tenants/{tenant}", | |
self.get_tenant, | |
methods=["GET"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections", | |
self.list_collections, | |
methods=["GET"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/count_collections", | |
self.count_collections, | |
methods=["GET"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections", | |
self.create_collection, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/add", | |
self.add, | |
methods=["POST"], | |
status_code=status.HTTP_201_CREATED, | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/update", | |
self.update, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/upsert", | |
self.upsert, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/get", | |
self.get, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/delete", | |
self.delete, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/count", | |
self.count, | |
methods=["GET"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}/query", | |
self.get_nearest_neighbors, | |
methods=["POST"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_name}", | |
self.get_collection, | |
methods=["GET"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_id}", | |
self.update_collection, | |
methods=["PUT"], | |
response_model=None, | |
) | |
self.router.add_api_route( | |
"/api/v1/collections/{collection_name}", | |
self.delete_collection, | |
methods=["DELETE"], | |
response_model=None, | |
) | |
self._app.include_router(self.router) | |
use_route_names_as_operation_ids(self._app) | |
instrument_fastapi(self._app) | |
def shutdown(self) -> None: | |
self._system.stop() | |
def app(self) -> fastapi.FastAPI: | |
return self._app | |
def root(self) -> Dict[str, int]: | |
return {"nanosecond heartbeat": self._api.heartbeat()} | |
def heartbeat(self) -> Dict[str, int]: | |
return self.root() | |
def version(self) -> str: | |
return self._api.get_version() | |
def create_database( | |
self, database: CreateDatabase, tenant: str = DEFAULT_TENANT | |
) -> None: | |
return self._api.create_database(database.name, tenant) | |
def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: | |
return self._api.get_database(database, tenant) | |
def create_tenant(self, tenant: CreateTenant) -> None: | |
return self._api.create_tenant(tenant.name) | |
def get_tenant(self, tenant: str) -> Tenant: | |
return self._api.get_tenant(tenant) | |
def list_collections( | |
self, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Sequence[Collection]: | |
return self._api.list_collections( | |
limit=limit, offset=offset, tenant=tenant, database=database | |
) | |
def count_collections( | |
self, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> int: | |
return self._api.count_collections(tenant=tenant, database=database) | |
def create_collection( | |
self, | |
collection: CreateCollection, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Collection: | |
return self._api.create_collection( | |
name=collection.name, | |
metadata=collection.metadata, | |
get_or_create=collection.get_or_create, | |
tenant=tenant, | |
database=database, | |
) | |
def get_collection( | |
self, | |
collection_name: str, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Collection: | |
return self._api.get_collection( | |
collection_name, tenant=tenant, database=database | |
) | |
def update_collection( | |
self, collection_id: str, collection: UpdateCollection | |
) -> None: | |
return self._api._modify( | |
id=_uuid(collection_id), | |
new_name=collection.new_name, | |
new_metadata=collection.new_metadata, | |
) | |
def delete_collection( | |
self, | |
collection_name: str, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> None: | |
return self._api.delete_collection( | |
collection_name, tenant=tenant, database=database | |
) | |
def add(self, collection_id: str, add: AddEmbedding) -> None: | |
try: | |
result = self._api._add( | |
collection_id=_uuid(collection_id), | |
embeddings=add.embeddings, # type: ignore | |
metadatas=add.metadatas, # type: ignore | |
documents=add.documents, # type: ignore | |
uris=add.uris, # type: ignore | |
ids=add.ids, | |
) | |
except InvalidDimensionException as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
return result # type: ignore | |
def update(self, collection_id: str, add: UpdateEmbedding) -> None: | |
self._api._update( | |
ids=add.ids, | |
collection_id=_uuid(collection_id), | |
embeddings=add.embeddings, | |
documents=add.documents, # type: ignore | |
uris=add.uris, # type: ignore | |
metadatas=add.metadatas, # type: ignore | |
) | |
def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: | |
self._api._upsert( | |
collection_id=_uuid(collection_id), | |
ids=upsert.ids, | |
embeddings=upsert.embeddings, # type: ignore | |
documents=upsert.documents, # type: ignore | |
uris=upsert.uris, # type: ignore | |
metadatas=upsert.metadatas, # type: ignore | |
) | |
def get(self, collection_id: str, get: GetEmbedding) -> GetResult: | |
return self._api._get( | |
collection_id=_uuid(collection_id), | |
ids=get.ids, | |
where=get.where, | |
where_document=get.where_document, | |
sort=get.sort, | |
limit=get.limit, | |
offset=get.offset, | |
include=get.include, | |
) | |
def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: | |
return self._api._delete( | |
where=delete.where, # type: ignore | |
ids=delete.ids, | |
collection_id=_uuid(collection_id), | |
where_document=delete.where_document, | |
) | |
def count(self, collection_id: str) -> int: | |
return self._api._count(_uuid(collection_id)) | |
def reset(self) -> bool: | |
return self._api.reset() | |
def get_nearest_neighbors( | |
self, collection_id: str, query: QueryEmbedding | |
) -> QueryResult: | |
nnresult = self._api._query( | |
collection_id=_uuid(collection_id), | |
where=query.where, # type: ignore | |
where_document=query.where_document, # type: ignore | |
query_embeddings=query.query_embeddings, | |
n_results=query.n_results, | |
include=query.include, | |
) | |
return nnresult | |
def pre_flight_checks(self) -> Dict[str, Any]: | |
return { | |
"max_batch_size": self._api.max_batch_size, | |
} | |