File size: 1,948 Bytes
ee91c59 b668479 ee91c59 ca65cde ee91c59 b668479 ee91c59 ca65cde c9387f0 b668479 c94bd82 ee91c59 ca65cde ee91c59 ca65cde ee91c59 b668479 ee91c59 b668479 ee91c59 ca65cde c9387f0 b668479 ee91c59 ca65cde ee91c59 c9387f0 ee91c59 ca65cde b668479 ca65cde ee91c59 c9387f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Deque, Dict, Any
class CognitiveMemory(nn.Module):
"""Memory system dengan manajemen dimensi yang ketat"""
def __init__(self, context_size: int, capacity: int = 100):
super().__init__()
self.context_size = context_size
self.capacity = capacity
self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
# Proyeksi linear dengan dimensi input/output sama
self.key_proj = nn.Linear(context_size, context_size)
self.value_proj = nn.Linear(context_size, context_size)
self.importance_decay = nn.Parameter(torch.tensor(0.95))
def add_memory(self, context: torch.Tensor, activation: float):
"""Menyimpan memori dengan dimensi terkontrol"""
importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
self.memory_queue.append({
'context': context.detach().clone().squeeze(),
'importance': importance,
'age': torch.tensor(0.0)
})
def consolidate_memories(self):
"""Konsolidasi memori dengan validasi dimensi"""
self.memory_queue = deque(
[m for m in self.memory_queue if m['importance'] > 0.2],
maxlen=self.capacity
)
def retrieve(self, query: torch.Tensor) -> torch.Tensor:
"""Retrieval dengan penanganan tensor 1D"""
if not self.memory_queue:
return torch.zeros(self.context_size, device=query.device)
# Penanganan dimensi yang konsisten
contexts = torch.stack([m['context'] for m in self.memory_queue])
keys = self.key_proj(contexts)
values = self.value_proj(contexts)
query_proj = self.key_proj(query.squeeze())
scores = F.softmax(keys @ query_proj, dim=0)
return (scores.unsqueeze(1) * values).sum(dim=0) |