File size: 2,548 Bytes
04f8e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F

class AveragePooling(nn.Module):
    def __init__(self, pooling_size=2, device='cpu'):
        super(AveragePooling, self).__init__()
        self.pooling_size = pooling_size
        self.device = device
        self.to(device)

    def forward(self, image_features):
        batch_size, num_features, dim = image_features.size()
        height = width = int(num_features ** 0.5)
        image_features = image_features.view(batch_size, height, width, dim)
        pooled_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2), kernel_size=self.pooling_size)
        pooled_features = pooled_features.permute(0, 2, 3, 1)
        pooled_features = pooled_features.view(batch_size, -1, dim)
        return pooled_features

class AttentionPooling(nn.Module):
    def __init__(self, input_dim, pooling_size=2, device='cpu',dtype=torch.float32):
        super(AttentionPooling, self).__init__()
        self.pooling_size = pooling_size
        self.device = device
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(input_dim, 1))
        # self.mlp.to(device,dtype)

    def forward(self, x):
        batch_size, n, dim = x.shape
        sqrt_n = int(n ** 0.5)
        pooling_size = self.pooling_size
        
        x = x.view(batch_size, sqrt_n, sqrt_n, dim)
        
        pooled_features = []
        for i in range(0, sqrt_n, pooling_size):
            for j in range(0, sqrt_n, pooling_size):
                region = x[:, i:i+pooling_size, j:j+pooling_size, :]
                region = region.reshape(batch_size, -1, dim)
                alpha = self.mlp(region)
                alpha = torch.softmax(alpha, dim=1)
                region_pooled = torch.sum(alpha * region, dim=1)
                pooled_features.append(region_pooled)
        output = torch.stack(pooled_features, dim=1)
        
        return output

def build_pooling(pooling_type, input_dim=None, pooling_size=2, device='cpu',dtype=torch.float32):
    if pooling_type == 'average':
        return AveragePooling(pooling_size=pooling_size, device=device)
    elif pooling_type == 'attention':
        if input_dim is None:
            raise ValueError("input_dim must be specified for attention pooling")
        return AttentionPooling(input_dim=input_dim, pooling_size=pooling_size, device=device, dtype=dtype)
    else:
        raise ValueError("Unknown pooling type: {}".format(pooling_type))