JMalott commited on
Commit
3ff83b9
·
1 Parent(s): ca509e8

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from torch import nn, LongTensor, FloatTensor, BoolTensor
4
  from .dalle_bart_encoder import GLU, AttentionBase
5
  import gc
 
6
 
7
  IMAGE_TOKEN_COUNT = 256
8
 
@@ -154,6 +155,12 @@ class DalleBartDecoder(nn.Module):
154
  decoder_state += self.embed_positions.forward(token_index_batched)
155
  decoder_state = self.layernorm_embedding.forward(decoder_state)
156
  decoder_state = decoder_state[:, None]
 
 
 
 
 
 
157
  for i in range(self.layer_count):
158
  decoder_state, attention_state[i] = self.layers[i].forward(
159
  decoder_state,
@@ -173,6 +180,7 @@ class DalleBartDecoder(nn.Module):
173
  logits[:image_count] * (1 - supercondition_factor) +
174
  logits[image_count:] * supercondition_factor
175
  )
 
176
  del supercondition_factor
177
  logits_sorted, _ = logits.sort(descending=True)
178
  is_kept = logits >= logits_sorted[:, top_k - 1]
@@ -188,4 +196,6 @@ class DalleBartDecoder(nn.Module):
188
  del logits
189
  gc.collect()
190
 
 
 
191
  return image_tokens, attention_state
 
3
  from torch import nn, LongTensor, FloatTensor, BoolTensor
4
  from .dalle_bart_encoder import GLU, AttentionBase
5
  import gc
6
+ import tracemalloc
7
 
8
  IMAGE_TOKEN_COUNT = 256
9
 
 
155
  decoder_state += self.embed_positions.forward(token_index_batched)
156
  decoder_state = self.layernorm_embedding.forward(decoder_state)
157
  decoder_state = decoder_state[:, None]
158
+
159
+ tracemalloc.start()
160
+ print("--")
161
+ # displaying the memory
162
+ print(tracemalloc.get_traced_memory())
163
+
164
  for i in range(self.layer_count):
165
  decoder_state, attention_state[i] = self.layers[i].forward(
166
  decoder_state,
 
180
  logits[:image_count] * (1 - supercondition_factor) +
181
  logits[image_count:] * supercondition_factor
182
  )
183
+ print(tracemalloc.get_traced_memory())
184
  del supercondition_factor
185
  logits_sorted, _ = logits.sort(descending=True)
186
  is_kept = logits >= logits_sorted[:, top_k - 1]
 
196
  del logits
197
  gc.collect()
198
 
199
+ print(tracemalloc.get_traced_memory())
200
+
201
  return image_tokens, attention_state