import torch import torch.nn as nn import torch.nn.functional as F from torch import optim import math from typing import Optional from .node import CognitiveNode class DynamicCognitiveNet(nn.Module): """Self-organizing cognitive network with structure learning""" def __init__(self, input_size: int, output_size: int): super().__init__() self.input_size = input_size self.output_size = output_size # Initialize core nodes self.nodes = nn.ModuleDict({ f'input_{i}': CognitiveNode(i, 1) for i in range(input_size) }) self.output_nodes = nn.ModuleList([ CognitiveNode(input_size + i, 1) for i in range(output_size) ]) # Structure learning parameters self.connection_strength = nn.ParameterDict() self.init_connections() # Emotional context self.emotional_state = nn.Parameter(torch.tensor(0.0)) self.learning_rate = 0.01 # Adaptive learning self.optimizer = optim.AdamW(self.parameters(), lr=0.001) self.loss_fn = nn.MSELoss() def init_connections(self): """Initialize sparse random connections""" for i in range(self.input_size): for out_node in self.output_nodes: conn_id = f'input_{i}->{out_node.id}' self.connection_strength[conn_id] = nn.Parameter( torch.randn(1) * 0.1 ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Process inputs activations = {} for i in range(self.input_size): node = self.nodes[f'input_{i}'] activations[node.id] = node(x[i].unsqueeze(0)) # Propagate through network outputs = [] for out_node in self.output_nodes: input_acts = [] for i in range(self.input_size): conn_id = f'input_{i}->{out_node.id}' weight = self.connection_strength.get(conn_id, torch.tensor(0.0)) input_acts.append(activations[i] * torch.sigmoid(weight)) if input_acts: combined = sum(input_acts) / math.sqrt(len(input_acts)) out_act = out_node(combined.unsqueeze(0)) outputs.append(out_act) return torch.cat(outputs) def structural_update(self, reward: float): """Adapt network structure based on performance""" # Strengthen productive connections for conn_id, weight in self.connection_strength.items(): if reward > 0: new_strength = weight + self.learning_rate * reward else: new_strength = weight * 0.9 self.connection_strength[conn_id].data = torch.clamp(new_strength, -1, 1) # Add new connections if performance is poor if reward < -0.5 and torch.rand(1).item() < 0.3: new_conn = self._create_new_connection() if new_conn: self.connection_strength[new_conn] = nn.Parameter( torch.randn(1) * 0.1 ) def _create_new_connection(self) -> Optional[str]: """Create new random connection between underutilized nodes""" # Find least active nodes node_activations = { node_id: sum(node.recent_activations.values()) / len(node.recent_activations) for node_id, node in self.nodes.items() if node.recent_activations } if not node_activations: return None # Select random underutilized node pair sorted_nodes = sorted(node_activations.items(), key=lambda x: x[1]) if len(sorted_nodes) < 2: return None source = sorted_nodes[0][0] target = sorted_nodes[1][0] return f"{source}->{target}" def train_step(self, x: torch.Tensor, y: torch.Tensor) -> float: """Execute a single training step""" self.optimizer.zero_grad() pred = self(x) loss = self.loss_fn(pred, y) # Add structural regularization reg_loss = sum(torch.abs(w).mean() for w in self.connection_strength.values()) total_loss = loss + 0.01 * reg_loss total_loss.backward() self.optimizer.step() # Update emotional context self.emotional_state.data = torch.sigmoid( self.emotional_state + (0.5 - loss.item()) * 0.1 ) # Structural updates self.structural_update(reward=0.5 - loss.item()) return total_loss.item()