File size: 1,948 Bytes
ee91c59
 
 
 
b668479
ee91c59
 
ca65cde
ee91c59
 
 
 
b668479
ee91c59
ca65cde
c9387f0
 
b668479
c94bd82
ee91c59
ca65cde
ee91c59
 
ca65cde
ee91c59
b668479
ee91c59
b668479
ee91c59
ca65cde
c9387f0
 
 
 
b668479
ee91c59
ca65cde
ee91c59
c9387f0
ee91c59
ca65cde
b668479
 
 
ca65cde
ee91c59
c9387f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Deque, Dict, Any

class CognitiveMemory(nn.Module):
    """Memory system dengan manajemen dimensi yang ketat"""
    def __init__(self, context_size: int, capacity: int = 100):
        super().__init__()
        self.context_size = context_size
        self.capacity = capacity
        self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
        
        # Proyeksi linear dengan dimensi input/output sama
        self.key_proj = nn.Linear(context_size, context_size)
        self.value_proj = nn.Linear(context_size, context_size)
        self.importance_decay = nn.Parameter(torch.tensor(0.95))

    def add_memory(self, context: torch.Tensor, activation: float):
        """Menyimpan memori dengan dimensi terkontrol"""
        importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
        self.memory_queue.append({
            'context': context.detach().clone().squeeze(),
            'importance': importance,
            'age': torch.tensor(0.0)
        })

    def consolidate_memories(self):
        """Konsolidasi memori dengan validasi dimensi"""
        self.memory_queue = deque(
            [m for m in self.memory_queue if m['importance'] > 0.2],
            maxlen=self.capacity
        )

    def retrieve(self, query: torch.Tensor) -> torch.Tensor:
        """Retrieval dengan penanganan tensor 1D"""
        if not self.memory_queue:
            return torch.zeros(self.context_size, device=query.device)
            
        # Penanganan dimensi yang konsisten
        contexts = torch.stack([m['context'] for m in self.memory_queue])
        keys = self.key_proj(contexts)
        values = self.value_proj(contexts)
        query_proj = self.key_proj(query.squeeze())
        
        scores = F.softmax(keys @ query_proj, dim=0)
        return (scores.unsqueeze(1) * values).sum(dim=0)