| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def build_moe_connector(num_experts, num_selected): | |
| mm_hidden_size = 1024 | |
| hidden_size = 4096 | |
| return MLPMoE( | |
| num_experts = num_experts, | |
| num_selected = num_selected, | |
| mm_channels = mm_hidden_size, | |
| channels = hidden_size, | |
| ) | |
| class MLPMoE(nn.Module): | |
| def __init__(self, num_experts, num_selected, mm_channels, channels): | |
| super().__init__() | |
| self.num_experts = num_experts | |
| self.num_selected = num_selected | |
| self.mm_channels = mm_channels | |
| self.channels = channels | |
| self.gate = nn.Linear(mm_channels, num_experts, bias=False) | |
| self.num_selected = num_selected | |
| self.num_experts = num_experts | |
| self.experts = nn.ModuleList([nn.Sequential(nn.Linear(mm_channels, channels, bias=True), nn.GELU(), nn.Linear(channels, channels, bias=True)) for _ in range(num_experts)]) | |
| def forward(self, x_img): | |
| gate_logits = self.gate(x_img) | |
| gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype) | |
| weights, selected_experts = torch.topk(gate_softmax, self.num_selected) | |
| weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype) | |
| results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype) | |
| for b in range(x_img.shape[0]): | |
| for i, expert in enumerate(self.experts): | |
| token_idx, nth_expert = torch.where(selected_experts[b] == i) | |
| results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx]) | |
| return results |