binhnase04854's picture
first deploy
b699122
raw
history blame
15.8 kB
"""Milvus vector store index.
An index that is built within Milvus.
"""
import logging
from typing import Any, Dict, List, Optional
from uuid import uuid4
from gpt_index.data_structs.node_v2 import DocumentRelationship, Node
from gpt_index.vector_stores.types import (
NodeEmbeddingResult,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
logger = logging.getLogger(__name__)
class MilvusVectorStore(VectorStore):
"""The Milvus Vector Store.
In this vector store we store the text, its embedding and
a few pieces of its metadata in a Milvus collection. This implemnetation
allows the use of an already existing collection if it is one that was created
this vector store. It also supports creating a new one if the collection doesnt
exist or if `overwrite` is set to True.
Args:
collection_name (str, optional): The name of the collection where data will be
stored. Defaults to "llamalection".
index_params (dict, optional): The index parameters for Milvus, if none are
provided an HNSW index will be used. Defaults to None.
search_params (dict, optional): The search parameters for a Milvus query.
If none are provided, default params will be generated. Defaults to None.
dim (int, optional): The dimension of the embeddings. If it is not provided,
collection creation will be done on first insert. Defaults to None.
host (str, optional): The host address of Milvus. Defaults to "localhost".
port (int, optional): The port of Milvus. Defaults to 19530.
user (str, optional): The username for RBAC. Defaults to "".
password (str, optional): The password for RBAC. Defaults to "".
use_secure (bool, optional): Use https. Defaults to False.
overwrite (bool, optional): Whether to overwrite existing collection with same
name. Defaults to False.
Raises:
ImportError: Unable to import `pymilvus`.
MilvusException: Error communicating with Milvus, more can be found in logging
under Debug.
Returns:
MilvusVectorstore: Vectorstore that supports add, delete, and query.
"""
stores_text: bool = True
stores_node: bool = False
def __init__(
self,
collection_name: str = "llamalection",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
dim: Optional[int] = None,
host: str = "localhost",
port: int = 19530,
user: str = "",
password: str = "",
use_secure: bool = False,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Init params."""
import_err_msg = (
"`pymilvus` package not found, please run `pip install pymilvus`"
)
try:
import pymilvus # noqa: F401
except ImportError:
raise ImportError(import_err_msg)
from pymilvus import Collection, MilvusException, utility
self.collection_name = collection_name
self.search_params = search_params
self.index_params = index_params
self.dim = dim
self.host = host
self.port = port
self.user = user
self.password = password
self.use_secure = use_secure
self.overwrite = overwrite
self.default_search_params = {
"IVF_FLAT": {"metric_type": "IP", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "IP", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "IP", "params": {"nprobe": 10}},
"HNSW": {"metric_type": "IP", "params": {"ef": 10}},
"RHNSW_FLAT": {"metric_type": "IP", "params": {"ef": 10}},
"RHNSW_SQ": {"metric_type": "IP", "params": {"ef": 10}},
"RHNSW_PQ": {"metric_type": "IP", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "IP", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "IP", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "IP", "params": {}},
}
# Generate a connection alias
self._create_connection_alias()
# Figure out if there is already a created collection
if utility.has_collection(self.collection_name, using=self.alias):
self.collection = Collection(
self.collection_name, using=self.alias, consistency_level="Strong"
)
else:
self.collection = None
# If a collection already exists and we are overwriting, delete it
if self.collection is not None and self.overwrite is True:
try:
utility.drop_collection(self.collection_name, using=self.alias)
self.collection = None
logger.debug(
f"Successfully removed old collection: {self.collection_name}"
)
except MilvusException as e:
logger.debug(f"Failed to remove old collection: {self.collection_name}")
raise e
# If there is no collection and a dim is provided, we can create a collection
if self.collection is None and self.dim is not None:
self._create_collection()
# If there is a collection and no index exists on it, create an index
if self.collection is not None and len(self.collection.indexes) == 0:
self._create_index()
# If using an existing index and no search params were provided,
# generate the correct params
elif self.collection is not None and self.search_params is None:
self._create_search_params()
# If there is a collection with an index, make sure its loaded
if self.collection is not None and len(self.collection.indexes) != 0:
self.collection.load()
def _create_connection_alias(self) -> None:
from pymilvus import connections
self.alias = None
# Attempt to reuse an open connection
for x in connections.list_connections():
addr = connections.get_connection_addr(x[0])
if (
x[1]
and ("address" in addr)
and (addr["address"] == "{}:{}".format(self.host, self.port))
):
self.alias = x[0]
logger.debug(f"Using previous connection: {self.alias}")
break
# Connect to the Milvus instance using the passed in Environment variables
if self.alias is None:
self.alias = uuid4().hex
connections.connect(
alias=self.alias,
host=self.host,
port=self.port,
user=self.user, # type: ignore
password=self.password, # type: ignore
secure=self.use_secure,
)
logger.debug(f"Creating new connection: {self.alias}")
def _create_collection(self) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
try:
fields = [
FieldSchema(
name="id",
dtype=DataType.VARCHAR,
description="Unique ID",
is_primary=True,
auto_id=False,
max_length=65535,
),
FieldSchema(
name="doc_id",
dtype=DataType.VARCHAR,
description="Source document ID",
max_length=65535,
),
FieldSchema(
name="text",
dtype=DataType.VARCHAR,
description="The embedding vector",
max_length=65535,
),
FieldSchema(
name="embedding",
dtype=DataType.FLOAT_VECTOR,
description="The embedding vector",
dim=self.dim,
),
]
col_schema = CollectionSchema(fields=fields)
self.collection = Collection(
self.collection_name,
col_schema,
using=self.alias,
consistency_level="Strong",
)
logger.debug(
f"Successfully created a new collection: {self.collection_name}"
)
except MilvusException as e:
logger.debug(f"Failure to create a new collection: {self.collection_name}")
raise e
def _create_index(self) -> None:
from pymilvus import MilvusException
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
assert self.index_params is not None
self.collection.create_index(
"embedding", index_params=self.index_params, using=self.alias
)
# If search params dont exist already, grab the default
if self.search_params is None:
self.search_params = self.default_search_params[
self.index_params["index_type"]
]
logger.debug(
f"Successfully created an index on collection: {self.collection_name}"
)
except MilvusException as e:
logger.debug(
f"Failed to create an index on collection: {self.collection_name}"
)
raise e
def _create_search_params(self) -> None:
index = self.collection.indexes[0]._index_params
self.search_params = self.default_search_params[index["index_type"]]
self.search_params["metric_type"] = index["metric_type"]
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "VectorStore":
return cls(**config_dict)
@property
def client(self) -> Any:
"""Get client."""
return self.collection
@property
def config_dict(self) -> dict:
"""Return config dict."""
return {
"collection_name": self.collection_name,
"index_params": self.index_params,
"search_params": self.search_params,
"dim": self.dim,
"host": self.host,
"port": self.port,
"user": self.user,
"password": self.password,
"use_secure": self.use_secure,
# Set to false, dont want subsequent object to rewrite store
# "overwrite": False,
}
def add(self, embedding_results: List[NodeEmbeddingResult]) -> List[str]:
"""Add the embeddings and their nodes into Milvus.
Args:
embedding_results (List[NodeEmbeddingResult]): The embeddings and their data
to insert.
Raises:
MilvusException: Failed to insert data.
Returns:
List[str]: List of ids inserted.
"""
from pymilvus import MilvusException
# If the collection doesnt exist yet, create the collection, index, and load it
if self.collection is None and len(embedding_results) != 0:
self.dim = len(embedding_results[0].embedding)
self._create_collection()
self._create_index()
assert self.collection is not None
self.collection.load()
ids = []
doc_ids = []
texts = []
embeddings = []
# Process that data we are going to insert
for result in embedding_results:
ids.append(result.id)
doc_ids.append(result.doc_id)
texts.append(result.node.get_text())
embeddings.append(result.embedding)
try:
# Insert the data into milvus
self.collection.insert([ids, doc_ids, texts, embeddings])
logger.debug(
f"Successfully inserted embeddings into: {self.collection_name} "
f"Num Inserted: {len(ids)}"
)
except MilvusException as e:
logger.debug(f"Failed to insert embeddings into: {self.collection_name}")
raise e
return ids
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document from Milvus.
Args:
doc_id (str): The document id to delete.
Raises:
MilvusException: Failed to delete the doc.
"""
from pymilvus import MilvusException
if self.collection is None:
return
# Adds ability for multiple doc delete in future.
doc_ids: List[str]
if type(doc_id) != list:
doc_ids = [doc_id]
else:
doc_ids = doc_id # type: ignore
try:
# Begin by querying for the primary keys to delete
doc_ids = ['"' + entry + '"' for entry in doc_ids]
entries = self.collection.query(f"doc_id in [{','.join(doc_ids)}]")
ids = [entry["id"] for entry in entries]
ids = ['"' + entry + '"' for entry in ids]
self.collection.delete(f"id in [{','.join(ids)}]")
logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")
except MilvusException as e:
logger.debug(f"Unsuccessfully deleted embedding with doc_id: {doc_ids}")
raise e
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
doc_ids (Optional[List[str]]): list of doc_ids to filter by
"""
from pymilvus import MilvusException
if self.collection is None:
raise ValueError("Milvus instance not initialized.")
if query.mode != VectorStoreQueryMode.DEFAULT:
raise ValueError(f"Milvus does not support {query.mode} yet.")
expr: Optional[str] = None
if query.doc_ids is not None and len(query.doc_ids) != 0:
expr_list = ['"' + entry + '"' for entry in query.doc_ids]
expr = f"doc_id in [{','.join(expr_list)}]"
try:
res = self.collection.search(
[query.query_embedding],
"embedding",
self.search_params,
limit=query.similarity_top_k,
output_fields=["doc_id", "text"],
expr=expr,
)
logger.debug(
f"Successfully searched embedding in collection: {self.collection_name}"
f" Num Results: {len(res[0])}"
)
except MilvusException as e:
logger.debug(
f"Unsuccessfully searched embedding in collection: "
f"{self.collection_name}"
)
raise e
nodes = []
similarities = []
ids = []
for hit in res[0]:
node = Node(
doc_id=hit.id,
text=hit.entity.get("text"),
relationships={
DocumentRelationship.SOURCE: hit.entity.get("doc_id"),
},
)
nodes.append(node)
similarities.append(hit.score)
ids.append(hit.id)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)