|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
import numpy as np |
|
import polars as pl |
|
|
|
DEFAULT_MATCH_VECTOR_TOPN = 10 |
|
DEFAULT_MATCH_SPARSE_TOPN = 10 |
|
VEC = list | np.ndarray |
|
|
|
|
|
@dataclass |
|
class SparseVector: |
|
indices: list[int] |
|
values: list[float] | list[int] | None = None |
|
|
|
def __post_init__(self): |
|
assert (self.values is None) or (len(self.indices) == len(self.values)) |
|
|
|
def to_dict_old(self): |
|
d = {"indices": self.indices} |
|
if self.values is not None: |
|
d["values"] = self.values |
|
return d |
|
|
|
def to_dict(self): |
|
if self.values is None: |
|
raise ValueError("SparseVector.values is None") |
|
result = {} |
|
for i, v in zip(self.indices, self.values): |
|
result[str(i)] = v |
|
return result |
|
|
|
@staticmethod |
|
def from_dict(d): |
|
return SparseVector(d["indices"], d.get("values")) |
|
|
|
def __str__(self): |
|
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})" |
|
|
|
def __repr__(self): |
|
return str(self) |
|
|
|
|
|
class MatchTextExpr(ABC): |
|
def __init__( |
|
self, |
|
fields: list[str], |
|
matching_text: str, |
|
topn: int, |
|
extra_options: dict = dict(), |
|
): |
|
self.fields = fields |
|
self.matching_text = matching_text |
|
self.topn = topn |
|
self.extra_options = extra_options |
|
|
|
|
|
class MatchDenseExpr(ABC): |
|
def __init__( |
|
self, |
|
vector_column_name: str, |
|
embedding_data: VEC, |
|
embedding_data_type: str, |
|
distance_type: str, |
|
topn: int = DEFAULT_MATCH_VECTOR_TOPN, |
|
extra_options: dict = dict(), |
|
): |
|
self.vector_column_name = vector_column_name |
|
self.embedding_data = embedding_data |
|
self.embedding_data_type = embedding_data_type |
|
self.distance_type = distance_type |
|
self.topn = topn |
|
self.extra_options = extra_options |
|
|
|
|
|
class MatchSparseExpr(ABC): |
|
def __init__( |
|
self, |
|
vector_column_name: str, |
|
sparse_data: SparseVector | dict, |
|
distance_type: str, |
|
topn: int, |
|
opt_params: dict | None = None, |
|
): |
|
self.vector_column_name = vector_column_name |
|
self.sparse_data = sparse_data |
|
self.distance_type = distance_type |
|
self.topn = topn |
|
self.opt_params = opt_params |
|
|
|
|
|
class MatchTensorExpr(ABC): |
|
def __init__( |
|
self, |
|
column_name: str, |
|
query_data: VEC, |
|
query_data_type: str, |
|
topn: int, |
|
extra_option: dict | None = None, |
|
): |
|
self.column_name = column_name |
|
self.query_data = query_data |
|
self.query_data_type = query_data_type |
|
self.topn = topn |
|
self.extra_option = extra_option |
|
|
|
|
|
class FusionExpr(ABC): |
|
def __init__(self, method: str, topn: int, fusion_params: dict | None = None): |
|
self.method = method |
|
self.topn = topn |
|
self.fusion_params = fusion_params |
|
|
|
|
|
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr |
|
|
|
class OrderByExpr(ABC): |
|
def __init__(self): |
|
self.fields = list() |
|
def asc(self, field: str): |
|
self.fields.append((field, 0)) |
|
return self |
|
def desc(self, field: str): |
|
self.fields.append((field, 1)) |
|
return self |
|
def fields(self): |
|
return self.fields |
|
|
|
class DocStoreConnection(ABC): |
|
""" |
|
Database operations |
|
""" |
|
|
|
@abstractmethod |
|
def dbType(self) -> str: |
|
""" |
|
Return the type of the database. |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def health(self) -> dict: |
|
""" |
|
Return the health status of the database. |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
""" |
|
Table operations |
|
""" |
|
|
|
@abstractmethod |
|
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): |
|
""" |
|
Create an index with given name |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def deleteIdx(self, indexName: str, knowledgebaseId: str): |
|
""" |
|
Delete an index with given name |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: |
|
""" |
|
Check if an index with given name exists |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
""" |
|
CRUD operations |
|
""" |
|
|
|
@abstractmethod |
|
def search( |
|
self, selectFields: list[str], |
|
highlightFields: list[str], |
|
condition: dict, |
|
matchExprs: list[MatchExpr], |
|
orderBy: OrderByExpr, |
|
offset: int, |
|
limit: int, |
|
indexNames: str|list[str], |
|
knowledgebaseIds: list[str], |
|
aggFields: list[str] = [], |
|
rank_feature: dict | None = None |
|
) -> list[dict] | pl.DataFrame: |
|
""" |
|
Search with given conjunctive equivalent filtering condition and return all fields of matched documents |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: |
|
""" |
|
Get single chunk with given id |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: |
|
""" |
|
Update or insert a bulk of rows |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: |
|
""" |
|
Update rows with given conjunctive equivalent filtering condition |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: |
|
""" |
|
Delete rows with given conjunctive equivalent filtering condition |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
""" |
|
Helper functions for search result |
|
""" |
|
|
|
@abstractmethod |
|
def getTotal(self, res): |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def getChunkIds(self, res): |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def getFields(self, res, fields: list[str]) -> dict[str, dict]: |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def getHighlight(self, res, keywords: list[str], fieldnm: str): |
|
raise NotImplementedError("Not implemented") |
|
|
|
@abstractmethod |
|
def getAggregation(self, res, fieldnm: str): |
|
raise NotImplementedError("Not implemented") |
|
|
|
""" |
|
SQL |
|
""" |
|
@abstractmethod |
|
def sql(sql: str, fetch_size: int, format: str): |
|
""" |
|
Run the sql generated by text-to-sql |
|
""" |
|
raise NotImplementedError("Not implemented") |
|
|