JMalott commited on
Commit
602c80d
·
1 Parent(s): 9055def

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -161,7 +161,7 @@ class DalleBartDecoder(nn.Module):
161
  )
162
  decoder_state = self.final_ln(decoder_state)
163
  logits = self.lm_head(decoder_state)
164
- del decorder_state
165
  temperature = settings[[0]]
166
  top_k = settings[[1]].to(torch.long)
167
  supercondition_factor = settings[[2]]
@@ -176,6 +176,7 @@ class DalleBartDecoder(nn.Module):
176
  logits -= logits_sorted[:, [0]]
177
  del logits_sorted
178
  logits /= temperature
 
179
  logits.exp_()
180
  logits *= is_kept.to(torch.float32)
181
  image_tokens = torch.multinomial(logits, 1)[:, 0]
 
161
  )
162
  decoder_state = self.final_ln(decoder_state)
163
  logits = self.lm_head(decoder_state)
164
+ del decoder_state
165
  temperature = settings[[0]]
166
  top_k = settings[[1]].to(torch.long)
167
  supercondition_factor = settings[[2]]
 
176
  logits -= logits_sorted[:, [0]]
177
  del logits_sorted
178
  logits /= temperature
179
+ del temperature
180
  logits.exp_()
181
  logits *= is_kept.to(torch.float32)
182
  image_tokens = torch.multinomial(logits, 1)[:, 0]