|
import argparse |
|
import os |
|
|
|
from accelerate import Accelerator |
|
from datasets import load_dataset |
|
from peft import LoraConfig |
|
from tqdm import tqdm |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed |
|
|
|
from trl import SFTTrainer |
|
from trl.trainer import ConstantLengthDataset |
|
|
|
|
|
""" |
|
Fine-Tune Llama-7b on SE paired dataset |
|
""" |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_path", type=str, default="") |
|
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired") |
|
parser.add_argument("--subset", type=str, default="data/finetune") |
|
parser.add_argument("--split", type=str, default="train") |
|
parser.add_argument("--size_valid_set", type=int, default=4000) |
|
parser.add_argument("--streaming", action="store_true") |
|
parser.add_argument("--shuffle_buffer", type=int, default=5000) |
|
|
|
parser.add_argument("--seq_length", type=int, default=1024) |
|
parser.add_argument("--max_steps", type=int, default=10000) |
|
parser.add_argument("--batch_size", type=int, default=4) |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
|
parser.add_argument("--eos_token_id", type=int, default=49152) |
|
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4) |
|
parser.add_argument("--lr_scheduler_type", type=str, default="cosine") |
|
parser.add_argument("--num_warmup_steps", type=int, default=100) |
|
parser.add_argument("--weight_decay", type=float, default=0.05) |
|
|
|
parser.add_argument("--local_rank", type=int, default=0) |
|
parser.add_argument("--no_fp16", action="store_false") |
|
parser.add_argument("--bf16", action="store_true", default=False) |
|
parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False) |
|
parser.add_argument("--seed", type=int, default=0) |
|
parser.add_argument("--num_workers", type=int, default=None) |
|
parser.add_argument("--output_dir", type=str, default="./checkpoints") |
|
parser.add_argument("--log_freq", default=1, type=int) |
|
parser.add_argument("--eval_freq", default=1000, type=int) |
|
parser.add_argument("--save_freq", default=1000, type=int) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def chars_token_ratio(dataset, tokenizer, nb_examples=400): |
|
""" |
|
Estimate the average number of characters per token in the dataset. |
|
""" |
|
total_characters, total_tokens = 0, 0 |
|
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): |
|
text = prepare_sample_text(example) |
|
total_characters += len(text) |
|
if tokenizer.is_fast: |
|
total_tokens += len(tokenizer(text).tokens()) |
|
else: |
|
total_tokens += len(tokenizer.tokenize(text)) |
|
|
|
return total_characters / total_tokens |
|
|
|
|
|
def print_trainable_parameters(model): |
|
""" |
|
Prints the number of trainable parameters in the model. |
|
""" |
|
trainable_params = 0 |
|
all_param = 0 |
|
for _, param in model.named_parameters(): |
|
all_param += param.numel() |
|
if param.requires_grad: |
|
trainable_params += param.numel() |
|
print( |
|
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" |
|
) |
|
|
|
|
|
def prepare_sample_text(example): |
|
"""Prepare the text from a sample of the dataset.""" |
|
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" |
|
return text |
|
|
|
|
|
def create_datasets(tokenizer, args): |
|
dataset = load_dataset( |
|
args.dataset_name, |
|
data_dir=args.subset, |
|
split=args.split, |
|
use_auth_token=True, |
|
num_proc=args.num_workers if not args.streaming else None, |
|
streaming=args.streaming, |
|
) |
|
if args.streaming: |
|
print("Loading the dataset in streaming mode") |
|
valid_data = dataset.take(args.size_valid_set) |
|
train_data = dataset.skip(args.size_valid_set) |
|
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) |
|
else: |
|
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed) |
|
train_data = dataset["train"] |
|
valid_data = dataset["test"] |
|
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") |
|
|
|
chars_per_token = chars_token_ratio(train_data, tokenizer) |
|
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") |
|
|
|
train_dataset = ConstantLengthDataset( |
|
tokenizer, |
|
train_data, |
|
formatting_func=prepare_sample_text, |
|
infinite=True, |
|
seq_length=args.seq_length, |
|
chars_per_token=chars_per_token, |
|
) |
|
valid_dataset = ConstantLengthDataset( |
|
tokenizer, |
|
valid_data, |
|
formatting_func=prepare_sample_text, |
|
infinite=False, |
|
seq_length=args.seq_length, |
|
chars_per_token=chars_per_token, |
|
) |
|
return train_dataset, valid_dataset |
|
|
|
|
|
def run_training(args, train_data, val_data): |
|
print("Loading the model") |
|
|
|
lora_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
train_data.start_iteration = 0 |
|
|
|
print("Starting main loop") |
|
|
|
training_args = TrainingArguments( |
|
output_dir=args.output_dir, |
|
dataloader_drop_last=True, |
|
evaluation_strategy="steps", |
|
max_steps=args.max_steps, |
|
eval_steps=args.eval_freq, |
|
save_steps=args.save_freq, |
|
logging_steps=args.log_freq, |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
learning_rate=args.learning_rate, |
|
lr_scheduler_type=args.lr_scheduler_type, |
|
warmup_steps=args.num_warmup_steps, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
gradient_checkpointing=not args.no_gradient_checkpointing, |
|
fp16=not args.no_fp16, |
|
bf16=args.bf16, |
|
weight_decay=args.weight_decay, |
|
run_name="llama-7b-finetuned", |
|
report_to="wandb", |
|
ddp_find_unused_parameters=False, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index} |
|
) |
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_data, |
|
eval_dataset=val_data, |
|
peft_config=lora_config, |
|
packing=True, |
|
) |
|
|
|
print_trainable_parameters(trainer.model) |
|
|
|
print("Training...") |
|
trainer.train() |
|
|
|
print("Saving last checkpoint of the model") |
|
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) |
|
|
|
|
|
def main(args): |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
|
train_dataset, eval_dataset = create_datasets(tokenizer, args) |
|
run_training(args, train_dataset, eval_dataset) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
assert args.model_path != "", "Please provide the llama model path" |
|
|
|
set_seed(args.seed) |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
logging.set_verbosity_error() |
|
|
|
main(args) |
|
|