Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -42,6 +42,27 @@ class ImageEncoder(nn.Module):
|
|
42 |
|
43 |
return l2_norm
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
class QuesEncoder(nn.Module):
|
46 |
def __init__(self, ques_vocab_size, word_embed, hidden_size, num_hidden, qu_feature_size):
|
47 |
super(QuesEncoder, self).__init__()
|
@@ -137,3 +158,63 @@ class VQAModel(nn.Module):
|
|
137 |
logits = self.fc2(combined_feature)
|
138 |
|
139 |
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
return l2_norm
|
44 |
|
45 |
+
class ImageEncoder_attn(nn.Module):
|
46 |
+
def __init__(self, embed_dim):
|
47 |
+
super(ImageEncoder_attn, self).__init__()
|
48 |
+
# Load a pretrained VGG19 model
|
49 |
+
self.model = models.vgg19(pretrained=True).features
|
50 |
+
# Adding a 1x1 convolutional layer to map features to the desired embedding dimension
|
51 |
+
self.conv = nn.Conv2d(512, embed_dim, kernel_size=1)
|
52 |
+
|
53 |
+
def forward(self, image):
|
54 |
+
# Extracting spatial features of the image using the modified VGG19 model
|
55 |
+
with torch.no_grad(): # Freezing the weights of the pretrained model during this pass
|
56 |
+
img_features = self.model(image) # Shape: (batch_size, 512, H, W)
|
57 |
+
|
58 |
+
# Map features to the desired embedding dimension
|
59 |
+
img_features = self.conv(img_features) # Shape: (batch_size, embed_dim, H, W)
|
60 |
+
|
61 |
+
# Flatten spatial dimensions to get per-region features
|
62 |
+
img_features = img_features.flatten(2).permute(0, 2, 1) # Shape: (batch_size, num_regions, embed_dim)
|
63 |
+
|
64 |
+
return img_features
|
65 |
+
|
66 |
class QuesEncoder(nn.Module):
|
67 |
def __init__(self, ques_vocab_size, word_embed, hidden_size, num_hidden, qu_feature_size):
|
68 |
super(QuesEncoder, self).__init__()
|
|
|
158 |
logits = self.fc2(combined_feature)
|
159 |
|
160 |
return logits
|
161 |
+
|
162 |
+
class VQAModel_attn(nn.Module):
|
163 |
+
def __init__(self, feature_size, ques_vocab_size, ans_vocab_size, word_embed, hidden_size, num_hidden):
|
164 |
+
super(VQAModel_attn, self).__init__()
|
165 |
+
|
166 |
+
# Encoder to extract image features
|
167 |
+
self.img_encoder = ImageEncoder_attn(feature_size)
|
168 |
+
|
169 |
+
# Encoder to extract question features
|
170 |
+
self.ques_encoder = QuesEncoder(ques_vocab_size, word_embed, hidden_size, num_hidden, feature_size)
|
171 |
+
|
172 |
+
# Attention mechanism layers
|
173 |
+
self.attention_fc = nn.Linear(2 * feature_size, 1) # For compatibility scoring
|
174 |
+
|
175 |
+
# Dropout layer
|
176 |
+
self.dropout = nn.Dropout(0.5)
|
177 |
+
|
178 |
+
# Fully connected layers for answer prediction
|
179 |
+
self.fc1 = nn.Linear(feature_size, ans_vocab_size)
|
180 |
+
self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)
|
181 |
+
|
182 |
+
def forward(self, image, question):
|
183 |
+
# Extract image features (batch_size, num_regions, feature_size)
|
184 |
+
img_features = self.img_encoder(image)
|
185 |
+
|
186 |
+
# Extract question features (batch_size, feature_size)
|
187 |
+
qst_feature = self.ques_encoder(question)
|
188 |
+
|
189 |
+
# Ensure qst_feature has the correct dimensions
|
190 |
+
# Expand to (batch_size, 1, feature_size), then repeat to match num_regions
|
191 |
+
qst_feature_exp = qst_feature.unsqueeze(1).expand(-1, img_features.size(1), -1)
|
192 |
+
|
193 |
+
#print(f"img_features shape: {img_features.shape}")
|
194 |
+
#print(f"qst_feature shape: {qst_feature.shape}")
|
195 |
+
#print(f"qst_feature_exp shape: {qst_feature_exp.shape}")
|
196 |
+
# Concatenate image and question features along the last dimension
|
197 |
+
# Shape: (batch_size, num_regions, 2 * feature_size)
|
198 |
+
combined_features = torch.cat([img_features, qst_feature_exp], dim=-1)
|
199 |
+
|
200 |
+
# Compute attention scores for each region
|
201 |
+
# Shape: (batch_size, num_regions, 1)
|
202 |
+
attention_scores = self.attention_fc(combined_features)
|
203 |
+
|
204 |
+
# Apply softmax to get attention weights
|
205 |
+
# Shape: (batch_size, num_regions)
|
206 |
+
attention_weights = F.softmax(attention_scores.squeeze(-1), dim=1)
|
207 |
+
|
208 |
+
# Compute the weighted sum of image features
|
209 |
+
# Shape: (batch_size, feature_size)
|
210 |
+
attended_img_feature = torch.sum(img_features * attention_weights.unsqueeze(-1), dim=1)
|
211 |
+
|
212 |
+
# Combine attended image features with question features
|
213 |
+
combined_feature = attended_img_feature + qst_feature
|
214 |
+
|
215 |
+
# Dropout and fully connected layers for answer prediction
|
216 |
+
combined_feature = self.dropout(combined_feature)
|
217 |
+
combined_feature = F.relu(self.fc1(combined_feature))
|
218 |
+
logits = self.fc2(combined_feature)
|
219 |
+
|
220 |
+
return logits
|