vincentiusyoshuac commited on
Commit
ee91c59
·
verified ·
1 Parent(s): b385599

Create memory.py

Browse files
Files changed (1) hide show
  1. memory.py +54 -0
memory.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import deque
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ class CognitiveMemory(nn.Module):
8
+ """Differentiable memory system with consolidation and retrieval"""
9
+ def __init__(self, context_size: int, capacity: int = 100):
10
+ super().__init__()
11
+ self.context_size = context_size
12
+ self.capacity = capacity
13
+ self.memory_queue = deque(maxlen=capacity)
14
+
15
+ # Memory importance parameters
16
+ self.importance_decay = nn.Parameter(torch.tensor(0.95))
17
+ self.consolidation_threshold = 0.7
18
+
19
+ # Memory projection layers
20
+ self.key_proj = nn.Linear(context_size, 64)
21
+ self.value_proj = nn.Linear(context_size, 64)
22
+
23
+ def add_memory(self, context: torch.Tensor, activation: float):
24
+ """Store new memory with adaptive importance"""
25
+ importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
26
+ self.memory_queue.append({
27
+ 'context': context.detach(),
28
+ 'importance': importance,
29
+ 'age': 0.0
30
+ })
31
+
32
+ def consolidate_memories(self):
33
+ """Memory consolidation through importance reweighting"""
34
+ for mem in self.memory_queue:
35
+ mem['importance'] *= self.importance_decay
36
+ mem['age'] += 0.1
37
+
38
+ # Remove unimportant memories
39
+ self.memory_queue = deque(
40
+ [m for m in self.memory_queue if m['importance'] > 0.2],
41
+ maxlen=self.capacity
42
+ )
43
+
44
+ def retrieve(self, query: torch.Tensor) -> torch.Tensor:
45
+ """Attention-based memory retrieval"""
46
+ if not self.memory_queue:
47
+ return torch.zeros_like(query)
48
+
49
+ keys = torch.stack([self.key_proj(m['context']) for m in self.memory_queue])
50
+ values = torch.stack([self.value_proj(m['context']) for m in self.memory_queue])
51
+ query_proj = self.key_proj(query)
52
+
53
+ scores = F.softmax(keys @ query_proj.t(), dim=0)
54
+ return (scores * values).sum(dim=0)