File size: 1,700 Bytes
ccb175b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.nn.functional as F

def build_moe_connector(num_experts, num_selected):
    mm_hidden_size = 1024
    hidden_size = 4096

    return MLPMoE(
        num_experts = num_experts,
        num_selected = num_selected,
        mm_channels = mm_hidden_size,
        channels = hidden_size,
    )


class MLPMoE(nn.Module):
    def __init__(self, num_experts, num_selected, mm_channels, channels):
        super().__init__()
        self.num_experts = num_experts
        self.num_selected = num_selected
        self.mm_channels = mm_channels
        self.channels = channels

        self.gate = nn.Linear(mm_channels, num_experts, bias=False)

        self.num_selected = num_selected
        self.num_experts = num_experts
        self.experts = nn.ModuleList([nn.Sequential(nn.Linear(mm_channels, channels, bias=True), nn.GELU(), nn.Linear(channels, channels, bias=True)) for _ in range(num_experts)])
    
    def forward(self, x_img):
        gate_logits = self.gate(x_img)
        gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype)

        weights, selected_experts = torch.topk(gate_softmax, self.num_selected)
        weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype)

        results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype)
        for b in range(x_img.shape[0]):
            for i, expert in enumerate(self.experts):
                token_idx, nth_expert = torch.where(selected_experts[b] == i)
                results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx])
        
        return results