Benchmark-v0 / lora_finetune.py
yinjun
Add files using upload-large-folder tool
05744dc verified
import argparse
import os
import sys
from typing import List
import torch
import transformers
from peft import (
TaskType,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
from utils import *
from collator import Collator
def train(args):
set_seed(args.seed)
ensure_dir(args.output_dir)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
if local_rank == 0:
print(vars(args))
if ddp:
device_map = {"": local_rank}
config = LlamaConfig.from_pretrained(args.base_model)
tokenizer = LlamaTokenizer.from_pretrained(
args.base_model,
model_max_length=args.model_max_length,
padding_side="right",
)
tokenizer.pad_token_id = 0
train_data, valid_data = load_datasets(args)
add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
config.vocab_size = len(tokenizer)
if local_rank == 0:
print("add {} new token.".format(add_num))
print("data num:", len(train_data))
tokenizer.save_pretrained(args.output_dir)
config.save_pretrained(args.output_dir)
collator = Collator(args, tokenizer)
model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=torch.float16,
device_map=device_map,
)
model.resize_token_embeddings(len(tokenizer))
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=args.lora_target_modules.split(","),
modules_to_save=args.lora_modules_to_save.split(","),
lora_dropout=args.lora_dropout,
bias="none",
inference_mode=False,
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, config)
if args.resume_from_checkpoint:
checkpoint_name = os.path.join(
args.resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
args.resume_from_checkpoint = False # So the trainer won't try loading its state
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
if local_rank == 0:
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights)
else:
if local_rank == 0:
print(f"Checkpoint {checkpoint_name} not found")
for n, p in model.named_parameters():
if "original_module" in n and any(module_name in n for module_name in config.modules_to_save):
p.requires_grad = False
if local_rank == 0:
model.print_trainable_parameters()
if not ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=valid_data,
args=transformers.TrainingArguments(
seed=args.seed,
per_device_train_batch_size=args.per_device_batch_size,
per_device_eval_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
fp16=args.fp16,
bf16=args.bf16,
logging_steps=args.logging_step,
optim=args.optim,
gradient_checkpointing=True,
evaluation_strategy=args.save_and_eval_strategy,
save_strategy=args.save_and_eval_strategy,
eval_steps=args.save_and_eval_steps,
save_steps=args.save_and_eval_steps,
output_dir=args.output_dir,
save_total_limit=5,
load_best_model_at_end=True,
deepspeed=args.deepspeed,
ddp_find_unused_parameters=False if ddp else None,
report_to=None,
eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000,
dataloader_num_workers = args.dataloader_num_workers,
dataloader_prefetch_factor = args.dataloader_prefetch_factor
),
tokenizer=tokenizer,
data_collator=collator,
)
model.config.use_cache = False
# old_state_dict = model.state_dict
# model.state_dict = (
# lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
# ).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(
resume_from_checkpoint=args.resume_from_checkpoint,
)
trainer.save_state()
trainer.save_model(output_dir=args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='LLMRec')
parser = parse_global_args(parser)
parser = parse_train_args(parser)
parser = parse_dataset_args(parser)
args = parser.parse_args()
train(args)