amaiasalvador
commited on
Commit
·
3ab629c
1
Parent(s):
9f2a462
transformer no longer returns unnecessary attention weights. fix: allow backward when training ingredient decoder
Browse files- src/model.py +2 -2
- src/modules/transformer_decoder.py +8 -11
src/model.py
CHANGED
@@ -211,7 +211,7 @@ class InverseCookingModel(nn.Module):
|
|
211 |
ingr_ids[sample_mask == 0] = self.pad_value
|
212 |
|
213 |
outputs['ingr_ids'] = ingr_ids
|
214 |
-
outputs['ingr_probs'] = ingr_probs
|
215 |
|
216 |
mask = sample_mask
|
217 |
input_mask = mask.float().unsqueeze(1)
|
@@ -230,7 +230,7 @@ class InverseCookingModel(nn.Module):
|
|
230 |
ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
|
231 |
last_token_value=1)
|
232 |
|
233 |
-
outputs['recipe_probs'] = probs
|
234 |
outputs['recipe_ids'] = ids
|
235 |
|
236 |
return outputs
|
|
|
211 |
ingr_ids[sample_mask == 0] = self.pad_value
|
212 |
|
213 |
outputs['ingr_ids'] = ingr_ids
|
214 |
+
outputs['ingr_probs'] = ingr_probs.data
|
215 |
|
216 |
mask = sample_mask
|
217 |
input_mask = mask.float().unsqueeze(1)
|
|
|
230 |
ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
|
231 |
last_token_value=1)
|
232 |
|
233 |
+
outputs['recipe_probs'] = probs.data
|
234 |
outputs['recipe_ids'] = ids
|
235 |
|
236 |
return outputs
|
src/modules/transformer_decoder.py
CHANGED
@@ -161,12 +161,11 @@ class TransformerDecoderLayer(nn.Module):
|
|
161 |
self.last_ln = LayerNorm(self.embed_dim)
|
162 |
|
163 |
def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
|
164 |
-
attn_dict = dict()
|
165 |
|
166 |
# self attention
|
167 |
residual = x
|
168 |
x = self.maybe_layer_norm(0, x, before=True)
|
169 |
-
x,
|
170 |
query=x,
|
171 |
key=x,
|
172 |
value=x,
|
@@ -184,7 +183,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
184 |
# attention
|
185 |
if ingr_features is None:
|
186 |
|
187 |
-
x,
|
188 |
key=img_features,
|
189 |
value=img_features,
|
190 |
key_padding_mask=None,
|
@@ -192,7 +191,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
192 |
static_kv=True,
|
193 |
)
|
194 |
elif img_features is None:
|
195 |
-
x,
|
196 |
key=ingr_features,
|
197 |
value=ingr_features,
|
198 |
key_padding_mask=ingr_mask,
|
@@ -206,7 +205,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
206 |
kv = torch.cat((img_features, ingr_features), 0)
|
207 |
mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
|
208 |
ingr_mask), 1)
|
209 |
-
x,
|
210 |
key=kv,
|
211 |
value=kv,
|
212 |
key_padding_mask=mask,
|
@@ -229,7 +228,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
229 |
if self.use_last_ln:
|
230 |
x = self.last_ln(x)
|
231 |
|
232 |
-
return x
|
233 |
|
234 |
def maybe_layer_norm(self, i, x, before=False, after=False):
|
235 |
assert before ^ after
|
@@ -308,16 +307,14 @@ class DecoderTransformer(nn.Module):
|
|
308 |
x = x.transpose(0, 1)
|
309 |
|
310 |
for p, layer in enumerate(self.layers):
|
311 |
-
x
|
312 |
x,
|
313 |
ingr_features,
|
314 |
ingr_mask,
|
315 |
incremental_state,
|
316 |
img_features
|
317 |
)
|
318 |
-
|
319 |
-
attn_dict[key][p] = attn[key]
|
320 |
-
#attn_layers.append(attn)
|
321 |
# T x B x C -> B x T x C
|
322 |
x = x.transpose(0, 1)
|
323 |
|
@@ -387,7 +384,7 @@ class DecoderTransformer(nn.Module):
|
|
387 |
sampled_ids.append(predicted)
|
388 |
|
389 |
sampled_ids = torch.stack(sampled_ids[1:], 1)
|
390 |
-
logits = torch.stack(logits, 1)
|
391 |
|
392 |
return sampled_ids, logits
|
393 |
|
|
|
161 |
self.last_ln = LayerNorm(self.embed_dim)
|
162 |
|
163 |
def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
|
|
|
164 |
|
165 |
# self attention
|
166 |
residual = x
|
167 |
x = self.maybe_layer_norm(0, x, before=True)
|
168 |
+
x, _ = self.self_attn(
|
169 |
query=x,
|
170 |
key=x,
|
171 |
value=x,
|
|
|
183 |
# attention
|
184 |
if ingr_features is None:
|
185 |
|
186 |
+
x, _ = self.cond_att(query=x,
|
187 |
key=img_features,
|
188 |
value=img_features,
|
189 |
key_padding_mask=None,
|
|
|
191 |
static_kv=True,
|
192 |
)
|
193 |
elif img_features is None:
|
194 |
+
x, _ = self.cond_att(query=x,
|
195 |
key=ingr_features,
|
196 |
value=ingr_features,
|
197 |
key_padding_mask=ingr_mask,
|
|
|
205 |
kv = torch.cat((img_features, ingr_features), 0)
|
206 |
mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
|
207 |
ingr_mask), 1)
|
208 |
+
x, _ = self.cond_att(query=x,
|
209 |
key=kv,
|
210 |
value=kv,
|
211 |
key_padding_mask=mask,
|
|
|
228 |
if self.use_last_ln:
|
229 |
x = self.last_ln(x)
|
230 |
|
231 |
+
return x
|
232 |
|
233 |
def maybe_layer_norm(self, i, x, before=False, after=False):
|
234 |
assert before ^ after
|
|
|
307 |
x = x.transpose(0, 1)
|
308 |
|
309 |
for p, layer in enumerate(self.layers):
|
310 |
+
x = layer(
|
311 |
x,
|
312 |
ingr_features,
|
313 |
ingr_mask,
|
314 |
incremental_state,
|
315 |
img_features
|
316 |
)
|
317 |
+
|
|
|
|
|
318 |
# T x B x C -> B x T x C
|
319 |
x = x.transpose(0, 1)
|
320 |
|
|
|
384 |
sampled_ids.append(predicted)
|
385 |
|
386 |
sampled_ids = torch.stack(sampled_ids[1:], 1)
|
387 |
+
logits = torch.stack(logits, 1)
|
388 |
|
389 |
return sampled_ids, logits
|
390 |
|