ZubairAhmed777 commited on
Commit
2b65df9
·
verified ·
1 Parent(s): 2d20c0d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +81 -0
model.py CHANGED
@@ -42,6 +42,27 @@ class ImageEncoder(nn.Module):
42
 
43
  return l2_norm
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class QuesEncoder(nn.Module):
46
  def __init__(self, ques_vocab_size, word_embed, hidden_size, num_hidden, qu_feature_size):
47
  super(QuesEncoder, self).__init__()
@@ -137,3 +158,63 @@ class VQAModel(nn.Module):
137
  logits = self.fc2(combined_feature)
138
 
139
  return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  return l2_norm
44
 
45
+ class ImageEncoder_attn(nn.Module):
46
+ def __init__(self, embed_dim):
47
+ super(ImageEncoder_attn, self).__init__()
48
+ # Load a pretrained VGG19 model
49
+ self.model = models.vgg19(pretrained=True).features
50
+ # Adding a 1x1 convolutional layer to map features to the desired embedding dimension
51
+ self.conv = nn.Conv2d(512, embed_dim, kernel_size=1)
52
+
53
+ def forward(self, image):
54
+ # Extracting spatial features of the image using the modified VGG19 model
55
+ with torch.no_grad(): # Freezing the weights of the pretrained model during this pass
56
+ img_features = self.model(image) # Shape: (batch_size, 512, H, W)
57
+
58
+ # Map features to the desired embedding dimension
59
+ img_features = self.conv(img_features) # Shape: (batch_size, embed_dim, H, W)
60
+
61
+ # Flatten spatial dimensions to get per-region features
62
+ img_features = img_features.flatten(2).permute(0, 2, 1) # Shape: (batch_size, num_regions, embed_dim)
63
+
64
+ return img_features
65
+
66
  class QuesEncoder(nn.Module):
67
  def __init__(self, ques_vocab_size, word_embed, hidden_size, num_hidden, qu_feature_size):
68
  super(QuesEncoder, self).__init__()
 
158
  logits = self.fc2(combined_feature)
159
 
160
  return logits
161
+
162
+ class VQAModel_attn(nn.Module):
163
+ def __init__(self, feature_size, ques_vocab_size, ans_vocab_size, word_embed, hidden_size, num_hidden):
164
+ super(VQAModel_attn, self).__init__()
165
+
166
+ # Encoder to extract image features
167
+ self.img_encoder = ImageEncoder_attn(feature_size)
168
+
169
+ # Encoder to extract question features
170
+ self.ques_encoder = QuesEncoder(ques_vocab_size, word_embed, hidden_size, num_hidden, feature_size)
171
+
172
+ # Attention mechanism layers
173
+ self.attention_fc = nn.Linear(2 * feature_size, 1) # For compatibility scoring
174
+
175
+ # Dropout layer
176
+ self.dropout = nn.Dropout(0.5)
177
+
178
+ # Fully connected layers for answer prediction
179
+ self.fc1 = nn.Linear(feature_size, ans_vocab_size)
180
+ self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)
181
+
182
+ def forward(self, image, question):
183
+ # Extract image features (batch_size, num_regions, feature_size)
184
+ img_features = self.img_encoder(image)
185
+
186
+ # Extract question features (batch_size, feature_size)
187
+ qst_feature = self.ques_encoder(question)
188
+
189
+ # Ensure qst_feature has the correct dimensions
190
+ # Expand to (batch_size, 1, feature_size), then repeat to match num_regions
191
+ qst_feature_exp = qst_feature.unsqueeze(1).expand(-1, img_features.size(1), -1)
192
+
193
+ #print(f"img_features shape: {img_features.shape}")
194
+ #print(f"qst_feature shape: {qst_feature.shape}")
195
+ #print(f"qst_feature_exp shape: {qst_feature_exp.shape}")
196
+ # Concatenate image and question features along the last dimension
197
+ # Shape: (batch_size, num_regions, 2 * feature_size)
198
+ combined_features = torch.cat([img_features, qst_feature_exp], dim=-1)
199
+
200
+ # Compute attention scores for each region
201
+ # Shape: (batch_size, num_regions, 1)
202
+ attention_scores = self.attention_fc(combined_features)
203
+
204
+ # Apply softmax to get attention weights
205
+ # Shape: (batch_size, num_regions)
206
+ attention_weights = F.softmax(attention_scores.squeeze(-1), dim=1)
207
+
208
+ # Compute the weighted sum of image features
209
+ # Shape: (batch_size, feature_size)
210
+ attended_img_feature = torch.sum(img_features * attention_weights.unsqueeze(-1), dim=1)
211
+
212
+ # Combine attended image features with question features
213
+ combined_feature = attended_img_feature + qst_feature
214
+
215
+ # Dropout and fully connected layers for answer prediction
216
+ combined_feature = self.dropout(combined_feature)
217
+ combined_feature = F.relu(self.fc1(combined_feature))
218
+ logits = self.fc2(combined_feature)
219
+
220
+ return logits