File size: 1,871 Bytes
ee91c59
 
 
 
b668479
ee91c59
 
c9387f0
ee91c59
 
 
 
b668479
ee91c59
c9387f0
 
 
b668479
c94bd82
ee91c59
c9387f0
ee91c59
 
b668479
ee91c59
b668479
ee91c59
b668479
ee91c59
c9387f0
 
 
 
 
b668479
ee91c59
c9387f0
ee91c59
c9387f0
ee91c59
b668479
 
 
c9387f0
ee91c59
c9387f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Deque, Dict, Any

class CognitiveMemory(nn.Module):
    """Memory system dengan dimensi konsisten"""
    def __init__(self, context_size: int, capacity: int = 100):
        super().__init__()
        self.context_size = context_size
        self.capacity = capacity
        self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
        
        # Proyeksi mempertahankan dimensi asli
        self.key_proj = nn.Linear(context_size, context_size)
        self.value_proj = nn.Linear(context_size, context_size)
        self.importance_decay = nn.Parameter(torch.tensor(0.95))

    def add_memory(self, context: torch.Tensor, activation: float):
        """Menyimpan memori dengan dimensi yang sesuai"""
        importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
        self.memory_queue.append({
            'context': context.detach().clone(),
            'importance': importance,
            'age': torch.tensor(0.0)
        })

    def consolidate_memories(self):
        """Konsolidasi memori dengan manajemen dimensi"""
        self.memory_queue = deque(
            [m for m in self.memory_queue if m['importance'] > 0.2],
            maxlen=self.capacity
        )

    def retrieve(self, query: torch.Tensor) -> torch.Tensor:
        """Retrieval dengan penanganan dimensi yang aman"""
        if not self.memory_queue:
            return torch.zeros(self.context_size, device=query.device)
            
        contexts = torch.stack([m['context'] for m in self.memory_queue])
        keys = self.key_proj(contexts)
        values = self.value_proj(contexts)
        query_proj = self.key_proj(query)
        
        scores = F.softmax(keys @ query_proj, dim=0)
        return (scores.unsqueeze(1) * values).sum(dim=0)