Spaces:
Sleeping
Sleeping
| import base64 | |
| import pickle | |
| from typing import Any, Iterable, List, Optional, Tuple | |
| from omagent_core.memories.ltms.ltm_base import LTMBase | |
| from omagent_core.services.connectors.milvus import MilvusConnector | |
| from omagent_core.utils.registry import registry | |
| from pydantic import Field | |
| from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, | |
| utility) | |
| class VideoMilvusLTM(LTMBase): | |
| milvus_ltm_client: MilvusConnector | |
| storage_name: str = Field(default="default") | |
| dim: int = Field(default=128) | |
| def model_post_init(self, __context: Any) -> None: | |
| pass | |
| def _create_collection(self) -> None: | |
| # Check if collection exists | |
| if not self.milvus_ltm_client._client.has_collection(self.storage_name): | |
| index_params = self.milvus_ltm_client._client.prepare_index_params() | |
| # Define field schemas | |
| key_field = FieldSchema( | |
| name="key", dtype=DataType.VARCHAR, is_primary=True, max_length=256 | |
| ) | |
| value_field = FieldSchema( | |
| name="value", dtype=DataType.JSON, description="Json value" | |
| ) | |
| embedding_field = FieldSchema( | |
| name="embedding", | |
| dtype=DataType.FLOAT_VECTOR, | |
| description="Embedding vector", | |
| dim=self.dim, | |
| ) | |
| index_params = self.milvus_ltm_client._client.prepare_index_params() | |
| # Create collection schema | |
| schema = CollectionSchema( | |
| fields=[key_field, value_field, embedding_field], | |
| description="Key-Value storage with embeddings", | |
| ) | |
| for field in schema.fields: | |
| if ( | |
| field.dtype == DataType.FLOAT_VECTOR | |
| or field.dtype == DataType.BINARY_VECTOR | |
| ): | |
| index_params.add_index( | |
| field_name=field.name, | |
| index_name=field.name, | |
| index_type="FLAT", | |
| metric_type="COSINE", | |
| params={"nlist": 128}, | |
| ) | |
| self.milvus_ltm_client._client.create_collection( | |
| self.storage_name, schema=schema, index_params=index_params | |
| ) | |
| # Create index separately after collection creation | |
| print(f"Created storage {self.storage_name} successfully") | |
| def __getitem__(self, key: Any) -> Any: | |
| key_str = str(key) | |
| expr = f'key == "{key_str}"' | |
| res = self.milvus_ltm_client._client.query( | |
| self.storage_name, expr, output_fields=["value"] | |
| ) | |
| if res: | |
| value = res[0]["value"] | |
| # value_bytes = base64.b64decode(value_base64) | |
| # value = pickle.loads(value_bytes) | |
| return value | |
| else: | |
| raise KeyError(f"Key {key} not found") | |
| def __setitem__(self, key: Any, value: Any) -> None: | |
| self._create_collection() | |
| key_str = str(key) | |
| # Check if value is a dictionary containing 'value' and 'embedding' | |
| if isinstance(value, dict) and "value" in value and "embedding" in value: | |
| actual_value = value["value"] | |
| embedding = value["embedding"] | |
| else: | |
| raise ValueError( | |
| "When setting an item, value must be a dictionary containing 'value' and 'embedding' keys." | |
| ) | |
| # Serialize the actual value and encode it to base64 | |
| # value_bytes = pickle.dumps(actual_value) | |
| # value_base64 = base64.b64encode(value_bytes).decode('utf-8') | |
| # Ensure the embedding is provided | |
| if embedding is None: | |
| raise ValueError("An embedding vector must be provided.") | |
| # Check if the key exists and delete it if it does | |
| if key_str in self: | |
| self.__delitem__(key_str) | |
| # Prepare data for insertion (as a list of dictionaries) | |
| data = [ | |
| { | |
| "key": key_str, | |
| "value": actual_value, | |
| "embedding": embedding, | |
| } | |
| ] | |
| # Insert the new record | |
| self.milvus_ltm_client._client.insert( | |
| collection_name=self.storage_name, data=data | |
| ) | |
| def __delitem__(self, key: Any) -> None: | |
| key_str = str(key) | |
| if key_str in self: | |
| expr = f'key == "{key_str}"' | |
| self.milvus_ltm_client._client.delete(self.storage_name, expr) | |
| else: | |
| raise KeyError(f"Key {key} not found") | |
| def __contains__(self, key: Any) -> bool: | |
| key_str = str(key) | |
| expr = f'key == "{key_str}"' | |
| # Adjust the query call to match the expected signature | |
| res = self.milvus_ltm_client._client.query( | |
| self.storage_name, # Pass the collection name as the first argument | |
| filter=expr, | |
| output_fields=["key"], | |
| ) | |
| return len(res) > 0 | |
| """ | |
| def __len__(self) -> int: | |
| milvus_ltm.collection.flush() | |
| return self.collection.num_entities | |
| """ | |
| def __len__(self) -> int: | |
| expr = 'key != ""' # Expression to match all entities | |
| # self.milvus_ltm_client._client.load(refresh=True) | |
| results = self.milvus_ltm_client._client.query( | |
| self.storage_name, expr, output_fields=["key"], consistency_level="Strong" | |
| ) | |
| return len(results) | |
| def keys(self, limit=10) -> Iterable[Any]: | |
| expr = "" | |
| res = self.milvus_ltm_client._client.query( | |
| self.storage_name, expr, output_fields=["key"], limit=limit | |
| ) | |
| return (item["key"] for item in res) | |
| def values(self) -> Iterable[Any]: | |
| expr = 'key != ""' # Expression to match all active entities | |
| self.milvus_ltm_client._client.load(refresh=True) | |
| res = self.milvus_ltm_client._client.query( | |
| self.storage_name, expr, output_fields=["value"], consistency_level="Strong" | |
| ) | |
| for item in res: | |
| value_base64 = item["value"] | |
| value_bytes = base64.b64decode(value_base64) | |
| value = pickle.loads(value_bytes) | |
| yield value | |
| def items(self) -> Iterable[Tuple[Any, Any]]: | |
| expr = 'key != ""' | |
| res = self.milvus_ltm_client._client.query( | |
| self.storage_name, expr, output_fields=["key", "value"] | |
| ) | |
| for item in res: | |
| key = item["key"] | |
| value = item["value"] | |
| # value_bytes = base64.b64decode(value_base64) | |
| # value = pickle.loads(value_bytes) | |
| yield (key, value) | |
| def get(self, key: Any, default: Any = None) -> Any: | |
| try: | |
| return self[key] | |
| except KeyError: | |
| return default | |
| def clear(self) -> None: | |
| expr = ( | |
| 'key != ""' # This expression matches all records where 'key' is not empty | |
| ) | |
| self.milvus_ltm_client._client.delete(self.storage_name, filter=expr) | |
| def pop(self, key: Any, default: Any = None) -> Any: | |
| try: | |
| value = self[key] | |
| self.__delitem__(key) | |
| return value | |
| except KeyError: | |
| if default is not None: | |
| return default | |
| else: | |
| raise | |
| def update(self, other: Iterable[Tuple[Any, Any]]) -> None: | |
| for key, value in other: | |
| self[key] = value | |
| def get_by_vector( | |
| self, | |
| embedding: List[float], | |
| top_k: int = 10, | |
| threshold: float = 0.0, | |
| filter: str = "", | |
| ) -> List[Tuple[Any, Any, float]]: | |
| search_params = { | |
| "metric_type": "COSINE", | |
| "params": {"nprobe": 10, "range_filter": 1, "radius": threshold}, | |
| } | |
| results = self.milvus_ltm_client._client.search( | |
| self.storage_name, | |
| data=[embedding], | |
| anns_field="embedding", | |
| search_params=search_params, | |
| limit=top_k, | |
| output_fields=["key", "value"], | |
| consistency_level="Strong", | |
| filter=filter, | |
| ) | |
| items = [] | |
| for match in results[0]: | |
| key = match.get("entity").get("key") | |
| value = match.get("entity").get("value") | |
| items.append((key, value)) | |
| return items | |