{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "import os\n", "import torch\n", "import torch.multiprocessing as mp\n", "from torch.nn.parallel import DistributedDataParallel as DDP\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, TrainingArguments\n", "from datasets import load_dataset\n", "from trl import SFTTrainer\n", "from peft import LoraConfig\n", "\n", "def init_distributed(rank, world_size):\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", " os.environ[\"MASTER_PORT\"] = \"12345\"\n", " if rank == 0:\n", " print(\"Initializing distributed process group...\")\n", " torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank)\n", "\n", "def cleanup_distributed():\n", " torch.distributed.destroy_process_group()\n", "\n", "def main_worker(rank, world_size):\n", " init_distributed(rank, world_size)\n", "\n", " # Move the finetune() function here\n", " # Load the dataset\n", " dataset_name = \"ruslanmv/ai-medical-dataset\"\n", " dataset = load_dataset(dataset_name, split=\"train\")\n", " # Select the first 1M rows of the dataset\n", " dataset = dataset.select(range(100))\n", " # Load the model + tokenizer\n", " model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n", " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", " tokenizer.pad_token = tokenizer.eos_token\n", " bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.float16,\n", " )\n", " model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " quantization_config=bnb_config,\n", " trust_remote_code=True,\n", " use_cache=False,\n", " )\n", " # Replace the DDP wrapping part with the following lines\n", " model = model.to(rank)\n", " model = DDP(model, device_ids=[rank], output_device=rank)\n", "\n", " # PEFT config\n", " lora_alpha = 16\n", " lora_dropout = 0.1\n", " lora_r = 32 # 64\n", " peft_config = LoraConfig(\n", " lora_alpha=lora_alpha,\n", " lora_dropout=lora_dropout,\n", " r=lora_r,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules=[\"k_proj\", \"q_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n", " modules_to_save=[\"embed_tokens\", \"input_layernorm\", \"post_attention_layernorm\", \"norm\"],\n", " )\n", " # Args\n", " max_seq_length = 512\n", " output_dir = \"./results\"\n", " per_device_train_batch_size = 2 # reduced batch size to avoid OOM\n", " gradient_accumulation_steps = 2\n", " optim = \"adamw_torch\"\n", " save_steps = 10\n", " logging_steps = 1\n", " learning_rate = 2e-4\n", " max_grad_norm = 0.3\n", " max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.\n", " warmup_ratio = 0.1\n", " lr_scheduler_type = \"cosine\"\n", "\n", " training_arguments = TrainingArguments(\n", " output_dir=output_dir,\n", " per_device_train_batch_size=per_device_train_batch_size,\n", " gradient_accumulation_steps=gradient_accumulation_steps,\n", " optim=optim,\n", " save_steps=save_steps,\n", " logging_steps=logging_steps,\n", " learning_rate=learning_rate,\n", " fp16=True,\n", " max_grad_norm=max_grad_norm,\n", " max_steps=max_steps,\n", " warmup_ratio=warmup_ratio,\n", " group_by_length=True,\n", " lr_scheduler_type=lr_scheduler_type,\n", " gradient_checkpointing=True, # gradient checkpointing\n", " report_to=\"wandb\",\n", " )\n", " # Trainer\n", " trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=dataset,\n", " peft_config=peft_config,\n", " dataset_text_field=\"context\",\n", " max_seq_length=max_seq_length,\n", " tokenizer=tokenizer,\n", " args=training_arguments,\n", " )\n", " # Train :)\n", " trainer.train()\n", " cleanup_distributed()\n", "\n", "if __name__ == \"__main__\":\n", " world_size = torch.cuda.device_count()\n", "\n", " processes = []\n", " for rank in range(world_size):\n", " p = mp.Process(target=main_worker, args=(rank, world_size))\n", " p.start()\n", " processes.append(p)\n", "\n", " for p in processes:\n", " p.join()\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }