|
import argparse |
|
import os |
|
import sys |
|
from typing import List |
|
|
|
import torch |
|
import transformers |
|
|
|
|
|
from peft import ( |
|
TaskType, |
|
LoraConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
set_peft_model_state_dict, |
|
) |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig |
|
|
|
from utils import * |
|
from collator import Collator |
|
|
|
def train(args): |
|
|
|
set_seed(args.seed) |
|
ensure_dir(args.output_dir) |
|
|
|
device_map = "auto" |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
ddp = world_size != 1 |
|
local_rank = int(os.environ.get("LOCAL_RANK") or 0) |
|
if local_rank == 0: |
|
print(vars(args)) |
|
|
|
if ddp: |
|
device_map = {"": local_rank} |
|
|
|
config = LlamaConfig.from_pretrained(args.base_model) |
|
tokenizer = LlamaTokenizer.from_pretrained( |
|
args.base_model, |
|
model_max_length=args.model_max_length, |
|
padding_side="right", |
|
) |
|
tokenizer.pad_token_id = 0 |
|
|
|
train_data, valid_data = load_datasets(args) |
|
add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens()) |
|
config.vocab_size = len(tokenizer) |
|
if local_rank == 0: |
|
print("add {} new token.".format(add_num)) |
|
print("data num:", len(train_data)) |
|
tokenizer.save_pretrained(args.output_dir) |
|
config.save_pretrained(args.output_dir) |
|
|
|
collator = Collator(args, tokenizer) |
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
args.base_model, |
|
torch_dtype=torch.float16, |
|
device_map=device_map, |
|
) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
config = LoraConfig( |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
target_modules=args.lora_target_modules.split(","), |
|
modules_to_save=args.lora_modules_to_save.split(","), |
|
lora_dropout=args.lora_dropout, |
|
bias="none", |
|
inference_mode=False, |
|
task_type=TaskType.CAUSAL_LM, |
|
) |
|
model = get_peft_model(model, config) |
|
|
|
if args.resume_from_checkpoint: |
|
checkpoint_name = os.path.join( |
|
args.resume_from_checkpoint, "adapter_model.bin" |
|
) |
|
args.resume_from_checkpoint = False |
|
|
|
if os.path.exists(checkpoint_name): |
|
if local_rank == 0: |
|
print(f"Restarting from {checkpoint_name}") |
|
adapters_weights = torch.load(checkpoint_name) |
|
model = set_peft_model_state_dict(model, adapters_weights) |
|
else: |
|
if local_rank == 0: |
|
print(f"Checkpoint {checkpoint_name} not found") |
|
|
|
for n, p in model.named_parameters(): |
|
if "original_module" in n and any(module_name in n for module_name in config.modules_to_save): |
|
p.requires_grad = False |
|
|
|
if local_rank == 0: |
|
model.print_trainable_parameters() |
|
|
|
|
|
if not ddp and torch.cuda.device_count() > 1: |
|
model.is_parallelizable = True |
|
model.model_parallel = True |
|
|
|
trainer = transformers.Trainer( |
|
model=model, |
|
train_dataset=train_data, |
|
eval_dataset=valid_data, |
|
args=transformers.TrainingArguments( |
|
seed=args.seed, |
|
per_device_train_batch_size=args.per_device_batch_size, |
|
per_device_eval_batch_size=args.per_device_batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
warmup_ratio=args.warmup_ratio, |
|
num_train_epochs=args.epochs, |
|
learning_rate=args.learning_rate, |
|
weight_decay=args.weight_decay, |
|
lr_scheduler_type=args.lr_scheduler_type, |
|
fp16=args.fp16, |
|
bf16=args.bf16, |
|
logging_steps=args.logging_step, |
|
optim=args.optim, |
|
gradient_checkpointing=True, |
|
evaluation_strategy=args.save_and_eval_strategy, |
|
save_strategy=args.save_and_eval_strategy, |
|
eval_steps=args.save_and_eval_steps, |
|
save_steps=args.save_and_eval_steps, |
|
output_dir=args.output_dir, |
|
save_total_limit=5, |
|
load_best_model_at_end=True, |
|
deepspeed=args.deepspeed, |
|
ddp_find_unused_parameters=False if ddp else None, |
|
report_to=None, |
|
eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000, |
|
dataloader_num_workers = args.dataloader_num_workers, |
|
dataloader_prefetch_factor = args.dataloader_prefetch_factor |
|
), |
|
tokenizer=tokenizer, |
|
data_collator=collator, |
|
) |
|
model.config.use_cache = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
model = torch.compile(model) |
|
|
|
trainer.train( |
|
resume_from_checkpoint=args.resume_from_checkpoint, |
|
) |
|
|
|
trainer.save_state() |
|
trainer.save_model(output_dir=args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='LLMRec') |
|
parser = parse_global_args(parser) |
|
parser = parse_train_args(parser) |
|
parser = parse_dataset_args(parser) |
|
|
|
args = parser.parse_args() |
|
|
|
train(args) |
|
|