|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Supervised fine-tuning script for decoder language models. |
|
""" |
|
|
|
import logging |
|
import random |
|
import sys |
|
|
|
import datasets |
|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM, set_seed |
|
|
|
from alignment import ( |
|
DataArguments, |
|
H4ArgumentParser, |
|
ModelArguments, |
|
SFTConfig, |
|
apply_chat_template, |
|
decontaminate_humaneval, |
|
get_checkpoint, |
|
get_datasets, |
|
get_kbit_device_map, |
|
get_peft_config, |
|
get_quantization_config, |
|
get_tokenizer, |
|
) |
|
from trl import SFTTrainer, setup_chat_format |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def main(): |
|
parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) |
|
model_args, data_args, training_args = parser.parse() |
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
log_level = training_args.get_process_log_level() |
|
logger.setLevel(log_level) |
|
datasets.utils.logging.set_verbosity(log_level) |
|
transformers.utils.logging.set_verbosity(log_level) |
|
transformers.utils.logging.enable_default_handler() |
|
transformers.utils.logging.enable_explicit_format() |
|
|
|
|
|
logger.warning( |
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
|
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
|
) |
|
logger.info(f"Model parameters {model_args}") |
|
logger.info(f"Data parameters {data_args}") |
|
logger.info(f"Training/evaluation parameters {training_args}") |
|
|
|
|
|
last_checkpoint = get_checkpoint(training_args) |
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") |
|
|
|
|
|
|
|
|
|
raw_datasets = load_dataset("json", data_files="/proj/memorization/FK/warrior/data/warrior_train.json") |
|
eval_raw_datasets = load_dataset("json", data_files="/proj/memorization/FK/warrior/data/warrior_test.json") |
|
logger.info( |
|
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" |
|
) |
|
column_names = list(raw_datasets["train"].features) |
|
|
|
|
|
|
|
|
|
tokenizer = get_tokenizer(model_args, data_args) |
|
|
|
|
|
|
|
|
|
logger.info("*** Load pretrained model ***") |
|
torch_dtype = ( |
|
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) |
|
) |
|
quantization_config = get_quantization_config(model_args) |
|
|
|
model_kwargs = dict( |
|
revision=model_args.model_revision, |
|
trust_remote_code=model_args.trust_remote_code, |
|
attn_implementation=model_args.attn_implementation, |
|
torch_dtype=torch_dtype, |
|
use_cache=False if training_args.gradient_checkpointing else True, |
|
device_map=get_kbit_device_map() if quantization_config is not None else None, |
|
quantization_config=quantization_config, |
|
) |
|
|
|
model = model_args.model_name_or_path |
|
|
|
if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: |
|
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) |
|
model, tokenizer = setup_chat_format(model, tokenizer) |
|
model_kwargs = None |
|
|
|
|
|
|
|
|
|
raw_datasets = raw_datasets.map( |
|
apply_chat_template, |
|
fn_kwargs={ |
|
"tokenizer": tokenizer, |
|
"task": "sft", |
|
"auto_insert_empty_system_msg": False, |
|
}, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
desc="Applying chat template", |
|
) |
|
eval_raw_datasets = eval_raw_datasets.map( |
|
apply_chat_template, |
|
fn_kwargs={ |
|
"tokenizer": tokenizer, |
|
"task": "sft", |
|
"auto_insert_empty_system_msg": False, |
|
}, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
desc="Applying chat template", |
|
) |
|
|
|
|
|
train_dataset = raw_datasets["train"] |
|
eval_dataset = eval_raw_datasets["train"] |
|
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
model_init_kwargs=model_kwargs, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
dataset_text_field="text", |
|
max_seq_length=training_args.max_seq_length, |
|
tokenizer=tokenizer, |
|
packing=True, |
|
peft_config=get_peft_config(model_args), |
|
dataset_kwargs=training_args.dataset_kwargs, |
|
) |
|
|
|
|
|
|
|
|
|
logger.info("*** Train ***") |
|
checkpoint = None |
|
if training_args.resume_from_checkpoint is not None: |
|
checkpoint = training_args.resume_from_checkpoint |
|
elif last_checkpoint is not None: |
|
checkpoint = last_checkpoint |
|
train_result = trainer.train(resume_from_checkpoint=checkpoint) |
|
metrics = train_result.metrics |
|
metrics["train_samples"] = len(train_dataset) |
|
trainer.log_metrics("train", metrics) |
|
trainer.save_metrics("train", metrics) |
|
trainer.save_state() |
|
|
|
|
|
|
|
|
|
logger.info("*** Save model ***") |
|
trainer.save_model(training_args.output_dir) |
|
logger.info(f"Model saved to {training_args.output_dir}") |
|
|
|
|
|
kwargs = { |
|
"finetuned_from": model_args.model_name_or_path, |
|
"dataset": list(data_args.dataset_mixer.keys()), |
|
"dataset_tags": list(data_args.dataset_mixer.keys()), |
|
"tags": ["alignment-handbook"], |
|
} |
|
if trainer.accelerator.is_main_process: |
|
trainer.create_model_card(**kwargs) |
|
|
|
trainer.model.config.use_cache = True |
|
trainer.model.config.save_pretrained(training_args.output_dir) |
|
|
|
|
|
|
|
|
|
if training_args.do_eval: |
|
logger.info("*** Evaluate ***") |
|
metrics = trainer.evaluate() |
|
metrics["eval_samples"] = len(eval_dataset) |
|
trainer.log_metrics("eval", metrics) |
|
trainer.save_metrics("eval", metrics) |
|
|
|
if training_args.push_to_hub is True: |
|
logger.info("Pushing to hub...") |
|
trainer.push_to_hub(**kwargs) |
|
|
|
logger.info("*** Training complete ***") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |