Spaces:
Sleeping
Sleeping
| import argparse, json, math, os, time | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from accelerate import Accelerator | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from models.research_model import ResearchTransformer, ModelConfig | |
| def save_checkpoint(acc: Accelerator, model, optimizer, ckpt_path: str, epoch: int, step: int, extra: dict): | |
| if acc.is_main_process: | |
| os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) | |
| state = { | |
| "model": acc.get_state_dict(model), | |
| "optimizer": optimizer.state_dict(), | |
| "epoch": epoch, | |
| "step": step, | |
| "extra": extra, | |
| } | |
| torch.save(state, ckpt_path) | |
| def load_checkpoint(model, optimizer, ckpt_path: str): | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(ckpt["model"], strict=False) | |
| optimizer.load_state_dict(ckpt["optimizer"]) | |
| return ckpt.get("epoch", 0), ckpt.get("step", 0), ckpt.get("extra", {}) | |
| def build_tokenizer(name: str): | |
| tok = AutoTokenizer.from_pretrained(name) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| return tok | |
| def collate_batch(examples, tokenizer, block_size: int): | |
| texts = [ex.get("text") or next((v for v in ex.values() if isinstance(v, str)), "") for ex in examples] | |
| toks = tokenizer(texts, padding="max_length", truncation=True, max_length=block_size, return_tensors="pt") | |
| input_ids = toks["input_ids"] | |
| labels = input_ids.clone() | |
| return {"input_ids": input_ids, "labels": labels, "attention_mask": toks["attention_mask"]} | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--config", type=str, required=True) | |
| ap.add_argument("--resume", action="store_true") | |
| args = ap.parse_args() | |
| with open(args.config, "r") as f: | |
| cfg = json.load(f) | |
| acc = Accelerator() | |
| acc.print("Accelerator initialized.") | |
| model_arch = cfg.get("model_architecture", "ResearchTransformer (Experimental)") | |
| dataset_name = cfg.get("dataset_name", "stas/tiny-stories") | |
| tokenizer_name = cfg.get("tokenizer_name", "gpt2") | |
| block_size = int(cfg.get("block_size", 256)) | |
| batch_size = int(cfg.get("batch_size", 8)) | |
| max_batches_per_epoch = int(cfg.get("max_batches_per_epoch", 0)) or None | |
| params = cfg.get("params", {}) | |
| epochs = int(params.get("epochs", 1)) | |
| lr = float(params.get("learning_rate", 5e-5)) | |
| wd = float(params.get("weight_decay", 0.01)) | |
| accum_steps = int(cfg.get("accum_steps", 1)) | |
| results_file = cfg.get("results_file", "results.json") | |
| ckpt_path = cfg.get("checkpoint_path", os.path.join(os.path.dirname(results_file) or ".", "checkpoint.pt")) | |
| sample_every = int(cfg.get("sample_every_steps", 200)) | |
| tokenizer = build_tokenizer(tokenizer_name) | |
| vocab_size = int(cfg.get("vocab_size", getattr(tokenizer, 'vocab_size', 65536) or 65536)) | |
| if model_arch == "Official Gemma (Baseline)": | |
| model = AutoModelForCausalLM.from_pretrained(tokenizer_name) | |
| else: | |
| mc = ModelConfig( | |
| vocab_size=vocab_size, | |
| n_layer=int(cfg.get("n_layer", 6)), | |
| n_head=int(cfg.get("n_head", 8)), | |
| n_embd=int(cfg.get("n_embd", 512)), | |
| block_size=block_size, | |
| dropout=float(cfg.get("dropout", 0.1)), | |
| ) | |
| model = ResearchTransformer(mc) | |
| from datasets import load_dataset | |
| raw = load_dataset(dataset_name) | |
| if "train" not in raw: | |
| raw = {"train": raw} | |
| ds = raw["train"] | |
| split = ds.train_test_split(test_size=0.05, seed=42) if hasattr(ds, "train_test_split") else {"train": ds, "test": ds} | |
| train_ds, val_ds = split["train"], split["test"] | |
| from torch.utils.data import DataLoader | |
| def collate(examples): | |
| return collate_batch(examples, tokenizer, block_size) | |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate) | |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) | |
| model, optimizer, train_loader, val_loader = acc.prepare(model, optimizer, train_loader, val_loader) | |
| start_epoch = 0 | |
| global_step = 0 | |
| if args.resume and os.path.exists(ckpt_path): | |
| start_epoch, global_step, _ = load_checkpoint(model, optimizer, ckpt_path) | |
| acc.print(f"Resumed from checkpoint at epoch {start_epoch}, step {global_step}") | |
| os.makedirs(os.path.dirname(results_file) or ".", exist_ok=True) | |
| results = {"config": cfg, "status": "running", "history": [], "samples": []} | |
| def evaluate(): | |
| model.eval() | |
| losses = [] | |
| with torch.no_grad(): | |
| for i, batch in enumerate(val_loader): | |
| out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) | |
| losses.append(acc.gather_for_metrics(out.loss.detach().repeat(batch["input_ids"].size(0)))) | |
| if max_batches_per_epoch and i + 1 >= max_batches_per_epoch: | |
| break | |
| loss = torch.cat(losses).mean().item() | |
| ppl = math.exp(min(20.0, loss)) | |
| return loss, ppl | |
| def sample_text(prompt: str = "Once upon a time"): | |
| model.eval() | |
| with torch.no_grad(): | |
| ids = tokenizer(prompt, return_tensors="pt").input_ids.to(acc.device) | |
| gen = model.generate(ids, max_new_tokens=64) | |
| text = tokenizer.decode(gen[0], skip_special_tokens=True) | |
| return text | |
| best_val = float("inf") | |
| patience, bad_epochs = 3, 0 | |
| start_time = time.time() | |
| for epoch in range(start_epoch, epochs): | |
| model.train() | |
| epoch_start = time.time() | |
| optimizer.zero_grad() | |
| running_loss = 0.0 | |
| for i, batch in enumerate(train_loader): | |
| out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) | |
| loss = out.loss / accum_steps | |
| acc.backward(loss) | |
| if (i + 1) % accum_steps == 0: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| running_loss += out.loss.detach().item() | |
| global_step += 1 | |
| if sample_every and global_step % sample_every == 0 and acc.is_main_process: | |
| results["samples"].append({"step": global_step, "text": sample_text()}) | |
| if max_batches_per_epoch and i + 1 >= max_batches_per_epoch: | |
| break | |
| if (i + 1) % accum_steps != 0: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| train_time = time.time() - epoch_start | |
| val_loss, val_ppl = evaluate() | |
| try: | |
| mem = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
| except Exception: | |
| mem = None | |
| results["history"].append({ | |
| "epoch": epoch + 1, | |
| "train_time_sec": train_time, | |
| "val_loss": val_loss, | |
| "val_ppl": val_ppl, | |
| "max_cuda_mem_gb": mem, | |
| "effective_batch_size": batch_size * accum_steps, | |
| }) | |
| improve = val_loss < best_val - 1e-5 | |
| if improve: | |
| best_val = val_loss | |
| bad_epochs = 0 | |
| save_checkpoint(acc, model, optimizer, ckpt_path, epoch + 1, global_step, {"best_val": best_val}) | |
| else: | |
| bad_epochs += 1 | |
| if bad_epochs >= patience: | |
| acc.print("Early stopping triggered.") | |
| break | |
| if acc.is_main_process: | |
| with open(results_file, "w") as f: | |
| json.dump(results, f, indent=2) | |
| total = time.time() - start_time | |
| results["status"] = "completed" | |
| results["total_training_time_sec"] = total | |
| results["final_validation"] = {"loss": best_val, "perplexity": math.exp(min(20.0, best_val))} | |
| if acc.is_main_process: | |
| with open(results_file, "w") as f: | |
| json.dump(results, f, indent=2) | |
| acc.print(f"Done in {total/60:.1f} min. Best val {best_val:.4f}") | |
| if __name__ == "__main__": | |
| main() | |