|
import torch |
|
import hashlib |
|
import numpy as np |
|
|
|
class ParameterMemoryBank: |
|
""" |
|
Parameter Memory Bank (PMB) for infinite, queryable memory. |
|
|
|
This implementation uses a two-level hashing system for constant-time |
|
direct access and supports semantic similarity search. |
|
|
|
- Level 1: A list of 'blocks'. |
|
- Level 2: Each block is a dictionary-like structure mapping slots to items. |
|
|
|
For simplicity, we use Python lists and dictionaries. A production system |
|
would use a more optimized backend (e.g., Redis, custom memory store). |
|
""" |
|
def __init__(self, num_blocks=1024, slots_per_block=4096, embedding_dim=None): |
|
self.num_blocks = num_blocks |
|
self.slots_per_block = slots_per_block |
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
|
|
self.pmb = [ [None] * slots_per_block for _ in range(num_blocks) ] |
|
|
|
|
|
|
|
self.all_keys = [] |
|
self.key_locations = [] |
|
|
|
def _hash_fn(self, s, salt=""): |
|
"""A simple, salted hash function.""" |
|
return int(hashlib.sha256((str(s) + salt).encode()).hexdigest(), 16) |
|
|
|
def _get_hash_indices(self, item_id): |
|
""" |
|
Calculates the block and slot indices for a given item ID using |
|
the two-level hashing scheme. |
|
""" |
|
block_hash = self._hash_fn(item_id, salt="block") |
|
block_idx = block_hash % self.num_blocks |
|
|
|
slot_hash = self._hash_fn(item_id, salt=f"slot_{block_idx}") |
|
slot_idx = slot_hash % self.slots_per_block |
|
|
|
return block_idx, slot_idx |
|
|
|
def store(self, item_id, key_embedding, value): |
|
""" |
|
Stores a key-value pair in the PMB using its ID. |
|
|
|
Args: |
|
item_id (str or int): A unique identifier for the data. |
|
key_embedding (torch.Tensor): The embedding vector (k_i,j). |
|
value (any): The data to store (v_i,j), e.g., text, metadata. |
|
""" |
|
if not isinstance(key_embedding, torch.Tensor): |
|
raise TypeError("key_embedding must be a torch.Tensor") |
|
|
|
block_idx, slot_idx = self._get_hash_indices(item_id) |
|
|
|
|
|
|
|
|
|
if self.pmb[block_idx][slot_idx] is not None: |
|
|
|
pass |
|
|
|
self.pmb[block_idx][slot_idx] = (item_id, key_embedding.detach().cpu(), value.detach().cpu() if isinstance(value, torch.Tensor) else value) |
|
|
|
|
|
self.all_keys.append(key_embedding.detach().cpu()) |
|
self.key_locations.append((block_idx, slot_idx)) |
|
|
|
def retrieve_direct(self, item_id): |
|
""" |
|
Retrieves a value directly using its ID in O(1) time. |
|
|
|
Args: |
|
item_id (str or int): The unique identifier of the item. |
|
|
|
Returns: |
|
The stored value, or None if not found. |
|
""" |
|
block_idx, slot_idx = self._get_hash_indices(item_id) |
|
item = self.pmb[block_idx][slot_idx] |
|
|
|
|
|
if item and item[0] == item_id: |
|
return item[2] |
|
return None |
|
|
|
def retrieve_by_indices(self, indices): |
|
""" |
|
Retrieves items by their indices in the `all_keys` list. |
|
Args: |
|
indices (list or torch.Tensor): A list of indices. |
|
Returns: |
|
A list of the retrieved values. |
|
""" |
|
results = [] |
|
for idx in indices: |
|
if idx < len(self.key_locations): |
|
block_idx, slot_idx = self.key_locations[idx] |
|
item = self.pmb[block_idx][slot_idx] |
|
if item: |
|
value = item[2] |
|
|
|
if isinstance(value, torch.Tensor): |
|
results.append(value) |
|
else: |
|
|
|
if self.embedding_dim: |
|
results.append(torch.zeros(self.embedding_dim)) |
|
else: |
|
|
|
results.append(item[1]) |
|
else: |
|
|
|
if self.embedding_dim: |
|
results.append(torch.zeros(self.embedding_dim)) |
|
else: |
|
results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1)) |
|
else: |
|
|
|
if self.embedding_dim: |
|
results.append(torch.zeros(self.embedding_dim)) |
|
else: |
|
results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1)) |
|
return results |
|
|
|
def retrieve_semantic(self, query_embeddings, top_k=1): |
|
""" |
|
Retrieves the top_k most semantically similar items for a batch of query embeddings. |
|
|
|
Args: |
|
query_embeddings (torch.Tensor): Query vectors (batch_size, embedding_dim) or (batch_size, seq_len, embedding_dim). |
|
top_k (int): The number of similar items to return for each query. |
|
|
|
Returns: |
|
A tensor of the aggregated retrieved values with the same shape as query_embeddings. |
|
""" |
|
if not self.all_keys or top_k == 0: |
|
return torch.zeros_like(query_embeddings) |
|
|
|
if not isinstance(query_embeddings, torch.Tensor): |
|
raise TypeError("query_embeddings must be a torch.Tensor") |
|
|
|
|
|
original_shape = query_embeddings.shape |
|
device = query_embeddings.device |
|
|
|
|
|
if query_embeddings.dim() > 2: |
|
query_flat = query_embeddings.view(-1, original_shape[-1]) |
|
else: |
|
query_flat = query_embeddings |
|
|
|
|
|
if not self.all_keys: |
|
return torch.zeros_like(query_embeddings) |
|
|
|
try: |
|
|
|
all_keys_tensor = torch.stack(self.all_keys, dim=0).to(device) |
|
|
|
|
|
query_norm = torch.nn.functional.normalize(query_flat, p=2, dim=-1) |
|
keys_norm = torch.nn.functional.normalize(all_keys_tensor, p=2, dim=-1) |
|
|
|
|
|
similarities = torch.mm(query_norm, keys_norm.T) |
|
|
|
|
|
k = min(top_k, len(self.all_keys)) |
|
if k > 0: |
|
top_k_scores, top_k_indices = torch.topk(similarities, k=k, dim=1) |
|
|
|
|
|
batch_results = [] |
|
for i in range(query_flat.size(0)): |
|
retrieved_values = self.retrieve_by_indices(top_k_indices[i].cpu().tolist()) |
|
|
|
if retrieved_values: |
|
|
|
stacked_values = torch.stack(retrieved_values, dim=0).to(device) |
|
|
|
aggregated_value = torch.mean(stacked_values, dim=0) |
|
batch_results.append(aggregated_value) |
|
else: |
|
|
|
batch_results.append(torch.zeros(original_shape[-1], device=device)) |
|
|
|
|
|
if batch_results: |
|
result = torch.stack(batch_results, dim=0) |
|
|
|
return result.view(original_shape) |
|
else: |
|
return torch.zeros_like(query_embeddings) |
|
else: |
|
return torch.zeros_like(query_embeddings) |
|
|
|
except Exception as e: |
|
print(f"Error in PMB retrieve_semantic: {e}") |
|
return torch.zeros_like(query_embeddings) |
|
|
|
def __len__(self): |
|
return len(self.all_keys) |