from abc import ABC, abstractmethod from dataclasses import dataclass import numpy as np import polars as pl DEFAULT_MATCH_VECTOR_TOPN = 10 DEFAULT_MATCH_SPARSE_TOPN = 10 VEC = list | np.ndarray @dataclass class SparseVector: indices: list[int] values: list[float] | list[int] | None = None def __post_init__(self): assert (self.values is None) or (len(self.indices) == len(self.values)) def to_dict_old(self): d = {"indices": self.indices} if self.values is not None: d["values"] = self.values return d def to_dict(self): if self.values is None: raise ValueError("SparseVector.values is None") result = {} for i, v in zip(self.indices, self.values): result[str(i)] = v return result @staticmethod def from_dict(d): return SparseVector(d["indices"], d.get("values")) def __str__(self): return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})" def __repr__(self): return str(self) class MatchTextExpr(ABC): def __init__( self, fields: list[str], matching_text: str, topn: int, extra_options: dict = dict(), ): self.fields = fields self.matching_text = matching_text self.topn = topn self.extra_options = extra_options class MatchDenseExpr(ABC): def __init__( self, vector_column_name: str, embedding_data: VEC, embedding_data_type: str, distance_type: str, topn: int = DEFAULT_MATCH_VECTOR_TOPN, extra_options: dict = dict(), ): self.vector_column_name = vector_column_name self.embedding_data = embedding_data self.embedding_data_type = embedding_data_type self.distance_type = distance_type self.topn = topn self.extra_options = extra_options class MatchSparseExpr(ABC): def __init__( self, vector_column_name: str, sparse_data: SparseVector | dict, distance_type: str, topn: int, opt_params: dict | None = None, ): self.vector_column_name = vector_column_name self.sparse_data = sparse_data self.distance_type = distance_type self.topn = topn self.opt_params = opt_params class MatchTensorExpr(ABC): def __init__( self, column_name: str, query_data: VEC, query_data_type: str, topn: int, extra_option: dict | None = None, ): self.column_name = column_name self.query_data = query_data self.query_data_type = query_data_type self.topn = topn self.extra_option = extra_option class FusionExpr(ABC): def __init__(self, method: str, topn: int, fusion_params: dict | None = None): self.method = method self.topn = topn self.fusion_params = fusion_params MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr class OrderByExpr(ABC): def __init__(self): self.fields = list() def asc(self, field: str): self.fields.append((field, 0)) return self def desc(self, field: str): self.fields.append((field, 1)) return self def fields(self): return self.fields class DocStoreConnection(ABC): """ Database operations """ @abstractmethod def dbType(self) -> str: """ Return the type of the database. """ raise NotImplementedError("Not implemented") @abstractmethod def health(self) -> dict: """ Return the health status of the database. """ raise NotImplementedError("Not implemented") """ Table operations """ @abstractmethod def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): """ Create an index with given name """ raise NotImplementedError("Not implemented") @abstractmethod def deleteIdx(self, indexName: str, knowledgebaseId: str): """ Delete an index with given name """ raise NotImplementedError("Not implemented") @abstractmethod def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: """ Check if an index with given name exists """ raise NotImplementedError("Not implemented") """ CRUD operations """ @abstractmethod def search( self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str], aggFields: list[str] = [], rank_feature: dict | None = None ) -> list[dict] | pl.DataFrame: """ Search with given conjunctive equivalent filtering condition and return all fields of matched documents """ raise NotImplementedError("Not implemented") @abstractmethod def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: """ Get single chunk with given id """ raise NotImplementedError("Not implemented") @abstractmethod def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: """ Update or insert a bulk of rows """ raise NotImplementedError("Not implemented") @abstractmethod def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: """ Update rows with given conjunctive equivalent filtering condition """ raise NotImplementedError("Not implemented") @abstractmethod def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: """ Delete rows with given conjunctive equivalent filtering condition """ raise NotImplementedError("Not implemented") """ Helper functions for search result """ @abstractmethod def getTotal(self, res): raise NotImplementedError("Not implemented") @abstractmethod def getChunkIds(self, res): raise NotImplementedError("Not implemented") @abstractmethod def getFields(self, res, fields: list[str]) -> dict[str, dict]: raise NotImplementedError("Not implemented") @abstractmethod def getHighlight(self, res, keywords: list[str], fieldnm: str): raise NotImplementedError("Not implemented") @abstractmethod def getAggregation(self, res, fieldnm: str): raise NotImplementedError("Not implemented") """ SQL """ @abstractmethod def sql(sql: str, fetch_size: int, format: str): """ Run the sql generated by text-to-sql """ raise NotImplementedError("Not implemented")