Upload 7 files
Browse files- crossexpertattention.py +40 -0
 - meshconfig.py +64 -0
 - meshexpert.py +17 -0
 - meshlayer.py +55 -0
 - meshmodel.py +88 -0
 - meshrouter.py +27 -0
 - neighborexchange.py +82 -0
 
    	
        crossexpertattention.py
    ADDED
    
    | 
         @@ -0,0 +1,40 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Define the Cross-Expert Attention mechanism
         
     | 
| 9 | 
         
            +
            class CrossExpertAttention(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.config = config
         
     | 
| 13 | 
         
            +
                    # Define multi-head attention layers or similar for cross-expert communication
         
     | 
| 14 | 
         
            +
                    # This is a placeholder and needs detailed implementation
         
     | 
| 15 | 
         
            +
                    self.cross_attention = nn.MultiheadAttention(
         
     | 
| 16 | 
         
            +
                        embed_dim=config.hidden_size,
         
     | 
| 17 | 
         
            +
                        num_heads=config.num_attention_heads, # Using model's attention heads for now
         
     | 
| 18 | 
         
            +
                        batch_first=True
         
     | 
| 19 | 
         
            +
                    )
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def forward(self, expert_outputs):
         
     | 
| 22 | 
         
            +
                    # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    if not self.config.cross_expert_attention_enabled:
         
     | 
| 25 | 
         
            +
                        return expert_outputs
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    # Reshape for attention: (batch_size * sequence_length, num_experts, hidden_size)
         
     | 
| 28 | 
         
            +
                    batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
         
     | 
| 29 | 
         
            +
                    reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    # Apply cross-expert attention. Query, Key, Value are the same here (self-attention across experts)
         
     | 
| 32 | 
         
            +
                    # Attention mask could be used to restrict communication if needed
         
     | 
| 33 | 
         
            +
                    cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    # Reshape back: (batch_size, sequence_length, num_experts, hidden_size)
         
     | 
| 36 | 
         
            +
                    cross_attn_output = cross_attn_output.view(
         
     | 
| 37 | 
         
            +
                        expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    return cross_attn_output
         
     | 
    	
        meshconfig.py
    ADDED
    
    | 
         @@ -0,0 +1,64 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class MeshConfig(PretrainedConfig):
         
     | 
| 9 | 
         
            +
                model_type = "mesh"
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                def __init__(
         
     | 
| 12 | 
         
            +
                    self,
         
     | 
| 13 | 
         
            +
                    vocab_size=32000,
         
     | 
| 14 | 
         
            +
                    hidden_size=768,
         
     | 
| 15 | 
         
            +
                    intermediate_size=2048,
         
     | 
| 16 | 
         
            +
                    num_hidden_layers=12,
         
     | 
| 17 | 
         
            +
                    num_attention_heads=12,
         
     | 
| 18 | 
         
            +
                    num_key_value_heads=12,
         
     | 
| 19 | 
         
            +
                    max_position_embeddings=4096,
         
     | 
| 20 | 
         
            +
                    initializer_range=0.02,
         
     | 
| 21 | 
         
            +
                    rms_norm_eps=1e-6,
         
     | 
| 22 | 
         
            +
                    use_cache=True,
         
     | 
| 23 | 
         
            +
                    pad_token_id=0,
         
     | 
| 24 | 
         
            +
                    bos_token_id=1,
         
     | 
| 25 | 
         
            +
                    eos_token_id=2,
         
     | 
| 26 | 
         
            +
                    tie_word_embeddings=False,
         
     | 
| 27 | 
         
            +
                    # Mesh specific configurations
         
     | 
| 28 | 
         
            +
                    mesh_grid_size=(2, 2), # 2x2 grid
         
     | 
| 29 | 
         
            +
                    expert_intermediate_size=256, # Example size for expert intermediate layer
         
     | 
| 30 | 
         
            +
                    routing_k=2, # Top-k routing
         
     | 
| 31 | 
         
            +
                    neighbor_exchange_enabled=True,
         
     | 
| 32 | 
         
            +
                    cross_expert_attention_enabled=True,
         
     | 
| 33 | 
         
            +
                    **kwargs
         
     | 
| 34 | 
         
            +
                ):
         
     | 
| 35 | 
         
            +
                    super().__init__(
         
     | 
| 36 | 
         
            +
                        vocab_size=vocab_size,
         
     | 
| 37 | 
         
            +
                        hidden_size=hidden_size,
         
     | 
| 38 | 
         
            +
                        intermediate_size=intermediate_size,
         
     | 
| 39 | 
         
            +
                        num_hidden_layers=num_hidden_layers,
         
     | 
| 40 | 
         
            +
                        num_attention_heads=num_attention_heads,
         
     | 
| 41 | 
         
            +
                        num_key_value_heads=num_key_value_heads,
         
     | 
| 42 | 
         
            +
                        max_position_embeddings=max_position_embeddings,
         
     | 
| 43 | 
         
            +
                        initializer_range=initializer_range,
         
     | 
| 44 | 
         
            +
                        rms_norm_eps=rms_norm_eps,
         
     | 
| 45 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 46 | 
         
            +
                        pad_token_id=pad_token_id,
         
     | 
| 47 | 
         
            +
                        bos_token_id=bos_token_id,
         
     | 
| 48 | 
         
            +
                        eos_token_id=eos_token_id,
         
     | 
| 49 | 
         
            +
                        tie_word_embeddings=tie_word_embeddings,
         
     | 
| 50 | 
         
            +
                        **kwargs,
         
     | 
| 51 | 
         
            +
                    )
         
     | 
| 52 | 
         
            +
                    self.mesh_grid_size = mesh_grid_size
         
     | 
| 53 | 
         
            +
                    # Calculate expert_intermediate_size based on the shared and expert parameter split
         
     | 
| 54 | 
         
            +
                    # Total parameters = Shared (Embedding, Norm, LM Head) + Experts + Overhead
         
     | 
| 55 | 
         
            +
                    # This calculation is complex and depends on the specific layer mapping.
         
     | 
| 56 | 
         
            +
                    # For now, let's use a placeholder or calculate it based on the target parameter count.
         
     | 
| 57 | 
         
            +
                    # Target A242M (top-2): 100M shared + 135M (2 experts) + 7M overhead = 242M
         
     | 
| 58 | 
         
            +
                    # Let's assume the 135M for 2 experts is primarily in the intermediate size.
         
     | 
| 59 | 
         
            +
                    # We need to determine how Gemma's intermediate size maps to the expert intermediate size.
         
     | 
| 60 | 
         
            +
                    # For now, I will keep a placeholder or a simple ratio.
         
     | 
| 61 | 
         
            +
                    self.expert_intermediate_size = intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]) # Example: divide intermediate size by number of experts
         
     | 
| 62 | 
         
            +
                    self.routing_k = routing_k
         
     | 
| 63 | 
         
            +
                    self.neighbor_exchange_enabled = neighbor_exchange_enabled
         
     | 
| 64 | 
         
            +
                    self.cross_expert_attention_enabled = cross_expert_attention_enabled
         
     | 
    	
        meshexpert.py
    ADDED
    
    | 
         @@ -0,0 +1,17 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Define a single Expert within the Mesh
         
     | 
| 9 | 
         
            +
            class MeshExpert(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
         
     | 
| 13 | 
         
            +
                    self.gelu = nn.GELU() # Using GELU as an example activation
         
     | 
| 14 | 
         
            +
                    self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def forward(self, x):
         
     | 
| 17 | 
         
            +
                    return self.fc2(self.gelu(self.fc1(x)))
         
     | 
    	
        meshlayer.py
    ADDED
    
    | 
         @@ -0,0 +1,55 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Define the main Mesh Layer
         
     | 
| 9 | 
         
            +
            class MeshLayer(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.config = config
         
     | 
| 13 | 
         
            +
                    self.router = MeshRouter(config)
         
     | 
| 14 | 
         
            +
                    self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
         
     | 
| 15 | 
         
            +
                    self.neighbor_exchange = NeighborExchange(config)
         
     | 
| 16 | 
         
            +
                    self.cross_expert_attention = CrossExpertAttention(config)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 19 | 
         
            +
                    # hidden_states shape: (batch_size, sequence_length, hidden_size)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    # 1. Routing
         
     | 
| 22 | 
         
            +
                    topk_weights, topk_indices = self.router(hidden_states)
         
     | 
| 23 | 
         
            +
                    # topk_weights shape: (batch_size, sequence_length, k)
         
     | 
| 24 | 
         
            +
                    # topk_indices shape: (batch_size, sequence_length, k)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    # Prepare expert inputs: repeat hidden_states for each expert
         
     | 
| 27 | 
         
            +
                    # shape: (batch_size, sequence_length, num_experts, hidden_size)
         
     | 
| 28 | 
         
            +
                    expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    # 2. Expert Computation
         
     | 
| 31 | 
         
            +
                    # Compute output for all experts (can be optimized to only compute for selected experts)
         
     | 
| 32 | 
         
            +
                    expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2)
         
     | 
| 33 | 
         
            +
                    # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    # 3. Neighbor Exchange (conceptual implementation needed)
         
     | 
| 36 | 
         
            +
                    exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # 4. Cross-Expert Attention (conceptual implementation needed)
         
     | 
| 39 | 
         
            +
                    cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    # 5. Combine expert outputs based on routing weights
         
     | 
| 42 | 
         
            +
                    # Create a tensor to gather the outputs of the selected experts
         
     | 
| 43 | 
         
            +
                    # shape: (batch_size, sequence_length, k, hidden_size)
         
     | 
| 44 | 
         
            +
                    gathered_outputs = torch.gather(
         
     | 
| 45 | 
         
            +
                        cross_attned_expert_outputs,
         
     | 
| 46 | 
         
            +
                        dim=2,
         
     | 
| 47 | 
         
            +
                        index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
         
     | 
| 48 | 
         
            +
                    )
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    # Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size)
         
     | 
| 51 | 
         
            +
                    combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
         
     | 
| 52 | 
         
            +
                    # combined_output shape: (batch_size, sequence_length, hidden_size)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # Return the combined output and the expert indices for potential visualization
         
     | 
| 55 | 
         
            +
                    return combined_output, topk_indices # Return combined output and expert indices
         
     | 
    	
        meshmodel.py
    ADDED
    
    | 
         @@ -0,0 +1,88 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class MeshModel(PreTrainedModel):
         
     | 
| 9 | 
         
            +
                config_class = MeshConfig
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 12 | 
         
            +
                    super().__init__(config)
         
     | 
| 13 | 
         
            +
                    self.config = config
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                    self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
         
     | 
| 16 | 
         
            +
                    self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
         
     | 
| 17 | 
         
            +
                    self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 18 | 
         
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    self.post_init()
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                def forward(
         
     | 
| 23 | 
         
            +
                    self,
         
     | 
| 24 | 
         
            +
                    input_ids=None,
         
     | 
| 25 | 
         
            +
                    attention_mask=None,
         
     | 
| 26 | 
         
            +
                    token_type_ids=None,
         
     | 
| 27 | 
         
            +
                    position_ids=None,
         
     | 
| 28 | 
         
            +
                    head_mask=None,
         
     | 
| 29 | 
         
            +
                    inputs_embeds=None,
         
     | 
| 30 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 31 | 
         
            +
                    encoder_attention_mask=None,
         
     | 
| 32 | 
         
            +
                    labels=None,
         
     | 
| 33 | 
         
            +
                    past_key_values=None,
         
     | 
| 34 | 
         
            +
                    use_cache=None,
         
     | 
| 35 | 
         
            +
                    output_attentions=None,
         
     | 
| 36 | 
         
            +
                    output_hidden_states=None,
         
     | 
| 37 | 
         
            +
                    return_dict=None,
         
     | 
| 38 | 
         
            +
                ):
         
     | 
| 39 | 
         
            +
                    # Ensure return_dict is set to True by default if not specified
         
     | 
| 40 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    if input_ids is not None and inputs_embeds is not None:
         
     | 
| 43 | 
         
            +
                        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         
     | 
| 44 | 
         
            +
                    elif input_ids is not None:
         
     | 
| 45 | 
         
            +
                        input_shape = input_ids.size()
         
     | 
| 46 | 
         
            +
                        inputs_embeds = self.embedding(input_ids)
         
     | 
| 47 | 
         
            +
                    elif inputs_embeds is not None:
         
     | 
| 48 | 
         
            +
                        input_shape = inputs_embeds.size()[:-1]
         
     | 
| 49 | 
         
            +
                    else:
         
     | 
| 50 | 
         
            +
                        raise ValueError("You have to specify either input_ids or inputs_embeds")
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    hidden_states = inputs_embeds
         
     | 
| 53 | 
         
            +
                    expert_indices_list = [] # To collect expert indices from each layer
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    for i, layer in enumerate(self.layers):
         
     | 
| 56 | 
         
            +
                        hidden_states, expert_indices = layer(hidden_states)
         
     | 
| 57 | 
         
            +
                        expert_indices_list.append(expert_indices) # Collect indices
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 60 | 
         
            +
                    logits = self.lm_head(hidden_states)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    loss = None
         
     | 
| 63 | 
         
            +
                    if labels is not None:
         
     | 
| 64 | 
         
            +
                        # Compute loss (e.g., CrossEntropyLoss)
         
     | 
| 65 | 
         
            +
                        loss_fct = nn.CrossEntropyLoss()
         
     | 
| 66 | 
         
            +
                        # Shift so that tokens < n predict n
         
     | 
| 67 | 
         
            +
                        shift_logits = logits[..., :-1, :].contiguous()
         
     | 
| 68 | 
         
            +
                        shift_labels = labels[..., 1:].contiguous()
         
     | 
| 69 | 
         
            +
                        # Calculate scalar loss
         
     | 
| 70 | 
         
            +
                        loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    # Return a CausalLMOutputWithPast object or a tuple
         
     | 
| 73 | 
         
            +
                    if return_dict:
         
     | 
| 74 | 
         
            +
                         return CausalLMOutputWithPast(
         
     | 
| 75 | 
         
            +
                             loss=loss,
         
     | 
| 76 | 
         
            +
                             logits=logits,
         
     | 
| 77 | 
         
            +
                             past_key_values=None, # Need to implement caching
         
     | 
| 78 | 
         
            +
                             hidden_states=hidden_states,
         
     | 
| 79 | 
         
            +
                             attentions=None, # Need to implement attention handling
         
     | 
| 80 | 
         
            +
                         )
         
     | 
| 81 | 
         
            +
                    else:
         
     | 
| 82 | 
         
            +
                         # Return a tuple including loss, logits, and collected expert indices
         
     | 
| 83 | 
         
            +
                         # Ensure the order and content match what the Trainer expects or can handle
         
     | 
| 84 | 
         
            +
                         # Trainer expects (loss, logits, hidden_states, attentions) or similar.
         
     | 
| 85 | 
         
            +
                         # We can return (loss, logits) as the primary outputs for the Trainer
         
     | 
| 86 | 
         
            +
                         # and potentially include expert_indices as an additional output if needed
         
     | 
| 87 | 
         
            +
                         # by a custom callback or logging, but the default Trainer expects loss as the first element for backward.
         
     | 
| 88 | 
         
            +
                         return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list
         
     | 
    	
        meshrouter.py
    ADDED
    
    | 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Define the Router for dynamic routing
         
     | 
| 9 | 
         
            +
            class MeshRouter(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
         
     | 
| 13 | 
         
            +
                    self.softmax = nn.Softmax(dim=-1)
         
     | 
| 14 | 
         
            +
                    self.routing_k = config.routing_k
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def forward(self, x):
         
     | 
| 17 | 
         
            +
                    # x shape: (batch_size, sequence_length, hidden_size)
         
     | 
| 18 | 
         
            +
                    gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
         
     | 
| 19 | 
         
            +
                    gate_weights = self.softmax(gate_scores)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    # Select top-k experts
         
     | 
| 22 | 
         
            +
                    topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    # Normalize top-k weights
         
     | 
| 25 | 
         
            +
                    topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
         
     | 
    	
        neighborexchange.py
    ADDED
    
    | 
         @@ -0,0 +1,82 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            class NeighborExchange(nn.Module):
         
     | 
| 5 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 6 | 
         
            +
                    super().__init__()
         
     | 
| 7 | 
         
            +
                    self.config = config
         
     | 
| 8 | 
         
            +
                    self.num_experts_x = config.mesh_grid_size[0]
         
     | 
| 9 | 
         
            +
                    self.num_experts_y = config.mesh_grid_size[1]
         
     | 
| 10 | 
         
            +
                    self.num_experts = self.num_experts_x * self.num_experts_y
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    # Define parameters for neighbor communication.
         
     | 
| 13 | 
         
            +
                    # A simple approach: a learned linear combination of neighbor features.
         
     | 
| 14 | 
         
            +
                    # We can define a weight for each potential neighbor direction (e.g., up, down, left, right).
         
     | 
| 15 | 
         
            +
                    # For a 2x2 grid, each expert has 2 or 3 neighbors.
         
     | 
| 16 | 
         
            +
                    # A more general approach is a linear layer that takes concatenated neighbor features.
         
     | 
| 17 | 
         
            +
                    # Let's use a linear layer to transform the aggregated neighbor information.
         
     | 
| 18 | 
         
            +
                    # The input size to this layer will be the sum of hidden sizes of all potential neighbors
         
     | 
| 19 | 
         
            +
                    # multiplied by the hidden size, but that's too complex.
         
     | 
| 20 | 
         
            +
                    # A simpler approach: a linear layer per direction, or a single layer after aggregating.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    # Let's define a linear layer to process the information received from neighbors.
         
     | 
| 23 | 
         
            +
                    # The input size is the hidden size (from neighbors), output size is hidden size
         
     | 
| 24 | 
         
            +
                    # This layer will transform the aggregated neighbor features before adding to the expert's own output.
         
     | 
| 25 | 
         
            +
                    self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size) # Projects aggregated neighbor info
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    # Optional: Learned weights for different neighbor directions
         
     | 
| 28 | 
         
            +
                    # self.neighbor_weights = nn.Parameter(torch.ones(4)) # Example for 4 directions (N, S, E, W)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def forward(self, expert_outputs, expert_indices=None):
         
     | 
| 31 | 
         
            +
                    # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
         
     | 
| 32 | 
         
            +
                    # expert_indices shape: (batch_size, sequence_length, k) - indices of selected experts (not directly used for neighbor exchange in this simple model)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    if not self.config.neighbor_exchange_enabled:
         
     | 
| 35 | 
         
            +
                        return expert_outputs
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    # Reshape expert_outputs to reflect the grid structure (batch_size, seq_length, grid_x, grid_y, hidden_size)
         
     | 
| 40 | 
         
            +
                    reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    # Create a tensor to store the aggregated neighbor information for each expert
         
     | 
| 43 | 
         
            +
                    aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    # Implement neighbor exchange logic
         
     | 
| 46 | 
         
            +
                    # Iterate through each expert in the grid
         
     | 
| 47 | 
         
            +
                    for i in range(self.num_experts_x):
         
     | 
| 48 | 
         
            +
                        for j in range(self.num_experts_y):
         
     | 
| 49 | 
         
            +
                            current_expert_output = reshaped_outputs[:, :, i, j, :]
         
     | 
| 50 | 
         
            +
                            neighbor_info = torch.zeros_like(current_expert_output) # Accumulate info from neighbors
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                            # Define neighbor directions (example: up, down, left, right)
         
     | 
| 53 | 
         
            +
                            neighbors = []
         
     | 
| 54 | 
         
            +
                            if i > 0: # Up neighbor
         
     | 
| 55 | 
         
            +
                                neighbors.append(reshaped_outputs[:, :, i-1, j, :])
         
     | 
| 56 | 
         
            +
                            if i < self.num_experts_x - 1: # Down neighbor
         
     | 
| 57 | 
         
            +
                                neighbors.append(reshaped_outputs[:, :, i+1, j, :])
         
     | 
| 58 | 
         
            +
                            if j > 0: # Left neighbor
         
     | 
| 59 | 
         
            +
                                neighbors.append(reshaped_outputs[:, :, i, j-1, :])
         
     | 
| 60 | 
         
            +
                            if j < self.num_experts_y - 1: # Right neighbor
         
     | 
| 61 | 
         
            +
                                neighbors.append(reshaped_outputs[:, :, i, j+1, :])
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                            # Aggregate information from neighbors (simple average as an example)
         
     | 
| 64 | 
         
            +
                            if neighbors:
         
     | 
| 65 | 
         
            +
                                # Stack neighbors along a new dimension and take the mean
         
     | 
| 66 | 
         
            +
                                neighbor_stack = torch.stack(neighbors, dim=-2) # shape (batch, seq, num_neighbors, hidden)
         
     | 
| 67 | 
         
            +
                                aggregated_info = torch.mean(neighbor_stack, dim=-2) # shape (batch, seq, hidden)
         
     | 
| 68 | 
         
            +
                                neighbor_info = aggregated_info # Use the aggregated info
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                            # Apply the exchange projection to the aggregated neighbor information
         
     | 
| 71 | 
         
            +
                            transformed_neighbor_info = self.exchange_projection(neighbor_info)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                            # Store the transformed neighbor info for the current expert's position
         
     | 
| 74 | 
         
            +
                            aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # Reshape aggregated_neighbor_info back to (batch_size, sequence_length, num_experts, hidden_size)
         
     | 
| 77 | 
         
            +
                    aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    # Combine expert outputs with aggregated neighbor information (additive combination)
         
     | 
| 80 | 
         
            +
                    exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    return exchanged_expert_outputs
         
     |