vincentiusyoshuac commited on
Commit
b668479
·
verified ·
1 Parent(s): c94bd82

Update memory.py

Browse files
Files changed (1) hide show
  1. memory.py +29 -35
memory.py CHANGED
@@ -1,61 +1,55 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from collections import deque
5
- from typing import Dict, List, Optional, Tuple
6
 
7
  class CognitiveMemory(nn.Module):
8
- """Differentiable memory system with consolidation and retrieval"""
9
  def __init__(self, context_size: int, capacity: int = 100):
10
  super().__init__()
11
  self.context_size = context_size
12
  self.capacity = capacity
13
- self.memory_queue = deque(maxlen=capacity)
14
 
15
- # Memory importance parameters
16
- self.importance_decay = nn.Parameter(torch.tensor(0.95))
17
- self.consolidation_threshold = 0.7
18
-
19
- # Memory projection layers
20
  self.key_proj = nn.Linear(context_size, 64)
21
  self.value_proj = nn.Linear(context_size, 64)
 
 
 
 
 
22
 
23
  def add_memory(self, context: torch.Tensor, activation: float):
24
- """Store new memory with adaptive importance"""
25
- # Ensure context is 1D tensor with single value
26
- context = context.reshape(-1)
27
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
28
  self.memory_queue.append({
29
- 'context': context.detach(),
30
  'importance': importance,
31
- 'age': 0.0
32
  })
33
-
34
  def consolidate_memories(self):
35
- """Memory consolidation through importance reweighting"""
 
36
  for mem in self.memory_queue:
37
  mem['importance'] *= self.importance_decay
38
- mem['age'] += 0.1
39
-
40
- # Remove unimportant memories
41
- self.memory_queue = deque(
42
- [m for m in self.memory_queue if m['importance'] > 0.2],
43
- maxlen=self.capacity
44
- )
45
-
46
  def retrieve(self, query: torch.Tensor) -> torch.Tensor:
47
- """Attention-based memory retrieval"""
48
  if not self.memory_queue:
49
- return torch.zeros_like(query)
50
 
51
- # Ensure query is 1D tensor with single value
52
- query = query.reshape(1, 1)
53
- memories = torch.stack([m['context'].reshape(1, 1) for m in self.memory_queue])
54
-
55
- keys = self.key_proj(memories)
56
- values = self.value_proj(memories)
57
- query_proj = self.key_proj(query)
58
 
59
- scores = F.softmax(torch.matmul(keys, query_proj.transpose(0, 1)), dim=0)
60
- retrieved = torch.matmul(scores.transpose(0, 1), values)
61
- return retrieved.squeeze(0)
 
1
+ # cognitive_net/memory.py
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from collections import deque
6
+ from typing import Deque, Dict, Any
7
 
8
  class CognitiveMemory(nn.Module):
9
+ """Differentiable memory system with biological consolidation mechanisms"""
10
  def __init__(self, context_size: int, capacity: int = 100):
11
  super().__init__()
12
  self.context_size = context_size
13
  self.capacity = capacity
14
+ self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
15
 
16
+ # Memory projection layers with adaptive scaling
 
 
 
 
17
  self.key_proj = nn.Linear(context_size, 64)
18
  self.value_proj = nn.Linear(context_size, 64)
19
+ self.importance_decay = nn.Parameter(torch.tensor(0.95))
20
+
21
+ # Consolidation parameters
22
+ self.consolidation_threshold = 0.7
23
+ self.age_decay = 0.1
24
 
25
  def add_memory(self, context: torch.Tensor, activation: float):
26
+ """Store memory with dynamic importance weighting"""
 
 
27
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
28
  self.memory_queue.append({
29
+ 'context': context.detach().clone(),
30
  'importance': importance,
31
+ 'age': torch.tensor(0.0)
32
  })
33
+
34
  def consolidate_memories(self):
35
+ """Memory optimization through importance-based pruning"""
36
+ new_queue = deque(maxlen=self.capacity)
37
  for mem in self.memory_queue:
38
  mem['importance'] *= self.importance_decay
39
+ mem['age'] += self.age_decay
40
+ if mem['importance'] > 0.2:
41
+ new_queue.append(mem)
42
+ self.memory_queue = new_queue
43
+
 
 
 
44
  def retrieve(self, query: torch.Tensor) -> torch.Tensor:
45
+ """Content-based memory retrieval with attention"""
46
  if not self.memory_queue:
47
+ return torch.zeros(64, device=query.device)
48
 
49
+ contexts = torch.stack([m['context'] for m in self.memory_queue])
50
+ keys = self.key_proj(contexts)
51
+ values = self.value_proj(contexts)
52
+ query_proj = self.key_proj(query.unsqueeze(0))
 
 
 
53
 
54
+ scores = F.softmax(keys @ query_proj.T, dim=0)
55
+ return (scores * values).sum(dim=0)