File size: 4,439 Bytes
606bb52 ddcd0d7 606bb52 e25b866 606bb52 254bf80 606bb52 254bf80 ddcd0d7 606bb52 c41d138 606bb52 e25b866 ddcd0d7 606bb52 254bf80 606bb52 ddcd0d7 e25b866 ddcd0d7 606bb52 ddcd0d7 606bb52 254bf80 e25b866 606bb52 ddcd0d7 606bb52 ddcd0d7 e25b866 606bb52 ddcd0d7 e25b866 ddcd0d7 254bf80 e25b866 ddcd0d7 e25b866 ddcd0d7 e25b866 ddcd0d7 e25b866 254bf80 ddcd0d7 606bb52 ddcd0d7 e25b866 606bb52 e25b866 ddcd0d7 606bb52 e25b866 606bb52 e25b866 254bf80 e25b866 ddcd0d7 606bb52 e25b866 254bf80 e25b866 606bb52 254bf80 606bb52 ddcd0d7 606bb52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np
from typing import Dict, Optional
from .node import CognitiveNode
class DynamicCognitiveNet(nn.Module):
"""Arsitektur jaringan dengan manajemen tensor yang robust"""
def __init__(self, input_size: int, output_size: int):
super().__init__()
self.input_size = input_size
self.output_size = output_size
# Node dengan input size 1
self.input_nodes = nn.ModuleList([
CognitiveNode(i, 1) for i in range(input_size)
])
self.output_nodes = nn.ModuleList([
CognitiveNode(input_size + i, 1) for i in range(output_size)
])
# Manajemen koneksi
self.connections = nn.ParameterDict()
self._init_base_connections()
# Sistem pembelajaran
self.emotional_state = nn.Parameter(torch.tensor(0.0))
self.optimizer = optim.AdamW(self.parameters(), lr=0.001)
self.loss_fn = nn.MSELoss()
def _init_base_connections(self):
"""Inisialisasi koneksi input-output"""
for in_node in self.input_nodes:
for out_node in self.output_nodes:
conn_id = f"{in_node.id}->{out_node.id}"
self.connections[conn_id] = nn.Parameter(
torch.randn(1) * 0.1
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Validasi dimensi input
x = x.view(-1)
# Pemrosesan input
activations = {}
for i, node in enumerate(self.input_nodes):
activations[node.id] = node(x[i].unsqueeze(0))
# Integrasi output
outputs = []
for out_node in self.output_nodes:
integrated = []
for in_node in self.input_nodes:
conn_id = f"{in_node.id}->{out_node.id}"
weight = torch.sigmoid(self.connections[conn_id])
integrated.append(activations[in_node.id] * weight)
if integrated:
combined = sum(integrated) / math.sqrt(len(integrated))
outputs.append(out_node(combined))
return torch.stack(outputs).squeeze()
def structural_update(self, global_reward: float):
"""Update struktur jaringan"""
# Update kekuatan koneksi
for conn_id in list(self.connections.keys()):
new_weight = self.connections[conn_id] + 0.1 * global_reward
self.connections[conn_id].data = new_weight.clamp(-1, 1)
# Pembuatan koneksi baru
if global_reward < -0.5:
new_conn = self._find_underutilized_connection()
if new_conn and new_conn not in self.connections:
self.connections[new_conn] = nn.Parameter(torch.randn(1) * 0.1)
def _find_underutilized_connection(self) -> Optional[str]:
"""Mencari pasangan node yang kurang aktif"""
input_act = {n.id: np.mean(n.recent_activations)
for n in self.input_nodes if n.recent_activations}
output_act = {n.id: np.mean(n.recent_activations)
for n in self.output_nodes if n.recent_activations}
if not input_act or not output_act:
return None
src = min(input_act, key=lambda k: input_act[k])
tgt = min(output_act, key=lambda k: output_act[k])
return f"{src}->{tgt}"
def train_step(self, x: torch.Tensor, y: torch.Tensor) -> float:
"""Training step dengan error handling"""
self.optimizer.zero_grad()
try:
pred = self(x.view(-1))
loss = self.loss_fn(pred, y.view(-1))
except Exception as e:
print(f"Error forward: {e}")
return float('nan')
# Regularisasi struktural
reg_loss = sum(p.abs().mean() for p in self.connections.values())
total_loss = loss + 0.01 * reg_loss
try:
total_loss.backward()
self.optimizer.step()
except Exception as e:
print(f"Error backward: {e}")
return float('nan')
# Update emosi
self.emotional_state.data = torch.sigmoid(
self.emotional_state + (0.5 - loss.item()) * 0.1
)
self.structural_update(0.5 - loss.item())
return total_loss.item() |