vincentiusyoshuac commited on
Commit
c9387f0
·
verified ·
1 Parent(s): a135dcc

Update memory.py

Browse files
Files changed (1) hide show
  1. memory.py +15 -23
memory.py CHANGED
@@ -1,4 +1,3 @@
1
- # cognitive_net/memory.py
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
@@ -6,24 +5,20 @@ from collections import deque
6
  from typing import Deque, Dict, Any
7
 
8
  class CognitiveMemory(nn.Module):
9
- """Differentiable memory system with biological consolidation mechanisms"""
10
  def __init__(self, context_size: int, capacity: int = 100):
11
  super().__init__()
12
  self.context_size = context_size
13
  self.capacity = capacity
14
  self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
15
 
16
- # Memory projection layers with adaptive scaling
17
- self.key_proj = nn.Linear(context_size, 64)
18
- self.value_proj = nn.Linear(context_size, 64)
19
  self.importance_decay = nn.Parameter(torch.tensor(0.95))
20
-
21
- # Consolidation parameters
22
- self.consolidation_threshold = 0.7
23
- self.age_decay = 0.1
24
 
25
  def add_memory(self, context: torch.Tensor, activation: float):
26
- """Store memory with dynamic importance weighting"""
27
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
28
  self.memory_queue.append({
29
  'context': context.detach().clone(),
@@ -32,24 +27,21 @@ class CognitiveMemory(nn.Module):
32
  })
33
 
34
  def consolidate_memories(self):
35
- """Memory optimization through importance-based pruning"""
36
- new_queue = deque(maxlen=self.capacity)
37
- for mem in self.memory_queue:
38
- mem['importance'] *= self.importance_decay
39
- mem['age'] += self.age_decay
40
- if mem['importance'] > 0.2:
41
- new_queue.append(mem)
42
- self.memory_queue = new_queue
43
 
44
  def retrieve(self, query: torch.Tensor) -> torch.Tensor:
45
- """Content-based memory retrieval with attention"""
46
  if not self.memory_queue:
47
- return torch.zeros(64, device=query.device)
48
 
49
  contexts = torch.stack([m['context'] for m in self.memory_queue])
50
  keys = self.key_proj(contexts)
51
  values = self.value_proj(contexts)
52
- query_proj = self.key_proj(query.unsqueeze(0))
53
 
54
- scores = F.softmax(keys @ query_proj.T, dim=0)
55
- return (scores * values).sum(dim=0)
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
5
  from typing import Deque, Dict, Any
6
 
7
  class CognitiveMemory(nn.Module):
8
+ """Memory system dengan dimensi konsisten"""
9
  def __init__(self, context_size: int, capacity: int = 100):
10
  super().__init__()
11
  self.context_size = context_size
12
  self.capacity = capacity
13
  self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
14
 
15
+ # Proyeksi mempertahankan dimensi asli
16
+ self.key_proj = nn.Linear(context_size, context_size)
17
+ self.value_proj = nn.Linear(context_size, context_size)
18
  self.importance_decay = nn.Parameter(torch.tensor(0.95))
 
 
 
 
19
 
20
  def add_memory(self, context: torch.Tensor, activation: float):
21
+ """Menyimpan memori dengan dimensi yang sesuai"""
22
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
23
  self.memory_queue.append({
24
  'context': context.detach().clone(),
 
27
  })
28
 
29
  def consolidate_memories(self):
30
+ """Konsolidasi memori dengan manajemen dimensi"""
31
+ self.memory_queue = deque(
32
+ [m for m in self.memory_queue if m['importance'] > 0.2],
33
+ maxlen=self.capacity
34
+ )
 
 
 
35
 
36
  def retrieve(self, query: torch.Tensor) -> torch.Tensor:
37
+ """Retrieval dengan penanganan dimensi yang aman"""
38
  if not self.memory_queue:
39
+ return torch.zeros(self.context_size, device=query.device)
40
 
41
  contexts = torch.stack([m['context'] for m in self.memory_queue])
42
  keys = self.key_proj(contexts)
43
  values = self.value_proj(contexts)
44
+ query_proj = self.key_proj(query)
45
 
46
+ scores = F.softmax(keys @ query_proj, dim=0)
47
+ return (scores.unsqueeze(1) * values).sum(dim=0)