Spaces:
Sleeping
Sleeping
import argparse, json, math, os, time | |
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
from accelerate import Accelerator | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from models.research_model import ResearchTransformer, ModelConfig | |
def save_checkpoint(acc: Accelerator, model, optimizer, ckpt_path: str, epoch: int, step: int, extra: dict): | |
if acc.is_main_process: | |
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) | |
state = { | |
"model": acc.get_state_dict(model), | |
"optimizer": optimizer.state_dict(), | |
"epoch": epoch, | |
"step": step, | |
"extra": extra, | |
} | |
torch.save(state, ckpt_path) | |
def load_checkpoint(model, optimizer, ckpt_path: str): | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(ckpt["model"], strict=False) | |
optimizer.load_state_dict(ckpt["optimizer"]) | |
return ckpt.get("epoch", 0), ckpt.get("step", 0), ckpt.get("extra", {}) | |
def build_tokenizer(name: str): | |
tok = AutoTokenizer.from_pretrained(name) | |
if tok.pad_token is None: | |
tok.pad_token = tok.eos_token | |
return tok | |
def collate_batch(examples, tokenizer, block_size: int): | |
texts = [ex.get("text") or next((v for v in ex.values() if isinstance(v, str)), "") for ex in examples] | |
toks = tokenizer(texts, padding="max_length", truncation=True, max_length=block_size, return_tensors="pt") | |
input_ids = toks["input_ids"] | |
labels = input_ids.clone() | |
return {"input_ids": input_ids, "labels": labels, "attention_mask": toks["attention_mask"]} | |
def main(): | |
ap = argparse.ArgumentParser() | |
ap.add_argument("--config", type=str, required=True) | |
ap.add_argument("--resume", action="store_true") | |
args = ap.parse_args() | |
with open(args.config, "r") as f: | |
cfg = json.load(f) | |
acc = Accelerator() | |
acc.print("Accelerator initialized.") | |
model_arch = cfg.get("model_architecture", "ResearchTransformer (Experimental)") | |
dataset_name = cfg.get("dataset_name", "stas/tiny-stories") | |
tokenizer_name = cfg.get("tokenizer_name", "gpt2") | |
block_size = int(cfg.get("block_size", 256)) | |
batch_size = int(cfg.get("batch_size", 8)) | |
max_batches_per_epoch = int(cfg.get("max_batches_per_epoch", 0)) or None | |
params = cfg.get("params", {}) | |
epochs = int(params.get("epochs", 1)) | |
lr = float(params.get("learning_rate", 5e-5)) | |
wd = float(params.get("weight_decay", 0.01)) | |
accum_steps = int(cfg.get("accum_steps", 1)) | |
results_file = cfg.get("results_file", "results.json") | |
ckpt_path = cfg.get("checkpoint_path", os.path.join(os.path.dirname(results_file) or ".", "checkpoint.pt")) | |
sample_every = int(cfg.get("sample_every_steps", 200)) | |
tokenizer = build_tokenizer(tokenizer_name) | |
vocab_size = int(cfg.get("vocab_size", getattr(tokenizer, 'vocab_size', 65536) or 65536)) | |
if model_arch == "Official Gemma (Baseline)": | |
model = AutoModelForCausalLM.from_pretrained(tokenizer_name) | |
else: | |
mc = ModelConfig( | |
vocab_size=vocab_size, | |
n_layer=int(cfg.get("n_layer", 6)), | |
n_head=int(cfg.get("n_head", 8)), | |
n_embd=int(cfg.get("n_embd", 512)), | |
block_size=block_size, | |
dropout=float(cfg.get("dropout", 0.1)), | |
) | |
model = ResearchTransformer(mc) | |
from datasets import load_dataset | |
raw = load_dataset(dataset_name) | |
if "train" not in raw: | |
raw = {"train": raw} | |
ds = raw["train"] | |
split = ds.train_test_split(test_size=0.05, seed=42) if hasattr(ds, "train_test_split") else {"train": ds, "test": ds} | |
train_ds, val_ds = split["train"], split["test"] | |
from torch.utils.data import DataLoader | |
def collate(examples): | |
return collate_batch(examples, tokenizer, block_size) | |
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate) | |
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) | |
model, optimizer, train_loader, val_loader = acc.prepare(model, optimizer, train_loader, val_loader) | |
start_epoch = 0 | |
global_step = 0 | |
if args.resume and os.path.exists(ckpt_path): | |
start_epoch, global_step, _ = load_checkpoint(model, optimizer, ckpt_path) | |
acc.print(f"Resumed from checkpoint at epoch {start_epoch}, step {global_step}") | |
os.makedirs(os.path.dirname(results_file) or ".", exist_ok=True) | |
results = {"config": cfg, "status": "running", "history": [], "samples": []} | |
def evaluate(): | |
model.eval() | |
losses = [] | |
with torch.no_grad(): | |
for i, batch in enumerate(val_loader): | |
out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) | |
losses.append(acc.gather_for_metrics(out.loss.detach().repeat(batch["input_ids"].size(0)))) | |
if max_batches_per_epoch and i + 1 >= max_batches_per_epoch: | |
break | |
loss = torch.cat(losses).mean().item() | |
ppl = math.exp(min(20.0, loss)) | |
return loss, ppl | |
def sample_text(prompt: str = "Once upon a time"): | |
model.eval() | |
with torch.no_grad(): | |
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(acc.device) | |
gen = model.generate(ids, max_new_tokens=64) | |
text = tokenizer.decode(gen[0], skip_special_tokens=True) | |
return text | |
best_val = float("inf") | |
patience, bad_epochs = 3, 0 | |
start_time = time.time() | |
for epoch in range(start_epoch, epochs): | |
model.train() | |
epoch_start = time.time() | |
optimizer.zero_grad() | |
running_loss = 0.0 | |
for i, batch in enumerate(train_loader): | |
out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) | |
loss = out.loss / accum_steps | |
acc.backward(loss) | |
if (i + 1) % accum_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
running_loss += out.loss.detach().item() | |
global_step += 1 | |
if sample_every and global_step % sample_every == 0 and acc.is_main_process: | |
results["samples"].append({"step": global_step, "text": sample_text()}) | |
if max_batches_per_epoch and i + 1 >= max_batches_per_epoch: | |
break | |
if (i + 1) % accum_steps != 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
train_time = time.time() - epoch_start | |
val_loss, val_ppl = evaluate() | |
try: | |
mem = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
except Exception: | |
mem = None | |
results["history"].append({ | |
"epoch": epoch + 1, | |
"train_time_sec": train_time, | |
"val_loss": val_loss, | |
"val_ppl": val_ppl, | |
"max_cuda_mem_gb": mem, | |
"effective_batch_size": batch_size * accum_steps, | |
}) | |
improve = val_loss < best_val - 1e-5 | |
if improve: | |
best_val = val_loss | |
bad_epochs = 0 | |
save_checkpoint(acc, model, optimizer, ckpt_path, epoch + 1, global_step, {"best_val": best_val}) | |
else: | |
bad_epochs += 1 | |
if bad_epochs >= patience: | |
acc.print("Early stopping triggered.") | |
break | |
if acc.is_main_process: | |
with open(results_file, "w") as f: | |
json.dump(results, f, indent=2) | |
total = time.time() - start_time | |
results["status"] = "completed" | |
results["total_training_time_sec"] = total | |
results["final_validation"] = {"loss": best_val, "perplexity": math.exp(min(20.0, best_val))} | |
if acc.is_main_process: | |
with open(results_file, "w") as f: | |
json.dump(results, f, indent=2) | |
acc.print(f"Done in {total/60:.1f} min. Best val {best_val:.4f}") | |
if __name__ == "__main__": | |
main() | |