Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
5 |
import gc
|
|
|
6 |
|
7 |
IMAGE_TOKEN_COUNT = 256
|
8 |
|
@@ -154,6 +155,12 @@ class DalleBartDecoder(nn.Module):
|
|
154 |
decoder_state += self.embed_positions.forward(token_index_batched)
|
155 |
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
156 |
decoder_state = decoder_state[:, None]
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
for i in range(self.layer_count):
|
158 |
decoder_state, attention_state[i] = self.layers[i].forward(
|
159 |
decoder_state,
|
@@ -173,6 +180,7 @@ class DalleBartDecoder(nn.Module):
|
|
173 |
logits[:image_count] * (1 - supercondition_factor) +
|
174 |
logits[image_count:] * supercondition_factor
|
175 |
)
|
|
|
176 |
del supercondition_factor
|
177 |
logits_sorted, _ = logits.sort(descending=True)
|
178 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
@@ -188,4 +196,6 @@ class DalleBartDecoder(nn.Module):
|
|
188 |
del logits
|
189 |
gc.collect()
|
190 |
|
|
|
|
|
191 |
return image_tokens, attention_state
|
|
|
3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
5 |
import gc
|
6 |
+
import tracemalloc
|
7 |
|
8 |
IMAGE_TOKEN_COUNT = 256
|
9 |
|
|
|
155 |
decoder_state += self.embed_positions.forward(token_index_batched)
|
156 |
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
157 |
decoder_state = decoder_state[:, None]
|
158 |
+
|
159 |
+
tracemalloc.start()
|
160 |
+
print("--")
|
161 |
+
# displaying the memory
|
162 |
+
print(tracemalloc.get_traced_memory())
|
163 |
+
|
164 |
for i in range(self.layer_count):
|
165 |
decoder_state, attention_state[i] = self.layers[i].forward(
|
166 |
decoder_state,
|
|
|
180 |
logits[:image_count] * (1 - supercondition_factor) +
|
181 |
logits[image_count:] * supercondition_factor
|
182 |
)
|
183 |
+
print(tracemalloc.get_traced_memory())
|
184 |
del supercondition_factor
|
185 |
logits_sorted, _ = logits.sort(descending=True)
|
186 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
|
|
196 |
del logits
|
197 |
gc.collect()
|
198 |
|
199 |
+
print(tracemalloc.get_traced_memory())
|
200 |
+
|
201 |
return image_tokens, attention_state
|