from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from datasets import Dataset from peft import LoraConfig, get_peft_model import torch import json import matplotlib.pyplot as plt # Load model and tokenizer model_path = "../tinyllama_model" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) # Set pad token tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = tokenizer.eos_token_id # Load dataset from JSON with open("dataset.json", "r") as f: data = json.load(f) dataset = Dataset.from_list(data) # Tokenize dataset and include labels def tokenize_function(examples): inputs = [f"<|USER|> {p} <|ASSISTANT|> {r}" for p, r in zip(examples["prompt"], examples["response"])] tokenized = tokenizer(inputs, padding="max_length", truncation=True, max_length=128, return_tensors="pt") tokenized["labels"] = tokenized["input_ids"].clone() return tokenized tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["prompt", "response"]) # Configure LoRA for efficient fine-tuning lora_config = LoraConfig( r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # Training arguments training_args = TrainingArguments( output_dir="./finetuned_weights", num_train_epochs=3, per_device_train_batch_size=1, save_strategy="epoch", logging_steps=1, learning_rate=2e-4, fp16=False, report_to="none" ) # Trainer (no validation dataset due to small size) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, ) # Fine-tune model train_result = trainer.train() # Save fine-tuned weights model.save_pretrained("./finetuned_weights") tokenizer.save_pretrained("./finetuned_weights") # Extract metrics train_loss = [log["loss"] for log in trainer.state.log_history if "loss" in log] learning_rate = [log["learning_rate"] for log in trainer.state.log_history if "learning_rate" in log] # Print final metrics print(f"Final Train Loss: {train_loss[-1] if train_loss else 'N/A'}") print(f"Final Learning Rate: {learning_rate[-1] if learning_rate else 'N/A'}") # Plot train loss plt.figure(figsize=(10, 6)) if train_loss: plt.plot(range(len(train_loss)), train_loss, label="Train Loss", color="#2563eb") plt.xlabel("Steps") plt.ylabel("Loss") plt.title("Training Loss") plt.legend() plt.grid() plt.savefig("loss_plot.png") plt.show()