from huggingface_hub import login import os token = os.environ.get("HF_TOKEN") if token: login(token) import argparse, json, math, os, time from dataclasses import dataclass from typing import Optional import torch from accelerate import Accelerator from transformers import AutoTokenizer, AutoModelForCausalLM from models.research_model import ResearchTransformer, ModelConfig def save_checkpoint(acc: Accelerator, model, optimizer, ckpt_path: str, epoch: int, step: int, extra: dict): if acc.is_main_process: os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) state = { "model": acc.get_state_dict(model), "optimizer": optimizer.state_dict(), "epoch": epoch, "step": step, "extra": extra, } torch.save(state, ckpt_path) def load_checkpoint(model, optimizer, ckpt_path: str): ckpt = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt["model"], strict=False) optimizer.load_state_dict(ckpt["optimizer"]) return ckpt.get("epoch", 0), ckpt.get("step", 0), ckpt.get("extra", {}) def build_tokenizer(name: str): tok = AutoTokenizer.from_pretrained(name) if tok.pad_token is None: tok.pad_token = tok.eos_token return tok def collate_batch(examples, tokenizer, block_size: int): texts = [ex.get("text") or next((v for v in ex.values() if isinstance(v, str)), "") for ex in examples] toks = tokenizer(texts, padding="max_length", truncation=True, max_length=block_size, return_tensors="pt") input_ids = toks["input_ids"] labels = input_ids.clone() return {"input_ids": input_ids, "labels": labels, "attention_mask": toks["attention_mask"]} def main(): ap = argparse.ArgumentParser() ap.add_argument("--config", type=str, required=True) ap.add_argument("--resume", action="store_true") args = ap.parse_args() with open(args.config, "r") as f: cfg = json.load(f) acc = Accelerator() acc.print("Accelerator initialized.") model_arch = cfg.get("model_architecture", "ResearchTransformer (Experimental)") dataset_name = cfg.get("dataset_name", "stas/tiny-stories") tokenizer_name = cfg.get("tokenizer_name", "gpt2") block_size = int(cfg.get("block_size", 256)) batch_size = int(cfg.get("batch_size", 8)) max_batches_per_epoch = int(cfg.get("max_batches_per_epoch", 0)) or None params = cfg.get("params", {}) epochs = int(params.get("epochs", 1)) lr = float(params.get("learning_rate", 5e-5)) wd = float(params.get("weight_decay", 0.01)) accum_steps = int(cfg.get("accum_steps", 1)) results_file = cfg.get("results_file", "results.json") ckpt_path = cfg.get("checkpoint_path", os.path.join(os.path.dirname(results_file) or ".", "checkpoint.pt")) sample_every = int(cfg.get("sample_every_steps", 200)) tokenizer = build_tokenizer(tokenizer_name) vocab_size = int(cfg.get("vocab_size", getattr(tokenizer, 'vocab_size', 65536) or 65536)) if model_arch == "Official Gemma (Baseline)": model = AutoModelForCausalLM.from_pretrained(tokenizer_name) else: mc = ModelConfig( vocab_size=vocab_size, n_layer=int(cfg.get("n_layer", 6)), n_head=int(cfg.get("n_head", 8)), n_embd=int(cfg.get("n_embd", 512)), block_size=block_size, dropout=float(cfg.get("dropout", 0.1)), ) model = ResearchTransformer(mc) from datasets import load_dataset raw = load_dataset(dataset_name) if "train" not in raw: raw = {"train": raw} ds = raw["train"] split = ds.train_test_split(test_size=0.05, seed=42) if hasattr(ds, "train_test_split") else {"train": ds, "test": ds} train_ds, val_ds = split["train"], split["test"] from torch.utils.data import DataLoader def collate(examples): return collate_batch(examples, tokenizer, block_size) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) model, optimizer, train_loader, val_loader = acc.prepare(model, optimizer, train_loader, val_loader) start_epoch = 0 global_step = 0 if args.resume and os.path.exists(ckpt_path): start_epoch, global_step, _ = load_checkpoint(model, optimizer, ckpt_path) acc.print(f"Resumed from checkpoint at epoch {start_epoch}, step {global_step}") os.makedirs(os.path.dirname(results_file) or ".", exist_ok=True) results = {"config": cfg, "status": "running", "history": [], "samples": []} def evaluate(): model.eval() losses = [] with torch.no_grad(): for i, batch in enumerate(val_loader): out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) losses.append(acc.gather_for_metrics(out.loss.detach().repeat(batch["input_ids"].size(0)))) if max_batches_per_epoch and i + 1 >= max_batches_per_epoch: break loss = torch.cat(losses).mean().item() ppl = math.exp(min(20.0, loss)) return loss, ppl def sample_text(prompt: str = "Once upon a time"): model.eval() with torch.no_grad(): ids = tokenizer(prompt, return_tensors="pt").input_ids.to(acc.device) gen = model.generate(ids, max_new_tokens=64) text = tokenizer.decode(gen[0], skip_special_tokens=True) return text best_val = float("inf") patience, bad_epochs = 3, 0 start_time = time.time() for epoch in range(start_epoch, epochs): model.train() epoch_start = time.time() optimizer.zero_grad() running_loss = 0.0 for i, batch in enumerate(train_loader): out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss = out.loss / accum_steps acc.backward(loss) if (i + 1) % accum_steps == 0: optimizer.step() optimizer.zero_grad() running_loss += out.loss.detach().item() global_step += 1 if sample_every and global_step % sample_every == 0 and acc.is_main_process: results["samples"].append({"step": global_step, "text": sample_text()}) if max_batches_per_epoch and i + 1 >= max_batches_per_epoch: break if (i + 1) % accum_steps != 0: optimizer.step() optimizer.zero_grad() train_time = time.time() - epoch_start val_loss, val_ppl = evaluate() try: mem = torch.cuda.max_memory_allocated() / (1024 ** 3) except Exception: mem = None results["history"].append({ "epoch": epoch + 1, "train_time_sec": train_time, "val_loss": val_loss, "val_ppl": val_ppl, "max_cuda_mem_gb": mem, "effective_batch_size": batch_size * accum_steps, }) improve = val_loss < best_val - 1e-5 if improve: best_val = val_loss bad_epochs = 0 save_checkpoint(acc, model, optimizer, ckpt_path, epoch + 1, global_step, {"best_val": best_val}) else: bad_epochs += 1 if bad_epochs >= patience: acc.print("Early stopping triggered.") break if acc.is_main_process: with open(results_file, "w") as f: json.dump(results, f, indent=2) total = time.time() - start_time results["status"] = "completed" results["total_training_time_sec"] = total results["final_validation"] = {"loss": best_val, "perplexity": math.exp(min(20.0, best_val))} if acc.is_main_process: with open(results_file, "w") as f: json.dump(results, f, indent=2) acc.print(f"Done in {total/60:.1f} min. Best val {best_val:.4f}") if __name__ == "__main__": main()