anaantraj's picture
Create app.py
57951e8 verified
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
from datasets import load_dataset
# Load the base model (TinyLlama)
model_name = "NousResearch/Hermes-3-Llama-3.2-3B"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Prepare model for QLoRA
model = prepare_model_for_kbit_training(model)
# LoRA Configuration
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Load dataset
dataset = load_dataset("json", data_files="sevaai_faq.json")
from datasets import load_dataset
# Load dataset from your JSON file
dataset = load_dataset("json", data_files="sevaai_faq.json")
# Rename the "output" column to "text" so SFTTrainer can find it
dataset["train"] = dataset["train"].rename_column("output", "text")
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=2e-4,
logging_steps=10,
output_dir="./nirmaya",
save_steps=1000,
save_total_limit=2,
optim="adamw_torch"
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
peft_config=lora_config,
tokenizer=tokenizer,
args=training_args
)
# Train the model
trainer.train()
# Save fine-tuned model
trainer.save_model("./nirmaya")
print("Fine-tuning complete! Model saved to ./nirmaya")