ragflow / rag /utils /doc_store_conn.py
zhichyu's picture
Integration with Infinity (#2894)
b691127
raw
history blame
6.96 kB
from abc import ABC, abstractmethod
from typing import Optional, Union
from dataclasses import dataclass
import numpy as np
import polars as pl
from typing import List, Dict
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = Union[list, np.ndarray]
@dataclass
class SparseVector:
indices: list[int]
values: Union[list[float], list[int], None] = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
def __init__(
self,
fields: str,
matching_text: str,
topn: int,
extra_options: dict = dict(),
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr(ABC):
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr(ABC):
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: Optional[dict] = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr(ABC):
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = Union[
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
]
class OrderByExpr(ABC):
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
) -> list[dict] | pl.DataFrame:
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def getTotal(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: List[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")