chroma / chromadb /db /impl /sqlite.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
from chromadb.db.migrations import MigratableDB, Migration
from chromadb.config import System, Settings
import chromadb.db.base as base
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue
from chromadb.db.mixins.sysdb import SqlSysDB
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
import sqlite3
from overrides import override
import pypika
from typing import Sequence, cast, Optional, Type, Any
from typing_extensions import Literal
from types import TracebackType
import os
from uuid import UUID
from threading import local
from importlib_resources import files
from importlib_resources.abc import Traversable
class TxWrapper(base.TxWrapper):
_conn: Connection
_pool: Pool
def __init__(self, conn_pool: Pool, stack: local):
self._tx_stack = stack
self._conn = conn_pool.connect()
self._pool = conn_pool
@override
def __enter__(self) -> base.Cursor:
if len(self._tx_stack.stack) == 0:
self._conn.execute("BEGIN;")
self._tx_stack.stack.append(self)
return self._conn.cursor() # type: ignore
@override
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Literal[False]:
self._tx_stack.stack.pop()
if len(self._tx_stack.stack) == 0:
if exc_type is None:
self._conn.commit()
else:
self._conn.rollback()
self._conn.cursor().close()
self._pool.return_to_pool(self._conn)
return False
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB):
_conn_pool: Pool
_settings: Settings
_migration_imports: Sequence[Traversable]
_db_file: str
_tx_stack: local
_is_persistent: bool
def __init__(self, system: System):
self._settings = system.settings
self._migration_imports = [
files("chromadb.migrations.embeddings_queue"),
files("chromadb.migrations.sysdb"),
files("chromadb.migrations.metadb"),
]
self._is_persistent = self._settings.require("is_persistent")
self._opentelemetry_client = system.require(OpenTelemetryClient)
if not self._is_persistent:
# In order to allow sqlite to be shared between multiple threads, we need to use a
# URI connection string with shared cache.
# See https://www.sqlite.org/sharedcache.html
# https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
self._db_file = "file::memory:?cache=shared"
self._conn_pool = LockPool(self._db_file, is_uri=True)
else:
self._db_file = (
self._settings.require("persist_directory") + "/chroma.sqlite3"
)
if not os.path.exists(self._db_file):
os.makedirs(os.path.dirname(self._db_file), exist_ok=True)
self._conn_pool = PerThreadPool(self._db_file)
self._tx_stack = local()
super().__init__(system)
@trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
super().start()
with self.tx() as cur:
cur.execute("PRAGMA foreign_keys = ON")
cur.execute("PRAGMA case_sensitive_like = ON")
self.initialize_migrations()
@trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
super().stop()
self._conn_pool.close()
@staticmethod
@override
def querybuilder() -> Type[pypika.Query]:
return pypika.Query # type: ignore
@staticmethod
@override
def parameter_format() -> str:
return "?"
@staticmethod
@override
def migration_scope() -> str:
return "sqlite"
@override
def migration_dirs(self) -> Sequence[Traversable]:
return self._migration_imports
@override
def tx(self) -> TxWrapper:
if not hasattr(self._tx_stack, "stack"):
self._tx_stack.stack = []
return TxWrapper(self._conn_pool, stack=self._tx_stack)
@trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
if not self._settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
with self.tx() as cur:
# Drop all tables
cur.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table'
"""
)
for row in cur.fetchall():
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
self._conn_pool.close()
self.start()
super().reset_state()
@trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL)
@override
def setup_migrations(self) -> None:
with self.tx() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS migrations (
dir TEXT NOT NULL,
version INTEGER NOT NULL,
filename TEXT NOT NULL,
sql TEXT NOT NULL,
hash TEXT NOT NULL,
PRIMARY KEY (dir, version)
)
"""
)
@trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL)
@override
def migrations_initialized(self) -> bool:
with self.tx() as cur:
cur.execute(
"""SELECT count(*) FROM sqlite_master
WHERE type='table' AND name='migrations'"""
)
if cur.fetchone()[0] == 0:
return False
else:
return True
@trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL)
@override
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
with self.tx() as cur:
cur.execute(
"""
SELECT dir, version, filename, sql, hash
FROM migrations
WHERE dir = ?
ORDER BY version ASC
""",
(dir.name,),
)
migrations = []
for row in cur.fetchall():
found_dir = cast(str, row[0])
found_version = cast(int, row[1])
found_filename = cast(str, row[2])
found_sql = cast(str, row[3])
found_hash = cast(str, row[4])
migrations.append(
Migration(
dir=found_dir,
version=found_version,
filename=found_filename,
sql=found_sql,
hash=found_hash,
scope=self.migration_scope(),
)
)
return migrations
@override
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
cur.executescript(migration["sql"])
cur.execute(
"""
INSERT INTO migrations (dir, version, filename, sql, hash)
VALUES (?, ?, ?, ?, ?)
""",
(
migration["dir"],
migration["version"],
migration["filename"],
migration["sql"],
migration["hash"],
),
)
@staticmethod
@override
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
return UUID(value) if value is not None else None
@staticmethod
@override
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
return str(uuid) if uuid is not None else None
@staticmethod
@override
def unique_constraint_error() -> Type[BaseException]:
return sqlite3.IntegrityError