import argparse import os import sys from typing import List import torch import transformers from peft import ( TaskType, LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, ) from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig from utils import * from collator import Collator def train(args): set_seed(args.seed) ensure_dir(args.output_dir) device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 local_rank = int(os.environ.get("LOCAL_RANK") or 0) if local_rank == 0: print(vars(args)) if ddp: device_map = {"": local_rank} config = LlamaConfig.from_pretrained(args.base_model) tokenizer = LlamaTokenizer.from_pretrained( args.base_model, model_max_length=args.model_max_length, padding_side="right", ) tokenizer.pad_token_id = 0 train_data, valid_data = load_datasets(args) add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens()) config.vocab_size = len(tokenizer) if local_rank == 0: print("add {} new token.".format(add_num)) print("data num:", len(train_data)) tokenizer.save_pretrained(args.output_dir) config.save_pretrained(args.output_dir) collator = Collator(args, tokenizer) model = LlamaForCausalLM.from_pretrained( args.base_model, torch_dtype=torch.float16, device_map=device_map, ) model.resize_token_embeddings(len(tokenizer)) config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=args.lora_target_modules.split(","), modules_to_save=args.lora_modules_to_save.split(","), lora_dropout=args.lora_dropout, bias="none", inference_mode=False, task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, config) if args.resume_from_checkpoint: checkpoint_name = os.path.join( args.resume_from_checkpoint, "adapter_model.bin" ) # only LoRA model - LoRA config above has to fit args.resume_from_checkpoint = False # So the trainer won't try loading its state # The two files above have a different name depending on how they were saved, but are actually the same. if os.path.exists(checkpoint_name): if local_rank == 0: print(f"Restarting from {checkpoint_name}") adapters_weights = torch.load(checkpoint_name) model = set_peft_model_state_dict(model, adapters_weights) else: if local_rank == 0: print(f"Checkpoint {checkpoint_name} not found") for n, p in model.named_parameters(): if "original_module" in n and any(module_name in n for module_name in config.modules_to_save): p.requires_grad = False if local_rank == 0: model.print_trainable_parameters() if not ddp and torch.cuda.device_count() > 1: model.is_parallelizable = True model.model_parallel = True trainer = transformers.Trainer( model=model, train_dataset=train_data, eval_dataset=valid_data, args=transformers.TrainingArguments( seed=args.seed, per_device_train_batch_size=args.per_device_batch_size, per_device_eval_batch_size=args.per_device_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, warmup_ratio=args.warmup_ratio, num_train_epochs=args.epochs, learning_rate=args.learning_rate, weight_decay=args.weight_decay, lr_scheduler_type=args.lr_scheduler_type, fp16=args.fp16, bf16=args.bf16, logging_steps=args.logging_step, optim=args.optim, gradient_checkpointing=True, evaluation_strategy=args.save_and_eval_strategy, save_strategy=args.save_and_eval_strategy, eval_steps=args.save_and_eval_steps, save_steps=args.save_and_eval_steps, output_dir=args.output_dir, save_total_limit=5, load_best_model_at_end=True, deepspeed=args.deepspeed, ddp_find_unused_parameters=False if ddp else None, report_to=None, eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000, dataloader_num_workers = args.dataloader_num_workers, dataloader_prefetch_factor = args.dataloader_prefetch_factor ), tokenizer=tokenizer, data_collator=collator, ) model.config.use_cache = False # old_state_dict = model.state_dict # model.state_dict = ( # lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) # ).__get__(model, type(model)) if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) trainer.train( resume_from_checkpoint=args.resume_from_checkpoint, ) trainer.save_state() trainer.save_model(output_dir=args.output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser(description='LLMRec') parser = parse_global_args(parser) parser = parse_train_args(parser) parser = parse_dataset_args(parser) args = parser.parse_args() train(args)