Spaces:
Running
Running
feat: handle gradient checkpointing
Browse files- src/dalle_mini/model/modeling.py +2 -2
- tools/train/train.py +23 -1
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -144,7 +144,7 @@ class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection):
|
|
| 144 |
|
| 145 |
def setup(self):
|
| 146 |
layer_module = (
|
| 147 |
-
nn.remat(FlaxBartEncoderLayer)
|
| 148 |
if self.config.gradient_checkpointing
|
| 149 |
else FlaxBartEncoderLayer
|
| 150 |
)
|
|
@@ -211,7 +211,7 @@ class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection):
|
|
| 211 |
|
| 212 |
def setup(self):
|
| 213 |
layer_module = (
|
| 214 |
-
nn.remat(FlaxBartDecoderLayer)
|
| 215 |
if self.config.gradient_checkpointing
|
| 216 |
else FlaxBartDecoderLayer
|
| 217 |
)
|
|
|
|
| 144 |
|
| 145 |
def setup(self):
|
| 146 |
layer_module = (
|
| 147 |
+
nn.remat(FlaxBartEncoderLayer, concrete=True)
|
| 148 |
if self.config.gradient_checkpointing
|
| 149 |
else FlaxBartEncoderLayer
|
| 150 |
)
|
|
|
|
| 211 |
|
| 212 |
def setup(self):
|
| 213 |
layer_module = (
|
| 214 |
+
nn.remat(FlaxBartDecoderLayer, concrete=True)
|
| 215 |
if self.config.gradient_checkpointing
|
| 216 |
else FlaxBartDecoderLayer
|
| 217 |
)
|
tools/train/train.py
CHANGED
|
@@ -18,6 +18,7 @@ Training DALL·E Mini.
|
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
|
|
|
| 21 |
import io
|
| 22 |
import logging
|
| 23 |
import os
|
|
@@ -531,6 +532,8 @@ def main():
|
|
| 531 |
# Set up our new model config
|
| 532 |
if model_args.config_name:
|
| 533 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
|
|
|
|
|
|
| 534 |
else:
|
| 535 |
config = None
|
| 536 |
|
|
@@ -553,8 +556,27 @@ def main():
|
|
| 553 |
)
|
| 554 |
|
| 555 |
# update model config per training args
|
|
|
|
|
|
|
| 556 |
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
# get model metadata
|
| 559 |
model_metadata = model_args.get_metadata()
|
| 560 |
|
|
@@ -967,7 +989,7 @@ def main():
|
|
| 967 |
|
| 968 |
def compute_eval_loss(batch):
|
| 969 |
batch, labels = batch.pop("labels")
|
| 970 |
-
logits =
|
| 971 |
return loss_fn(logits, labels)
|
| 972 |
|
| 973 |
# calculate loss independently per dp_device
|
|
|
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
import copy
|
| 22 |
import io
|
| 23 |
import logging
|
| 24 |
import os
|
|
|
|
| 532 |
# Set up our new model config
|
| 533 |
if model_args.config_name:
|
| 534 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
| 535 |
+
# initializing params with gradient checkpointing create issues
|
| 536 |
+
config.gradient_checkpointing = False
|
| 537 |
else:
|
| 538 |
config = None
|
| 539 |
|
|
|
|
| 556 |
)
|
| 557 |
|
| 558 |
# update model config per training args
|
| 559 |
+
# Done after initialization of weights to avoid issues with remat
|
| 560 |
+
# This is still considered correctly during training as function is pjitted
|
| 561 |
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
| 562 |
|
| 563 |
+
# eval model cannot use remat
|
| 564 |
+
eval_config = copy.deepcopy(model.config)
|
| 565 |
+
eval_config.gradient_checkpointing = False
|
| 566 |
+
|
| 567 |
+
if training_args.gradient_checkpointing:
|
| 568 |
+
eval_model = DalleBart(
|
| 569 |
+
eval_config,
|
| 570 |
+
seed=training_args.seed_model,
|
| 571 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 572 |
+
abstract_init=True,
|
| 573 |
+
load_on_cpu=True,
|
| 574 |
+
)
|
| 575 |
+
del eval_model._params
|
| 576 |
+
eval_fn = eval_model.__call__
|
| 577 |
+
else:
|
| 578 |
+
eval_fn = model.__call__
|
| 579 |
+
|
| 580 |
# get model metadata
|
| 581 |
model_metadata = model_args.get_metadata()
|
| 582 |
|
|
|
|
| 989 |
|
| 990 |
def compute_eval_loss(batch):
|
| 991 |
batch, labels = batch.pop("labels")
|
| 992 |
+
logits = eval_fn(**batch, params=state.params, train=False)[0]
|
| 993 |
return loss_fn(logits, labels)
|
| 994 |
|
| 995 |
# calculate loss independently per dp_device
|