韩宇
init
1b7e88c
import base64
import pickle
from typing import Any, Iterable, List, Optional, Tuple
from omagent_core.memories.ltms.ltm_base import LTMBase
from omagent_core.services.connectors.milvus import MilvusConnector
from omagent_core.utils.registry import registry
from pydantic import Field
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
utility)
@registry.register_component()
class VideoMilvusLTM(LTMBase):
milvus_ltm_client: MilvusConnector
storage_name: str = Field(default="default")
dim: int = Field(default=128)
def model_post_init(self, __context: Any) -> None:
pass
def _create_collection(self) -> None:
# Check if collection exists
if not self.milvus_ltm_client._client.has_collection(self.storage_name):
index_params = self.milvus_ltm_client._client.prepare_index_params()
# Define field schemas
key_field = FieldSchema(
name="key", dtype=DataType.VARCHAR, is_primary=True, max_length=256
)
value_field = FieldSchema(
name="value", dtype=DataType.JSON, description="Json value"
)
embedding_field = FieldSchema(
name="embedding",
dtype=DataType.FLOAT_VECTOR,
description="Embedding vector",
dim=self.dim,
)
index_params = self.milvus_ltm_client._client.prepare_index_params()
# Create collection schema
schema = CollectionSchema(
fields=[key_field, value_field, embedding_field],
description="Key-Value storage with embeddings",
)
for field in schema.fields:
if (
field.dtype == DataType.FLOAT_VECTOR
or field.dtype == DataType.BINARY_VECTOR
):
index_params.add_index(
field_name=field.name,
index_name=field.name,
index_type="FLAT",
metric_type="COSINE",
params={"nlist": 128},
)
self.milvus_ltm_client._client.create_collection(
self.storage_name, schema=schema, index_params=index_params
)
# Create index separately after collection creation
print(f"Created storage {self.storage_name} successfully")
def __getitem__(self, key: Any) -> Any:
key_str = str(key)
expr = f'key == "{key_str}"'
res = self.milvus_ltm_client._client.query(
self.storage_name, expr, output_fields=["value"]
)
if res:
value = res[0]["value"]
# value_bytes = base64.b64decode(value_base64)
# value = pickle.loads(value_bytes)
return value
else:
raise KeyError(f"Key {key} not found")
def __setitem__(self, key: Any, value: Any) -> None:
self._create_collection()
key_str = str(key)
# Check if value is a dictionary containing 'value' and 'embedding'
if isinstance(value, dict) and "value" in value and "embedding" in value:
actual_value = value["value"]
embedding = value["embedding"]
else:
raise ValueError(
"When setting an item, value must be a dictionary containing 'value' and 'embedding' keys."
)
# Serialize the actual value and encode it to base64
# value_bytes = pickle.dumps(actual_value)
# value_base64 = base64.b64encode(value_bytes).decode('utf-8')
# Ensure the embedding is provided
if embedding is None:
raise ValueError("An embedding vector must be provided.")
# Check if the key exists and delete it if it does
if key_str in self:
self.__delitem__(key_str)
# Prepare data for insertion (as a list of dictionaries)
data = [
{
"key": key_str,
"value": actual_value,
"embedding": embedding,
}
]
# Insert the new record
self.milvus_ltm_client._client.insert(
collection_name=self.storage_name, data=data
)
def __delitem__(self, key: Any) -> None:
key_str = str(key)
if key_str in self:
expr = f'key == "{key_str}"'
self.milvus_ltm_client._client.delete(self.storage_name, expr)
else:
raise KeyError(f"Key {key} not found")
def __contains__(self, key: Any) -> bool:
key_str = str(key)
expr = f'key == "{key_str}"'
# Adjust the query call to match the expected signature
res = self.milvus_ltm_client._client.query(
self.storage_name, # Pass the collection name as the first argument
filter=expr,
output_fields=["key"],
)
return len(res) > 0
"""
def __len__(self) -> int:
milvus_ltm.collection.flush()
return self.collection.num_entities
"""
def __len__(self) -> int:
expr = 'key != ""' # Expression to match all entities
# self.milvus_ltm_client._client.load(refresh=True)
results = self.milvus_ltm_client._client.query(
self.storage_name, expr, output_fields=["key"], consistency_level="Strong"
)
return len(results)
def keys(self, limit=10) -> Iterable[Any]:
expr = ""
res = self.milvus_ltm_client._client.query(
self.storage_name, expr, output_fields=["key"], limit=limit
)
return (item["key"] for item in res)
def values(self) -> Iterable[Any]:
expr = 'key != ""' # Expression to match all active entities
self.milvus_ltm_client._client.load(refresh=True)
res = self.milvus_ltm_client._client.query(
self.storage_name, expr, output_fields=["value"], consistency_level="Strong"
)
for item in res:
value_base64 = item["value"]
value_bytes = base64.b64decode(value_base64)
value = pickle.loads(value_bytes)
yield value
def items(self) -> Iterable[Tuple[Any, Any]]:
expr = 'key != ""'
res = self.milvus_ltm_client._client.query(
self.storage_name, expr, output_fields=["key", "value"]
)
for item in res:
key = item["key"]
value = item["value"]
# value_bytes = base64.b64decode(value_base64)
# value = pickle.loads(value_bytes)
yield (key, value)
def get(self, key: Any, default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default
def clear(self) -> None:
expr = (
'key != ""' # This expression matches all records where 'key' is not empty
)
self.milvus_ltm_client._client.delete(self.storage_name, filter=expr)
def pop(self, key: Any, default: Any = None) -> Any:
try:
value = self[key]
self.__delitem__(key)
return value
except KeyError:
if default is not None:
return default
else:
raise
def update(self, other: Iterable[Tuple[Any, Any]]) -> None:
for key, value in other:
self[key] = value
def get_by_vector(
self,
embedding: List[float],
top_k: int = 10,
threshold: float = 0.0,
filter: str = "",
) -> List[Tuple[Any, Any, float]]:
search_params = {
"metric_type": "COSINE",
"params": {"nprobe": 10, "range_filter": 1, "radius": threshold},
}
results = self.milvus_ltm_client._client.search(
self.storage_name,
data=[embedding],
anns_field="embedding",
search_params=search_params,
limit=top_k,
output_fields=["key", "value"],
consistency_level="Strong",
filter=filter,
)
items = []
for match in results[0]:
key = match.get("entity").get("key")
value = match.get("entity").get("value")
items.append((key, value))
return items