|
import logging |
|
import os |
|
import platform |
|
from typing import Literal |
|
|
|
import chromadb |
|
from dotenv import load_dotenv |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
from chromadb.config import Settings |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv() |
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv() |
|
|
|
CHROMA_CLIENT = None |
|
|
|
|
|
def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]: |
|
path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/" |
|
logger.info(f"Using save path: {path}") |
|
return path |
|
|
|
|
|
def get_chroma_client(): |
|
global CHROMA_CLIENT |
|
if CHROMA_CLIENT is None: |
|
SAVE_PATH = get_save_path() |
|
CHROMA_CLIENT = chromadb.PersistentClient( |
|
path=SAVE_PATH, |
|
settings=Settings(anonymized_telemetry=False, is_persistent=True), |
|
) |
|
return CHROMA_CLIENT |
|
|
|
|
|
def get_collection(chroma_client, embedding_function, collection_name): |
|
logger.info(f"Getting or creating collection: {collection_name}") |
|
return chroma_client.create_collection( |
|
name=collection_name, |
|
get_or_create=True, |
|
embedding_function=embedding_function, |
|
metadata={"hnsw:space": "cosine"}, |
|
) |
|
|