JMalott commited on
Commit
563067a
·
1 Parent(s): 14ae027

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -175,7 +175,7 @@ class DalleBartDecoder(nn.Module):
175
  logits /= temperature
176
  logits.exp_()
177
  logits *= is_kept.to(torch.float32)
178
- image_tokens = torch.multinomial(logits, 1)[:, 0]
179
- del logits
180
- print("hi")
181
- return image_tokens, attention_state
 
175
  logits /= temperature
176
  logits.exp_()
177
  logits *= is_kept.to(torch.float32)
178
+ #image_tokens = torch.multinomial(logits, 1)[:, 0]
179
+ #del logits
180
+
181
+ return torch.multinomial(logits, 1)[:, 0], attention_state