|
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, TrainerCallback |
|
from datasets import load_dataset |
|
import torch |
|
import os |
|
import psutil |
|
import gc |
|
|
|
|
|
def cleanup_memory(): |
|
gc.collect() |
|
torch.mps.empty_cache() |
|
if hasattr(torch.cuda, 'empty_cache'): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7' |
|
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5' |
|
os.environ['PYTORCH_MPS_ALLOCATOR_POLICY'] = 'garbage_collection_conservative' |
|
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' |
|
|
|
|
|
def print_memory_stats(): |
|
process = psutil.Process() |
|
print(f"RAM Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB") |
|
if hasattr(torch.mps, 'current_allocated_memory'): |
|
print(f"MPS Memory allocated: {torch.mps.current_allocated_memory() / 1024 / 1024:.2f} MB") |
|
|
|
|
|
class MemoryCallback(TrainerCallback): |
|
def __init__(self, print_memory_stats_fn): |
|
self.print_memory_stats_fn = print_memory_stats_fn |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
if state.global_step % 100 == 0: |
|
print(f"\nStep {state.global_step}:") |
|
self.print_memory_stats_fn() |
|
cleanup_memory() |
|
|
|
|
|
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') |
|
print(f"Using device: {device}") |
|
|
|
|
|
model_name = "distilgpt2" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
use_cache=False, |
|
torch_dtype=torch.float32 |
|
) |
|
model.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
train_data = load_dataset("json", data_files={"train": "data.json"}) |
|
|
|
def filter_dataset(example): |
|
return len(example["prompt"]) + len(example["completion"]) <= 512 |
|
|
|
train_data = train_data.filter(filter_dataset) |
|
|
|
|
|
def preprocess_function(examples): |
|
inputs = [prompt + tokenizer.eos_token + completion |
|
for prompt, completion in zip(examples["prompt"], examples["completion"])] |
|
|
|
model_inputs = tokenizer( |
|
inputs, |
|
max_length=256, |
|
truncation=True, |
|
padding="max_length" |
|
) |
|
|
|
model_inputs["labels"] = model_inputs["input_ids"].copy() |
|
return model_inputs |
|
|
|
|
|
train_dataset = train_data["train"].map(preprocess_function, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
num_train_epochs=15, |
|
per_device_train_batch_size=1, |
|
gradient_accumulation_steps=8, |
|
logging_dir="./logs", |
|
fp16=False, |
|
eval_strategy="no", |
|
learning_rate=1e-5, |
|
save_steps=100, |
|
save_total_limit=2, |
|
gradient_checkpointing=True, |
|
optim="adamw_torch", |
|
dataloader_num_workers=0, |
|
dataloader_pin_memory=False, |
|
torch_compile=False, |
|
max_grad_norm=1.0, |
|
logging_steps=5, |
|
max_steps=1000, |
|
warmup_steps=300, |
|
weight_decay=0.2, |
|
logging_first_step=True, |
|
lr_scheduler_type="cosine_with_restarts", |
|
warmup_ratio=0.15, |
|
) |
|
|
|
|
|
cleanup_memory() |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
callbacks=[MemoryCallback(print_memory_stats)] |
|
) |
|
|
|
|
|
print("Initial memory usage:") |
|
print_memory_stats() |
|
|
|
|
|
try: |
|
trainer.train() |
|
except Exception as e: |
|
print(f"Training error: {str(e)}") |
|
cleanup_memory() |
|
try: |
|
model.save_pretrained("./lockin_model_partial") |
|
tokenizer.save_pretrained("./lockin_model_partial") |
|
print("Saved partial progress") |
|
except: |
|
print("Could not save partial progress") |
|
raise e |
|
finally: |
|
cleanup_memory() |
|
|
|
|
|
try: |
|
model.save_pretrained("./lockin_model") |
|
tokenizer.save_pretrained("./lockin_model") |
|
print("Model saved successfully") |
|
except Exception as e: |
|
print(f"Error saving model: {str(e)}") |
|
|
|
|
|
cleanup_memory() |
|
print("\nFinal memory usage:") |
|
print_memory_stats() |