Spaces:
Sleeping
Sleeping
| from threading import Lock | |
| from chromadb.segment import ( | |
| SegmentImplementation, | |
| SegmentManager, | |
| MetadataReader, | |
| SegmentType, | |
| VectorReader, | |
| S, | |
| ) | |
| import logging | |
| from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, BasicCache,SegmentCache | |
| import os | |
| from chromadb.config import System, get_class | |
| from chromadb.db.system import SysDB | |
| from overrides import override | |
| from chromadb.segment.impl.vector.local_persistent_hnsw import ( | |
| PersistentLocalHnswSegment, | |
| ) | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryClient, | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata | |
| from typing import Dict, Type, Sequence, Optional, cast | |
| from uuid import UUID, uuid4 | |
| import platform | |
| from chromadb.utils.lru_cache import LRUCache | |
| from chromadb.utils.directory import get_directory_size | |
| if platform.system() != "Windows": | |
| import resource | |
| elif platform.system() == "Windows": | |
| import ctypes | |
| SEGMENT_TYPE_IMPLS = { | |
| SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", | |
| SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", | |
| SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", | |
| } | |
| class LocalSegmentManager(SegmentManager): | |
| _sysdb: SysDB | |
| _system: System | |
| _opentelemetry_client: OpenTelemetryClient | |
| _instances: Dict[UUID, SegmentImplementation] | |
| _vector_instances_file_handle_cache: LRUCache[ | |
| UUID, PersistentLocalHnswSegment | |
| ] # LRU cache to manage file handles across vector segment instances | |
| _vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY | |
| _lock: Lock | |
| _max_file_handles: int | |
| def __init__(self, system: System): | |
| super().__init__(system) | |
| self._sysdb = self.require(SysDB) | |
| self._system = system | |
| self._opentelemetry_client = system.require(OpenTelemetryClient) | |
| self.logger = logging.getLogger(__name__) | |
| self._instances = {} | |
| self.segment_cache: Dict[SegmentScope, SegmentCache] = {SegmentScope.METADATA: BasicCache()} | |
| if system.settings.chroma_segment_cache_policy == "LRU" and system.settings.chroma_memory_limit_bytes > 0: | |
| self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(capacity=system.settings.chroma_memory_limit_bytes,callback=lambda k, v: self.callback_cache_evict(v), size_func=lambda k: self._get_segment_disk_size(k)) | |
| else: | |
| self.segment_cache[SegmentScope.VECTOR] = BasicCache() | |
| self._lock = Lock() | |
| # TODO: prototyping with distributed segment for now, but this should be a configurable option | |
| # we need to think about how to handle this configuration | |
| if self._system.settings.require("is_persistent"): | |
| self._vector_segment_type = SegmentType.HNSW_LOCAL_PERSISTED | |
| if platform.system() != "Windows": | |
| self._max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0] | |
| else: | |
| self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore | |
| segment_limit = ( | |
| self._max_file_handles | |
| // PersistentLocalHnswSegment.get_file_handle_count() | |
| ) | |
| self._vector_instances_file_handle_cache = LRUCache( | |
| segment_limit, callback=lambda _, v: v.close_persistent_index() | |
| ) | |
| def callback_cache_evict(self, segment: Segment): | |
| collection_id = segment["collection"] | |
| self.logger.info(f"LRU cache evict collection {collection_id}") | |
| instance = self._instance(segment) | |
| instance.stop() | |
| del self._instances[segment["id"]] | |
| def start(self) -> None: | |
| for instance in self._instances.values(): | |
| instance.start() | |
| super().start() | |
| def stop(self) -> None: | |
| for instance in self._instances.values(): | |
| instance.stop() | |
| super().stop() | |
| def reset_state(self) -> None: | |
| for instance in self._instances.values(): | |
| instance.stop() | |
| instance.reset_state() | |
| self._instances = {} | |
| self.segment_cache[SegmentScope.VECTOR].reset() | |
| super().reset_state() | |
| def create_segments(self, collection: Collection) -> Sequence[Segment]: | |
| vector_segment = _segment( | |
| self._vector_segment_type, SegmentScope.VECTOR, collection | |
| ) | |
| metadata_segment = _segment( | |
| SegmentType.SQLITE, SegmentScope.METADATA, collection | |
| ) | |
| return [vector_segment, metadata_segment] | |
| def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: | |
| segments = self._sysdb.get_segments(collection=collection_id) | |
| for segment in segments: | |
| if segment["id"] in self._instances: | |
| if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value: | |
| instance = self.get_segment(collection_id, VectorReader) | |
| instance.delete() | |
| elif segment["type"] == SegmentType.SQLITE.value: | |
| instance = self.get_segment(collection_id, MetadataReader) | |
| instance.delete() | |
| del self._instances[segment["id"]] | |
| if segment["scope"] is SegmentScope.VECTOR: | |
| self.segment_cache[SegmentScope.VECTOR].pop(collection_id) | |
| if segment["scope"] is SegmentScope.METADATA: | |
| self.segment_cache[SegmentScope.METADATA].pop(collection_id) | |
| return [s["id"] for s in segments] | |
| def _get_segment_disk_size(self, collection_id: UUID) -> int: | |
| segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR) | |
| if len(segments) == 0: | |
| return 0 | |
| # With local segment manager (single server chroma), a collection always have one segment. | |
| size = get_directory_size( | |
| os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"]))) | |
| return size | |
| def _get_segment_sysdb(self, collection_id:UUID, scope: SegmentScope): | |
| segments = self._sysdb.get_segments(collection=collection_id, scope=scope) | |
| known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) | |
| # Get the first segment of a known type | |
| segment = next(filter(lambda s: s["type"] in known_types, segments)) | |
| return segment | |
| def get_segment(self, collection_id: UUID, type: Type[S]) -> S: | |
| if type == MetadataReader: | |
| scope = SegmentScope.METADATA | |
| elif type == VectorReader: | |
| scope = SegmentScope.VECTOR | |
| else: | |
| raise ValueError(f"Invalid segment type: {type}") | |
| segment = self.segment_cache[scope].get(collection_id) | |
| if segment is None: | |
| segment = self._get_segment_sysdb(collection_id, scope) | |
| self.segment_cache[scope].set(collection_id, segment) | |
| # Instances must be atomically created, so we use a lock to ensure that only one thread | |
| # creates the instance. | |
| with self._lock: | |
| instance = self._instance(segment) | |
| return cast(S, instance) | |
| def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: | |
| # The local segment manager responds to hints by pre-loading both the metadata and vector | |
| # segments for the given collection. | |
| for type in [MetadataReader, VectorReader]: | |
| # Just use get_segment to load the segment into the cache | |
| instance = self.get_segment(collection_id, type) | |
| # If the segment is a vector segment, we need to keep segments in an LRU cache | |
| # to avoid hitting the OS file handle limit. | |
| if type == VectorReader and self._system.settings.require("is_persistent"): | |
| instance = cast(PersistentLocalHnswSegment, instance) | |
| instance.open_persistent_index() | |
| self._vector_instances_file_handle_cache.set(collection_id, instance) | |
| def _cls(self, segment: Segment) -> Type[SegmentImplementation]: | |
| classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] | |
| cls = get_class(classname, SegmentImplementation) | |
| return cls | |
| def _instance(self, segment: Segment) -> SegmentImplementation: | |
| if segment["id"] not in self._instances: | |
| cls = self._cls(segment) | |
| instance = cls(self._system, segment) | |
| instance.start() | |
| self._instances[segment["id"]] = instance | |
| return self._instances[segment["id"]] | |
| def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment: | |
| """Create a metadata dict, propagating metadata correctly for the given segment type.""" | |
| cls = get_class(SEGMENT_TYPE_IMPLS[type], SegmentImplementation) | |
| collection_metadata = collection.get("metadata", None) | |
| metadata: Optional[Metadata] = None | |
| if collection_metadata: | |
| metadata = cls.propagate_collection_metadata(collection_metadata) | |
| return Segment( | |
| id=uuid4(), | |
| type=type.value, | |
| scope=scope, | |
| topic=collection["topic"], | |
| collection=collection["id"], | |
| metadata=metadata | |
| ) | |