lockinaiv2 / train.py
jonngan's picture
Upload 2 files
2a17ce7 verified
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from datasets import load_dataset
import torch
import os
import psutil
import gc
# Memory management and environment setup
def cleanup_memory():
gc.collect()
torch.mps.empty_cache()
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
# Set MPS memory limits and environment variables
# Note: Changed watermark ratio to a more conservative value
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7' # Changed from 0.8
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5' # Added explicit low watermark
os.environ['PYTORCH_MPS_ALLOCATOR_POLICY'] = 'garbage_collection_conservative'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
# Memory monitoring
def print_memory_stats():
process = psutil.Process()
print(f"RAM Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")
if hasattr(torch.mps, 'current_allocated_memory'):
print(f"MPS Memory allocated: {torch.mps.current_allocated_memory() / 1024 / 1024:.2f} MB")
# Custom callback for memory monitoring
class MemoryCallback(TrainerCallback):
def __init__(self, print_memory_stats_fn):
self.print_memory_stats_fn = print_memory_stats_fn
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % 100 == 0:
print(f"\nStep {state.global_step}:")
self.print_memory_stats_fn()
cleanup_memory()
# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
# Load model and tokenizer
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
torch_dtype=torch.float32
)
model.to(device) # Explicitly move model to device
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add pad token
tokenizer.pad_token = tokenizer.eos_token
# Load and filter dataset
train_data = load_dataset("json", data_files={"train": "data.json"})
def filter_dataset(example):
return len(example["prompt"]) + len(example["completion"]) <= 512
train_data = train_data.filter(filter_dataset)
# Preprocess function
def preprocess_function(examples):
inputs = [prompt + tokenizer.eos_token + completion
for prompt, completion in zip(examples["prompt"], examples["completion"])]
model_inputs = tokenizer(
inputs,
max_length=256,
truncation=True,
padding="max_length"
)
model_inputs["labels"] = model_inputs["input_ids"].copy()
return model_inputs
# Preprocess the dataset
train_dataset = train_data["train"].map(preprocess_function, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=15,
per_device_train_batch_size=1,
gradient_accumulation_steps=8, # Reduced from 32
logging_dir="./logs",
fp16=False,
eval_strategy="no",
learning_rate=1e-5, # Reduced from 5e-5
save_steps=100,
save_total_limit=2,
gradient_checkpointing=True,
optim="adamw_torch",
dataloader_num_workers=0,
dataloader_pin_memory=False,
torch_compile=False,
max_grad_norm=1.0, # Increased from 0.5
logging_steps=5, # More frequent logging
max_steps=1000,
warmup_steps=300, # Increased warmup steps
weight_decay=0.2, # Increased from 0.01
logging_first_step=True,
lr_scheduler_type="cosine_with_restarts", # Changed to cosine with restarts
warmup_ratio=0.15, # Increased warmup ratio
)
# Clear cache before training
cleanup_memory()
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
callbacks=[MemoryCallback(print_memory_stats)]
)
# Monitor initial memory usage
print("Initial memory usage:")
print_memory_stats()
# Training with error handling
try:
trainer.train()
except Exception as e:
print(f"Training error: {str(e)}")
cleanup_memory()
try:
model.save_pretrained("./lockin_model_partial")
tokenizer.save_pretrained("./lockin_model_partial")
print("Saved partial progress")
except:
print("Could not save partial progress")
raise e
finally:
cleanup_memory()
# Save the complete model
try:
model.save_pretrained("./lockin_model")
tokenizer.save_pretrained("./lockin_model")
print("Model saved successfully")
except Exception as e:
print(f"Error saving model: {str(e)}")
# Final cleanup
cleanup_memory()
print("\nFinal memory usage:")
print_memory_stats()