dhruv2842 commited on
Commit
d10a3b5
·
verified ·
1 Parent(s): 44cd570

Upload 3 files

Browse files
Files changed (3) hide show
  1. glam_efficientnet_model.py +106 -0
  2. glam_module.py +71 -0
  3. swin_module.py +72 -0
glam_efficientnet_model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig, EfficientNetModel
5
+ from typing import Optional, Union
6
+
7
+ # --------------------------------------------------
8
+ # Import your GLAM, SwinWindowAttention blocks here
9
+ # --------------------------------------------------
10
+ # from .glam_module import GLAM
11
+ # from .swin_module import SwinWindowAttention
12
+
13
+ from glam_module import GLAM
14
+ from swin_module import SwinWindowAttention
15
+
16
+ class GLAMEfficientNetConfig(PretrainedConfig):
17
+ """Hugging Face-style configuration for GLAM EfficientNet."""
18
+ model_type = "glam_efficientnet"
19
+
20
+ def __init__(self,
21
+ num_classes: int = 3,
22
+ embed_dim: int = 512,
23
+ num_heads: int = 8,
24
+ window_size: int = 7,
25
+ reduction_ratio: int = 8,
26
+ dropout: float = 0.5,
27
+ **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.num_classes = num_classes
30
+ self.embed_dim = embed_dim
31
+ self.num_heads = num_heads
32
+ self.window_size = window_size
33
+ self.reduction_ratio = reduction_ratio
34
+ self.dropout = dropout
35
+
36
+
37
+ class GLAMEfficientNetForClassification(PreTrainedModel):
38
+ """Hugging Face-style Model for EfficientNet + GLAM + Swin Architecture."""
39
+ config_class = GLAMEfficientNetConfig
40
+
41
+ def __init__(self, config: GLAMEfficientNetConfig, glam_module_cls, swin_module_cls):
42
+ super().__init__(config)
43
+
44
+ # ✅ 1) Hugging Face EfficientNet Backbone
45
+ self.features = EfficientNetModel.from_pretrained("google/efficientnet-b0")
46
+
47
+ # ✅ 1x1 conv for channel adjustment
48
+ self.conv1x1 = nn.Conv2d(1280, config.embed_dim, kernel_size=1)
49
+
50
+ # ✅ 2) Swin Attention Block
51
+ self.swin_attn = swin_module_cls(
52
+ embed_dim=config.embed_dim,
53
+ window_size=config.window_size,
54
+ num_heads=config.num_heads,
55
+ dropout=config.dropout
56
+ )
57
+ self.pre_attn_norm = nn.LayerNorm(config.embed_dim)
58
+ self.post_attn_norm = nn.LayerNorm(config.embed_dim)
59
+
60
+ # ✅ 3) GLAM Block
61
+ self.glam = glam_module_cls(in_channels=config.embed_dim, reduction_ratio=config.reduction_ratio)
62
+
63
+ # ✅ 4) Self-Adaptive Gating
64
+ self.gate_fc = nn.Linear(config.embed_dim, 1)
65
+
66
+ # ✅ Final classification
67
+ self.dropout = nn.Dropout(config.dropout)
68
+ self.classifier = nn.Linear(config.embed_dim, config.num_classes)
69
+
70
+ def forward(self, pixel_values, labels=None, **kwargs):
71
+ """Perform forward pass."""
72
+ # ✅ 1) EfficientNet Backbone
73
+ backbone_output = self.features(pixel_values) # Returns BaseModelOutput
74
+ feats = backbone_output.last_hidden_state # [B, C, H', W']
75
+ feats = self.conv1x1(feats) # Adjust channel dims
76
+ B, C, H, W = feats.shape
77
+
78
+ # ✅ 2) Transformer Branch
79
+ x_perm = feats.permute(0, 2, 3, 1).contiguous() # [B, H', W', C]
80
+ x_norm = self.pre_attn_norm(x_perm).permute(0, 3, 1, 2).contiguous()
81
+ x_norm = self.dropout(x_norm)
82
+
83
+ T_out = self.swin_attn(x_norm) # [B, C, H', W']
84
+
85
+ T_out = self.post_attn_norm(T_out.permute(0, 2, 3, 1).contiguous())
86
+ T_out = T_out.permute(0, 3, 1, 2).contiguous()
87
+
88
+ # ✅ 3) GLAM Branch
89
+ G_out = self.glam(feats)
90
+
91
+ # ✅ 4) Self-Adaptive Gating
92
+ gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
93
+ g = torch.sigmoid(self.gate_fc(gap_feats)).view(B, 1, 1, 1)
94
+
95
+ F_out = g * T_out + (1 - g) * G_out
96
+
97
+ # ✅ Final Pooling & Classifier
98
+ pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
99
+ logits = self.classifier(self.dropout(pooled))
100
+
101
+ loss = None
102
+ if labels is not None:
103
+ loss = F.cross_entropy(logits, labels)
104
+
105
+ return {"loss": loss, "logits": logits}
106
+
glam_module.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class GLAM(nn.Module):
2
+ """
3
+ Global-Local Attention Module (GLAM) that produces a refined feature map.
4
+ """
5
+ def __init__(self, in_channels, reduction_ratio=8):
6
+ super(GLAM, self).__init__()
7
+
8
+ # --- Local Channel Attention ---
9
+ self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
10
+ self.local_channel_act = nn.Sigmoid()
11
+ self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
12
+
13
+ # --- Local Spatial Attention ---
14
+ # 3-dilated, 5-dilated conv merges
15
+ self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
16
+ self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
17
+ self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
18
+ self.local_spatial_act = nn.Sigmoid()
19
+
20
+ # --- Global Channel Attention ---
21
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
22
+ self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
23
+ self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
24
+ self.global_channel_act = nn.Sigmoid()
25
+
26
+ # --- Global Spatial Attention ---
27
+ self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
28
+ self.global_spatial_softmax = nn.Softmax(dim=-1)
29
+
30
+
31
+ # --- Weighted paramerers initialization ---
32
+ self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
33
+ self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
34
+
35
+
36
+ def forward(self, x):
37
+ # Local Channel Attention
38
+ lca = self.local_channel_conv(x)
39
+ lca = self.local_channel_act(lca)
40
+ lca = self.local_channel_expand(lca)
41
+ lca_out = lca * x
42
+
43
+ # Local Spatial Attention
44
+ lsa3 = self.local_spatial_conv3(x)
45
+ lsa5 = self.local_spatial_conv5(x)
46
+ lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
47
+ lsa = self.local_spatial_merge(lsa_cat)
48
+ lsa = self.local_spatial_act(lsa)
49
+ lsa_out = lsa * lca_out
50
+ lsa_out = lsa_out + lca_out
51
+
52
+ # Global Channel Attention
53
+ B, C, H, W = x.size()
54
+ gca = self.global_avg_pool(x).view(B, C)
55
+ gca = F.relu(self.global_channel_fc1(gca), inplace=True)
56
+ gca = self.global_channel_fc2(gca)
57
+ gca = self.global_channel_act(gca)
58
+ gca = gca.view(B, C, 1, 1)
59
+ gca_out = gca * x
60
+
61
+ # Global Spatial Attention
62
+ gsa = self.global_spatial_conv(x) # [B, 1, H, W]
63
+ gsa = gsa.view(B, -1) # [B, H*W]
64
+ gsa = self.global_spatial_softmax(gsa)
65
+ gsa = gsa.view(B, 1, H, W)
66
+ gsa_out = gsa * gca_out
67
+ gsa_out = gsa_out + gca_out
68
+
69
+ # Fuse
70
+ out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
71
+ return out
swin_module.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------
2
+ # 2. SWIN-STYLE TRANSFORMER UTILS
3
+ # -------------------------------
4
+ def window_partition(x, window_size):
5
+ """
6
+ x: (B, H, W, C)
7
+ Returns windows of shape: (num_windows*B, window_size*window_size, C)
8
+ """
9
+ B, H, W, C = x.shape
10
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
11
+ # permute to gather patches
12
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
13
+ # merge dimension
14
+ windows = x.view(-1, window_size * window_size, C)
15
+ return windows
16
+
17
+ def window_reverse(windows, window_size, H, W):
18
+ """
19
+ Reverse of window_partition.
20
+ windows: (num_windows*B, window_size*window_size, C)
21
+ Returns: (B, H, W, C)
22
+ """
23
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
24
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
25
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
26
+ x = x.view(B, H, W, -1)
27
+ return x
28
+
29
+ class SwinWindowAttention(nn.Module):
30
+ """
31
+ A simplified Swin-like window attention block:
32
+ 1) Partition input into windows
33
+ 2) Perform multi-head self-attn
34
+ 3) Merge back
35
+ """
36
+ def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
37
+ super(SwinWindowAttention, self).__init__()
38
+ self.embed_dim = embed_dim
39
+ self.window_size = window_size
40
+ self.num_heads = num_heads
41
+
42
+ self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
43
+ self.dropout = nn.Dropout(dropout)
44
+
45
+ def forward(self, x):
46
+ # x: (B, C, H, W) --> rearrange to (B, H, W, C)
47
+ B, C, H, W = x.shape
48
+ x = x.permute(0, 2, 3, 1).contiguous()
49
+
50
+ # pad if needed so H, W are multiples of window_size
51
+ pad_h = (self.window_size - H % self.window_size) % self.window_size
52
+ pad_w = (self.window_size - W % self.window_size) % self.window_size
53
+ if pad_h or pad_w:
54
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
55
+
56
+ Hp, Wp = x.shape[1], x.shape[2]
57
+ # Partition into windows
58
+ windows = window_partition(x, self.window_size) # shape: (num_windows*B, window_size*window_size, C)
59
+ # Multi-head self-attn
60
+ attn_windows, _ = self.mha(windows, windows, windows)
61
+ attn_windows = self.dropout(attn_windows)
62
+
63
+ # Reverse window partition
64
+ x = window_reverse(attn_windows, self.window_size, Hp, Wp)
65
+
66
+ # Remove padding if added
67
+ if pad_h or pad_w:
68
+ x = x[:, :H, :W, :].contiguous()
69
+
70
+ # back to (B, C, H, W)
71
+ x = x.permute(0, 3, 1, 2).contiguous()
72
+ return x