Spaces:
Running
Running
fix: correct use of dtype
Browse files- dalle_mini/model.py +4 -10
dalle_mini/model.py
CHANGED
|
@@ -18,21 +18,20 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
| 18 |
self.shared = nn.Embed(
|
| 19 |
self.config.vocab_size,
|
| 20 |
self.config.d_model,
|
| 21 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std
|
| 22 |
-
dtype=self.dtype,
|
| 23 |
)
|
| 24 |
# a separate embedding is used for the decoder
|
| 25 |
self.decoder_embed = nn.Embed(
|
| 26 |
self.config.image_vocab_size + 1,
|
| 27 |
self.config.d_model,
|
| 28 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std
|
| 29 |
-
dtype=self.dtype,
|
| 30 |
)
|
| 31 |
self.encoder = FlaxBartEncoder(
|
| 32 |
self.config, dtype=self.dtype, embed_tokens=self.shared
|
| 33 |
)
|
| 34 |
|
| 35 |
# the decoder has a different config
|
|
|
|
| 36 |
decoder_config = BartConfig(self.config.to_dict())
|
| 37 |
decoder_config.max_position_embeddings = (
|
| 38 |
self.config.image_length + 1 # image tokens + BOS
|
|
@@ -47,16 +46,11 @@ class CustomFlaxBartForConditionalGenerationModule(
|
|
| 47 |
FlaxBartForConditionalGenerationModule
|
| 48 |
):
|
| 49 |
def setup(self):
|
| 50 |
-
# check config is valid, otherwise set default values
|
| 51 |
-
# TODO: simplify with custom config class
|
| 52 |
-
self.config.text_normalized = True / False
|
| 53 |
-
|
| 54 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
| 55 |
self.lm_head = nn.Dense(
|
| 56 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
| 57 |
use_bias=False,
|
| 58 |
-
|
| 59 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
| 60 |
)
|
| 61 |
self.final_logits_bias = self.param(
|
| 62 |
"final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
|
|
|
|
| 18 |
self.shared = nn.Embed(
|
| 19 |
self.config.vocab_size,
|
| 20 |
self.config.d_model,
|
| 21 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 22 |
)
|
| 23 |
# a separate embedding is used for the decoder
|
| 24 |
self.decoder_embed = nn.Embed(
|
| 25 |
self.config.image_vocab_size + 1,
|
| 26 |
self.config.d_model,
|
| 27 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 28 |
)
|
| 29 |
self.encoder = FlaxBartEncoder(
|
| 30 |
self.config, dtype=self.dtype, embed_tokens=self.shared
|
| 31 |
)
|
| 32 |
|
| 33 |
# the decoder has a different config
|
| 34 |
+
# TODO: should not be needed once we have custom config/module
|
| 35 |
decoder_config = BartConfig(self.config.to_dict())
|
| 36 |
decoder_config.max_position_embeddings = (
|
| 37 |
self.config.image_length + 1 # image tokens + BOS
|
|
|
|
| 46 |
FlaxBartForConditionalGenerationModule
|
| 47 |
):
|
| 48 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
| 50 |
self.lm_head = nn.Dense(
|
| 51 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
| 52 |
use_bias=False,
|
| 53 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
| 54 |
)
|
| 55 |
self.final_logits_bias = self.param(
|
| 56 |
"final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
|