remiai3's picture
Upload 13 files
03afb93 verified
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
import torch
import json
import matplotlib.pyplot as plt
# Load model and tokenizer
model_path = "../tinyllama_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
# Set pad token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# Load dataset from JSON
with open("dataset.json", "r") as f:
data = json.load(f)
dataset = Dataset.from_list(data)
# Tokenize dataset and include labels
def tokenize_function(examples):
inputs = [f"<|USER|> {p} <|ASSISTANT|> {r}" for p, r in zip(examples["prompt"], examples["response"])]
tokenized = tokenizer(inputs, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["prompt", "response"])
# Configure LoRA for efficient fine-tuning
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Training arguments
training_args = TrainingArguments(
output_dir="./finetuned_weights",
num_train_epochs=3,
per_device_train_batch_size=1,
save_strategy="epoch",
logging_steps=1,
learning_rate=2e-4,
fp16=False,
report_to="none"
)
# Trainer (no validation dataset due to small size)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
# Fine-tune model
train_result = trainer.train()
# Save fine-tuned weights
model.save_pretrained("./finetuned_weights")
tokenizer.save_pretrained("./finetuned_weights")
# Extract metrics
train_loss = [log["loss"] for log in trainer.state.log_history if "loss" in log]
learning_rate = [log["learning_rate"] for log in trainer.state.log_history if "learning_rate" in log]
# Print final metrics
print(f"Final Train Loss: {train_loss[-1] if train_loss else 'N/A'}")
print(f"Final Learning Rate: {learning_rate[-1] if learning_rate else 'N/A'}")
# Plot train loss
plt.figure(figsize=(10, 6))
if train_loss:
plt.plot(range(len(train_loss)), train_loss, label="Train Loss", color="#2563eb")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.grid()
plt.savefig("loss_plot.png")
plt.show()