Spaces:
Running
Running
feat: allow loading a model checkpoint
Browse files- seq2seq/run_seq2seq_flax.py +41 -23
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -125,6 +125,12 @@ class ModelArguments:
|
|
| 125 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
| 126 |
},
|
| 127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
|
| 130 |
@dataclass
|
|
@@ -424,36 +430,48 @@ def main():
|
|
| 424 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 425 |
|
| 426 |
# Load pretrained model and tokenizer
|
| 427 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
| 428 |
-
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 429 |
-
)
|
| 430 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 431 |
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 432 |
)
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
| 439 |
-
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
| 440 |
-
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
| 441 |
-
config.forced_bos_token_id = None # we don't need this token
|
| 442 |
-
config.forced_eos_token_id = None # we don't need this token
|
| 443 |
-
config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
|
| 444 |
-
config.min_length = data_args.max_target_length
|
| 445 |
-
config.max_length = data_args.max_target_length
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
model.params['model']['shared'] = base_model.params['model']['shared']
|
| 456 |
-
del base_model
|
| 457 |
|
| 458 |
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
| 459 |
|
|
|
|
| 125 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
| 126 |
},
|
| 127 |
)
|
| 128 |
+
from_checkpoint: Optional[str] = field(
|
| 129 |
+
default=None,
|
| 130 |
+
metadata={
|
| 131 |
+
"help": "Loads a pretrained wandb checkpoint. Use artifact reference."
|
| 132 |
+
},
|
| 133 |
+
)
|
| 134 |
|
| 135 |
|
| 136 |
@dataclass
|
|
|
|
| 430 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 431 |
|
| 432 |
# Load pretrained model and tokenizer
|
|
|
|
|
|
|
|
|
|
| 433 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 434 |
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 435 |
)
|
| 436 |
|
| 437 |
+
if model_args.from_checkpoint is not None:
|
| 438 |
+
artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:latest')
|
| 439 |
+
artifact_dir = artifact.download()
|
| 440 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
+
# some models will try to change bos (because of force_bos_token_to_be_generated)
|
| 443 |
+
# we ensure bos and eos are not forced
|
| 444 |
+
model.config.force_bos_token_to_be_generated = False
|
| 445 |
+
model.config.forced_bos_token_id = None
|
| 446 |
+
model.config.forced_eos_token_id = None
|
| 447 |
|
| 448 |
+
else:
|
| 449 |
+
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
| 450 |
+
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 451 |
+
)
|
| 452 |
+
# Set up our new model config
|
| 453 |
+
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
| 454 |
+
config.tie_word_embeddings = False
|
| 455 |
+
config.decoder_start_token_id = BOS_TOKEN_ID # for first token
|
| 456 |
+
config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
| 457 |
+
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
| 458 |
+
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
| 459 |
+
config.forced_bos_token_id = None # we don't need this token
|
| 460 |
+
config.forced_eos_token_id = None # we don't need this token
|
| 461 |
+
config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
|
| 462 |
+
config.min_length = data_args.max_target_length
|
| 463 |
+
config.max_length = data_args.max_target_length
|
| 464 |
+
|
| 465 |
+
# Create a custom model and initialize it randomly
|
| 466 |
+
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
| 467 |
+
|
| 468 |
+
# Use pre-trained weights for encoder
|
| 469 |
+
model.params['model']['encoder'] = base_model.params['model']['encoder']
|
| 470 |
+
model.params['model']['shared'] = base_model.params['model']['shared']
|
| 471 |
+
del base_model
|
| 472 |
|
| 473 |
+
print(f"TPUs: {jax.device_count()}")
|
| 474 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
|
|
|
|
|
|
| 475 |
|
| 476 |
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
| 477 |
|