Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
ae983d7
1
Parent(s):
7e48337
Use DalleBartTokenizer. State restoration reverted to previous method:
Browse filesexplicitly download artifact and use the download directory.
A better solution will be addressed in #120.
- tools/train/train.py +13 -8
tools/train/train.py
CHANGED
|
@@ -44,7 +44,7 @@ from tqdm import tqdm
|
|
| 44 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 45 |
|
| 46 |
from dalle_mini.data import Dataset
|
| 47 |
-
from dalle_mini.model import DalleBart, DalleBartConfig
|
| 48 |
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
|
|
@@ -435,9 +435,15 @@ def main():
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if training_args.resume_from_checkpoint is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
# load model
|
| 439 |
model = DalleBart.from_pretrained(
|
| 440 |
-
|
| 441 |
dtype=getattr(jnp, model_args.dtype),
|
| 442 |
abstract_init=True,
|
| 443 |
)
|
|
@@ -445,8 +451,8 @@ def main():
|
|
| 445 |
print(model.params)
|
| 446 |
|
| 447 |
# load tokenizer
|
| 448 |
-
tokenizer =
|
| 449 |
-
|
| 450 |
use_fast=True,
|
| 451 |
)
|
| 452 |
|
|
@@ -481,9 +487,8 @@ def main():
|
|
| 481 |
model_args.tokenizer_name, use_fast=True
|
| 482 |
)
|
| 483 |
else:
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
model.config.resolved_name_or_path,
|
| 487 |
use_fast=True,
|
| 488 |
)
|
| 489 |
|
|
@@ -621,7 +626,7 @@ def main():
|
|
| 621 |
if training_args.resume_from_checkpoint is not None:
|
| 622 |
# restore optimizer state and other parameters
|
| 623 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 624 |
-
state = state.restore_state(
|
| 625 |
|
| 626 |
# label smoothed cross entropy
|
| 627 |
def loss_fn(logits, labels):
|
|
|
|
| 44 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 45 |
|
| 46 |
from dalle_mini.data import Dataset
|
| 47 |
+
from dalle_mini.model import DalleBart, DalleBartConfig, DalleBartTokenizer
|
| 48 |
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if training_args.resume_from_checkpoint is not None:
|
| 438 |
+
if jax.process_index() == 0:
|
| 439 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
| 440 |
+
else:
|
| 441 |
+
artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
|
| 442 |
+
artifact_dir = artifact.download()
|
| 443 |
+
|
| 444 |
# load model
|
| 445 |
model = DalleBart.from_pretrained(
|
| 446 |
+
artifact_dir,
|
| 447 |
dtype=getattr(jnp, model_args.dtype),
|
| 448 |
abstract_init=True,
|
| 449 |
)
|
|
|
|
| 451 |
print(model.params)
|
| 452 |
|
| 453 |
# load tokenizer
|
| 454 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 455 |
+
artifact_dir,
|
| 456 |
use_fast=True,
|
| 457 |
)
|
| 458 |
|
|
|
|
| 487 |
model_args.tokenizer_name, use_fast=True
|
| 488 |
)
|
| 489 |
else:
|
| 490 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 491 |
+
model_args.model_name_or_path,
|
|
|
|
| 492 |
use_fast=True,
|
| 493 |
)
|
| 494 |
|
|
|
|
| 626 |
if training_args.resume_from_checkpoint is not None:
|
| 627 |
# restore optimizer state and other parameters
|
| 628 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 629 |
+
state = state.restore_state(artifact_dir)
|
| 630 |
|
| 631 |
# label smoothed cross entropy
|
| 632 |
def loss_fn(logits, labels):
|