vincentiusyoshuac commited on
Commit
ca65cde
·
verified ·
1 Parent(s): 771ab14

Update memory.py

Browse files
Files changed (1) hide show
  1. memory.py +8 -7
memory.py CHANGED
@@ -5,43 +5,44 @@ from collections import deque
5
  from typing import Deque, Dict, Any
6
 
7
  class CognitiveMemory(nn.Module):
8
- """Memory system dengan dimensi konsisten"""
9
  def __init__(self, context_size: int, capacity: int = 100):
10
  super().__init__()
11
  self.context_size = context_size
12
  self.capacity = capacity
13
  self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
14
 
15
- # Proyeksi mempertahankan dimensi asli
16
  self.key_proj = nn.Linear(context_size, context_size)
17
  self.value_proj = nn.Linear(context_size, context_size)
18
  self.importance_decay = nn.Parameter(torch.tensor(0.95))
19
 
20
  def add_memory(self, context: torch.Tensor, activation: float):
21
- """Menyimpan memori dengan dimensi yang sesuai"""
22
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
23
  self.memory_queue.append({
24
- 'context': context.detach().clone(),
25
  'importance': importance,
26
  'age': torch.tensor(0.0)
27
  })
28
 
29
  def consolidate_memories(self):
30
- """Konsolidasi memori dengan manajemen dimensi"""
31
  self.memory_queue = deque(
32
  [m for m in self.memory_queue if m['importance'] > 0.2],
33
  maxlen=self.capacity
34
  )
35
 
36
  def retrieve(self, query: torch.Tensor) -> torch.Tensor:
37
- """Retrieval dengan penanganan dimensi yang aman"""
38
  if not self.memory_queue:
39
  return torch.zeros(self.context_size, device=query.device)
40
 
 
41
  contexts = torch.stack([m['context'] for m in self.memory_queue])
42
  keys = self.key_proj(contexts)
43
  values = self.value_proj(contexts)
44
- query_proj = self.key_proj(query)
45
 
46
  scores = F.softmax(keys @ query_proj, dim=0)
47
  return (scores.unsqueeze(1) * values).sum(dim=0)
 
5
  from typing import Deque, Dict, Any
6
 
7
  class CognitiveMemory(nn.Module):
8
+ """Memory system dengan manajemen dimensi yang ketat"""
9
  def __init__(self, context_size: int, capacity: int = 100):
10
  super().__init__()
11
  self.context_size = context_size
12
  self.capacity = capacity
13
  self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
14
 
15
+ # Proyeksi linear dengan dimensi input/output sama
16
  self.key_proj = nn.Linear(context_size, context_size)
17
  self.value_proj = nn.Linear(context_size, context_size)
18
  self.importance_decay = nn.Parameter(torch.tensor(0.95))
19
 
20
  def add_memory(self, context: torch.Tensor, activation: float):
21
+ """Menyimpan memori dengan dimensi terkontrol"""
22
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
23
  self.memory_queue.append({
24
+ 'context': context.detach().clone().squeeze(),
25
  'importance': importance,
26
  'age': torch.tensor(0.0)
27
  })
28
 
29
  def consolidate_memories(self):
30
+ """Konsolidasi memori dengan validasi dimensi"""
31
  self.memory_queue = deque(
32
  [m for m in self.memory_queue if m['importance'] > 0.2],
33
  maxlen=self.capacity
34
  )
35
 
36
  def retrieve(self, query: torch.Tensor) -> torch.Tensor:
37
+ """Retrieval dengan penanganan tensor 1D"""
38
  if not self.memory_queue:
39
  return torch.zeros(self.context_size, device=query.device)
40
 
41
+ # Penanganan dimensi yang konsisten
42
  contexts = torch.stack([m['context'] for m in self.memory_queue])
43
  keys = self.key_proj(contexts)
44
  values = self.value_proj(contexts)
45
+ query_proj = self.key_proj(query.squeeze())
46
 
47
  scores = F.softmax(keys @ query_proj, dim=0)
48
  return (scores.unsqueeze(1) * values).sum(dim=0)