Spaces:
Runtime error
Runtime error
"""Milvus reader.""" | |
from typing import Any, Dict, List, Optional | |
from uuid import uuid4 | |
from gpt_index.readers.base import BaseReader | |
from gpt_index.readers.schema.base import Document | |
class MilvusReader(BaseReader): | |
"""Milvus reader.""" | |
def __init__( | |
self, | |
host: str = "localhost", | |
port: int = 19530, | |
user: str = "", | |
password: str = "", | |
use_secure: bool = False, | |
): | |
"""Initialize with parameters.""" | |
import_err_msg = ( | |
"`pymilvus` package not found, please run `pip install pymilvus`" | |
) | |
try: | |
import pymilvus # noqa: F401 | |
except ImportError: | |
raise ImportError(import_err_msg) | |
from pymilvus import MilvusException | |
self.host = host | |
self.port = port | |
self.user = user | |
self.password = password | |
self.use_secure = use_secure | |
self.collection = None | |
self.default_search_params = { | |
"IVF_FLAT": {"metric_type": "IP", "params": {"nprobe": 10}}, | |
"IVF_SQ8": {"metric_type": "IP", "params": {"nprobe": 10}}, | |
"IVF_PQ": {"metric_type": "IP", "params": {"nprobe": 10}}, | |
"HNSW": {"metric_type": "IP", "params": {"ef": 10}}, | |
"RHNSW_FLAT": {"metric_type": "IP", "params": {"ef": 10}}, | |
"RHNSW_SQ": {"metric_type": "IP", "params": {"ef": 10}}, | |
"RHNSW_PQ": {"metric_type": "IP", "params": {"ef": 10}}, | |
"IVF_HNSW": {"metric_type": "IP", "params": {"nprobe": 10, "ef": 10}}, | |
"ANNOY": {"metric_type": "IP", "params": {"search_k": 10}}, | |
"AUTOINDEX": {"metric_type": "IP", "params": {}}, | |
} | |
try: | |
self._create_connection_alias() | |
except MilvusException as e: | |
raise e | |
def load_data( | |
self, | |
query_vector: List[float], | |
collection_name: str, | |
expr: Any = None, | |
search_params: Optional[dict] = None, | |
limit: int = 10, | |
) -> List[Document]: | |
"""Load data from Milvus. | |
Args: | |
collection_name (str): Name of the Milvus collection. | |
query_vector (List[float]): Query vector. | |
limit (int): Number of results to return. | |
Returns: | |
List[Document]: A list of documents. | |
""" | |
from pymilvus import Collection, MilvusException | |
try: | |
self.collection = Collection(collection_name, using=self.alias) | |
except MilvusException as e: | |
raise e | |
assert self.collection is not None | |
try: | |
self.collection.load() | |
except MilvusException as e: | |
raise e | |
if search_params is None: | |
search_params = self._create_search_params() | |
res = self.collection.search( | |
[query_vector], | |
"embedding", | |
param=search_params, | |
expr=expr, | |
output_fields=["doc_id", "text"], | |
limit=limit, | |
) | |
documents = [] | |
# TODO: In future append embedding when more efficient | |
for hit in res[0]: | |
document = Document( | |
doc_id=hit.entity.get("doc_id"), | |
text=hit.entity.get("text"), | |
) | |
documents.append(document) | |
return documents | |
def _create_connection_alias(self) -> None: | |
from pymilvus import connections | |
self.alias = None | |
# Attempt to reuse an open connection | |
for x in connections.list_connections(): | |
addr = connections.get_connection_addr(x[0]) | |
if ( | |
x[1] | |
and ("address" in addr) | |
and (addr["address"] == "{}:{}".format(self.host, self.port)) | |
): | |
self.alias = x[0] | |
break | |
# Connect to the Milvus instance using the passed in Environment variables | |
if self.alias is None: | |
self.alias = uuid4().hex | |
connections.connect( | |
alias=self.alias, | |
host=self.host, | |
port=self.port, | |
user=self.user, # type: ignore | |
password=self.password, # type: ignore | |
secure=self.use_secure, | |
) | |
def _create_search_params(self) -> Dict[str, Any]: | |
assert self.collection is not None | |
index = self.collection.indexes[0]._index_params | |
search_params = self.default_search_params[index["index_type"]] | |
search_params["metric_type"] = index["metric_type"] | |
return search_params | |