amaiasalvador commited on
Commit
3ab629c
·
1 Parent(s): 9f2a462

transformer no longer returns unnecessary attention weights. fix: allow backward when training ingredient decoder

Browse files
Files changed (2) hide show
  1. src/model.py +2 -2
  2. src/modules/transformer_decoder.py +8 -11
src/model.py CHANGED
@@ -211,7 +211,7 @@ class InverseCookingModel(nn.Module):
211
  ingr_ids[sample_mask == 0] = self.pad_value
212
 
213
  outputs['ingr_ids'] = ingr_ids
214
- outputs['ingr_probs'] = ingr_probs
215
 
216
  mask = sample_mask
217
  input_mask = mask.float().unsqueeze(1)
@@ -230,7 +230,7 @@ class InverseCookingModel(nn.Module):
230
  ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
231
  last_token_value=1)
232
 
233
- outputs['recipe_probs'] = probs
234
  outputs['recipe_ids'] = ids
235
 
236
  return outputs
 
211
  ingr_ids[sample_mask == 0] = self.pad_value
212
 
213
  outputs['ingr_ids'] = ingr_ids
214
+ outputs['ingr_probs'] = ingr_probs.data
215
 
216
  mask = sample_mask
217
  input_mask = mask.float().unsqueeze(1)
 
230
  ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
231
  last_token_value=1)
232
 
233
+ outputs['recipe_probs'] = probs.data
234
  outputs['recipe_ids'] = ids
235
 
236
  return outputs
src/modules/transformer_decoder.py CHANGED
@@ -161,12 +161,11 @@ class TransformerDecoderLayer(nn.Module):
161
  self.last_ln = LayerNorm(self.embed_dim)
162
 
163
  def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
164
- attn_dict = dict()
165
 
166
  # self attention
167
  residual = x
168
  x = self.maybe_layer_norm(0, x, before=True)
169
- x, attn_selfatt = self.self_attn(
170
  query=x,
171
  key=x,
172
  value=x,
@@ -184,7 +183,7 @@ class TransformerDecoderLayer(nn.Module):
184
  # attention
185
  if ingr_features is None:
186
 
187
- x, attn_condatt = self.cond_att(query=x,
188
  key=img_features,
189
  value=img_features,
190
  key_padding_mask=None,
@@ -192,7 +191,7 @@ class TransformerDecoderLayer(nn.Module):
192
  static_kv=True,
193
  )
194
  elif img_features is None:
195
- x, attn_condatt = self.cond_att(query=x,
196
  key=ingr_features,
197
  value=ingr_features,
198
  key_padding_mask=ingr_mask,
@@ -206,7 +205,7 @@ class TransformerDecoderLayer(nn.Module):
206
  kv = torch.cat((img_features, ingr_features), 0)
207
  mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
208
  ingr_mask), 1)
209
- x, attn_condatt = self.cond_att(query=x,
210
  key=kv,
211
  value=kv,
212
  key_padding_mask=mask,
@@ -229,7 +228,7 @@ class TransformerDecoderLayer(nn.Module):
229
  if self.use_last_ln:
230
  x = self.last_ln(x)
231
 
232
- return x, attn_dict
233
 
234
  def maybe_layer_norm(self, i, x, before=False, after=False):
235
  assert before ^ after
@@ -308,16 +307,14 @@ class DecoderTransformer(nn.Module):
308
  x = x.transpose(0, 1)
309
 
310
  for p, layer in enumerate(self.layers):
311
- x, attn = layer(
312
  x,
313
  ingr_features,
314
  ingr_mask,
315
  incremental_state,
316
  img_features
317
  )
318
- for key in attn.keys():
319
- attn_dict[key][p] = attn[key]
320
- #attn_layers.append(attn)
321
  # T x B x C -> B x T x C
322
  x = x.transpose(0, 1)
323
 
@@ -387,7 +384,7 @@ class DecoderTransformer(nn.Module):
387
  sampled_ids.append(predicted)
388
 
389
  sampled_ids = torch.stack(sampled_ids[1:], 1)
390
- logits = torch.stack(logits, 1).data
391
 
392
  return sampled_ids, logits
393
 
 
161
  self.last_ln = LayerNorm(self.embed_dim)
162
 
163
  def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
 
164
 
165
  # self attention
166
  residual = x
167
  x = self.maybe_layer_norm(0, x, before=True)
168
+ x, _ = self.self_attn(
169
  query=x,
170
  key=x,
171
  value=x,
 
183
  # attention
184
  if ingr_features is None:
185
 
186
+ x, _ = self.cond_att(query=x,
187
  key=img_features,
188
  value=img_features,
189
  key_padding_mask=None,
 
191
  static_kv=True,
192
  )
193
  elif img_features is None:
194
+ x, _ = self.cond_att(query=x,
195
  key=ingr_features,
196
  value=ingr_features,
197
  key_padding_mask=ingr_mask,
 
205
  kv = torch.cat((img_features, ingr_features), 0)
206
  mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
207
  ingr_mask), 1)
208
+ x, _ = self.cond_att(query=x,
209
  key=kv,
210
  value=kv,
211
  key_padding_mask=mask,
 
228
  if self.use_last_ln:
229
  x = self.last_ln(x)
230
 
231
+ return x
232
 
233
  def maybe_layer_norm(self, i, x, before=False, after=False):
234
  assert before ^ after
 
307
  x = x.transpose(0, 1)
308
 
309
  for p, layer in enumerate(self.layers):
310
+ x = layer(
311
  x,
312
  ingr_features,
313
  ingr_mask,
314
  incremental_state,
315
  img_features
316
  )
317
+
 
 
318
  # T x B x C -> B x T x C
319
  x = x.transpose(0, 1)
320
 
 
384
  sampled_ids.append(predicted)
385
 
386
  sampled_ids = torch.stack(sampled_ids[1:], 1)
387
+ logits = torch.stack(logits, 1)
388
 
389
  return sampled_ids, logits
390