Update meshrouter.py
Browse files- meshrouter.py +26 -2
 
    	
        meshrouter.py
    CHANGED
    
    | 
         @@ -1,3 +1,27 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 2 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            # Define the Router for dynamic routing
         
     | 
| 9 | 
         
            +
            class MeshRouter(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, config: MeshConfig):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
         
     | 
| 13 | 
         
            +
                    self.softmax = nn.Softmax(dim=-1)
         
     | 
| 14 | 
         
            +
                    self.routing_k = config.routing_k
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def forward(self, x):
         
     | 
| 17 | 
         
            +
                    # x shape: (batch_size, sequence_length, hidden_size)
         
     | 
| 18 | 
         
            +
                    gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
         
     | 
| 19 | 
         
            +
                    gate_weights = self.softmax(gate_scores)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    # Select top-k experts
         
     | 
| 22 | 
         
            +
                    topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    # Normalize top-k weights
         
     | 
| 25 | 
         
            +
                    topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
         
     |