File size: 8,991 Bytes
f03ee14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import torch
import hashlib
import numpy as np

class ParameterMemoryBank:
    """
    Parameter Memory Bank (PMB) for infinite, queryable memory.
    
    This implementation uses a two-level hashing system for constant-time
    direct access and supports semantic similarity search.
    
    - Level 1: A list of 'blocks'.
    - Level 2: Each block is a dictionary-like structure mapping slots to items.
    
    For simplicity, we use Python lists and dictionaries. A production system
    would use a more optimized backend (e.g., Redis, custom memory store).
    """
    def __init__(self, num_blocks=1024, slots_per_block=4096, embedding_dim=None):
        self.num_blocks = num_blocks
        self.slots_per_block = slots_per_block
        self.embedding_dim = embedding_dim
        
        # PMB is a list of blocks, where each block is a list of slots.
        # Each slot can hold a tuple: (id, key_embedding, value)
        self.pmb = [ [None] * slots_per_block for _ in range(num_blocks) ]
        
        # For semantic search, we need a separate structure to hold all keys.
        # This is a trade-off for efficient similarity search.
        self.all_keys = []
        self.key_locations = [] # Stores (block_idx, slot_idx) for each key

    def _hash_fn(self, s, salt=""):
        """A simple, salted hash function."""
        return int(hashlib.sha256((str(s) + salt).encode()).hexdigest(), 16)

    def _get_hash_indices(self, item_id):
        """
        Calculates the block and slot indices for a given item ID using
        the two-level hashing scheme.
        """
        block_hash = self._hash_fn(item_id, salt="block")
        block_idx = block_hash % self.num_blocks
        
        slot_hash = self._hash_fn(item_id, salt=f"slot_{block_idx}")
        slot_idx = slot_hash % self.slots_per_block
        
        return block_idx, slot_idx

    def store(self, item_id, key_embedding, value):
        """
        Stores a key-value pair in the PMB using its ID.
        
        Args:
            item_id (str or int): A unique identifier for the data.
            key_embedding (torch.Tensor): The embedding vector (k_i,j).
            value (any): The data to store (v_i,j), e.g., text, metadata.
        """
        if not isinstance(key_embedding, torch.Tensor):
            raise TypeError("key_embedding must be a torch.Tensor")

        block_idx, slot_idx = self._get_hash_indices(item_id)
        
        # Store the item in the hash-based location.
        # Note: This simple implementation doesn't handle hash collisions.
        # A real system would need a collision resolution strategy (e.g., cuckoo hashing, chaining).
        if self.pmb[block_idx][slot_idx] is not None:
            # Handle collision by updating the existing entry or finding an empty slot
            pass  # For now, just overwrite

        self.pmb[block_idx][slot_idx] = (item_id, key_embedding.detach().cpu(), value.detach().cpu() if isinstance(value, torch.Tensor) else value)
        
        # Also store the key for semantic search
        self.all_keys.append(key_embedding.detach().cpu())
        self.key_locations.append((block_idx, slot_idx))

    def retrieve_direct(self, item_id):
        """
        Retrieves a value directly using its ID in O(1) time.
        
        Args:
            item_id (str or int): The unique identifier of the item.
            
        Returns:
            The stored value, or None if not found.
        """
        block_idx, slot_idx = self._get_hash_indices(item_id)
        item = self.pmb[block_idx][slot_idx]
        
        # Check if the found item ID matches, in case of no collision handling
        if item and item[0] == item_id:
            return item[2] # Return the value
        return None

    def retrieve_by_indices(self, indices):
        """
        Retrieves items by their indices in the `all_keys` list.
        Args:
            indices (list or torch.Tensor): A list of indices.
        Returns:
            A list of the retrieved values.
        """
        results = []
        for idx in indices:
            if idx < len(self.key_locations):
                block_idx, slot_idx = self.key_locations[idx]
                item = self.pmb[block_idx][slot_idx]
                if item:
                    value = item[2]  # Get the value
                    # Convert back to tensor if it was stored as tensor
                    if isinstance(value, torch.Tensor):
                        results.append(value)
                    else:
                        # If value is not a tensor, create a zero tensor of appropriate size
                        if self.embedding_dim:
                            results.append(torch.zeros(self.embedding_dim))
                        else:
                            # Fallback: use the key embedding as value
                            results.append(item[1])  # Use key embedding
                else:
                    # No item found, append zero tensor
                    if self.embedding_dim:
                        results.append(torch.zeros(self.embedding_dim))
                    else:
                        results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
            else:
                # Index out of range
                if self.embedding_dim:
                    results.append(torch.zeros(self.embedding_dim))
                else:
                    results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
        return results

    def retrieve_semantic(self, query_embeddings, top_k=1):
        """
        Retrieves the top_k most semantically similar items for a batch of query embeddings.

        Args:
            query_embeddings (torch.Tensor): Query vectors (batch_size, embedding_dim) or (batch_size, seq_len, embedding_dim).
            top_k (int): The number of similar items to return for each query.

        Returns:
            A tensor of the aggregated retrieved values with the same shape as query_embeddings.
        """
        if not self.all_keys or top_k == 0:
            return torch.zeros_like(query_embeddings)

        if not isinstance(query_embeddings, torch.Tensor):
            raise TypeError("query_embeddings must be a torch.Tensor")

        # Store original shape and device
        original_shape = query_embeddings.shape
        device = query_embeddings.device
        
        # Flatten query embeddings to 2D for processing
        if query_embeddings.dim() > 2:
            query_flat = query_embeddings.view(-1, original_shape[-1])
        else:
            query_flat = query_embeddings

        # Handle empty memory bank
        if not self.all_keys:
            return torch.zeros_like(query_embeddings)

        try:
            # Stack all keys into a single tensor
            all_keys_tensor = torch.stack(self.all_keys, dim=0).to(device)
            
            # Compute cosine similarity
            query_norm = torch.nn.functional.normalize(query_flat, p=2, dim=-1)
            keys_norm = torch.nn.functional.normalize(all_keys_tensor, p=2, dim=-1)
            
            # Compute similarities: (batch_size, num_keys)
            similarities = torch.mm(query_norm, keys_norm.T)
            
            # Get top_k results for each query
            k = min(top_k, len(self.all_keys))
            if k > 0:
                top_k_scores, top_k_indices = torch.topk(similarities, k=k, dim=1)
                
                # Retrieve the corresponding values
                batch_results = []
                for i in range(query_flat.size(0)):
                    retrieved_values = self.retrieve_by_indices(top_k_indices[i].cpu().tolist())
                    
                    if retrieved_values:
                        # Stack and move to correct device
                        stacked_values = torch.stack(retrieved_values, dim=0).to(device)
                        # Average the top_k retrieved values
                        aggregated_value = torch.mean(stacked_values, dim=0)
                        batch_results.append(aggregated_value)
                    else:
                        # No valid retrievals, use zero tensor
                        batch_results.append(torch.zeros(original_shape[-1], device=device))
                
                # Stack all batch results
                if batch_results:
                    result = torch.stack(batch_results, dim=0)
                    # Reshape back to original shape
                    return result.view(original_shape)
                else:
                    return torch.zeros_like(query_embeddings)
            else:
                return torch.zeros_like(query_embeddings)
                
        except Exception as e:
            print(f"Error in PMB retrieve_semantic: {e}")
            return torch.zeros_like(query_embeddings)

    def __len__(self):
        return len(self.all_keys)