Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- .gitattributes +35 -35
- README.md +20 -14
- data.py +50 -0
- main.py +5 -0
- requirements-colab.txt +2 -0
- requirements.txt +11 -0
- train.py +210 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,14 +1,20 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.44.0
|
8 |
-
app_file:
|
9 |
-
pinned: false
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: LLM Algorithm Lab
|
3 |
+
emoji: 🧪
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.44.0
|
8 |
+
app_file: main.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# Scientific LLM Algorithm Laboratory — Refactor (Full)
|
13 |
+
|
14 |
+
This repository contains the full refactor with:
|
15 |
+
- Hugging Face Spaces demo UI (toy runs)
|
16 |
+
- Colab UI with full hyperparameters
|
17 |
+
- Secure GitHub pushing (token via env)
|
18 |
+
- Robust dataloader and training orchestrator
|
19 |
+
|
20 |
+
See app/ for UI and core orchestrator. Use requirements-colab.txt for Colab.
|
data.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
def build_dataloaders(dataset_name: str, tokenizer_name: str, batch_size: int, val_split: float = 0.05, block_size: int = 512, num_workers: int = 2) -> Tuple:
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
8 |
+
if tokenizer.pad_token is None:
|
9 |
+
tokenizer.pad_token = tokenizer.eos_token
|
10 |
+
|
11 |
+
raw = load_dataset(dataset_name)
|
12 |
+
if 'train' not in raw:
|
13 |
+
raw = {'train': raw}
|
14 |
+
if isinstance(raw, dict) and 'train' in raw:
|
15 |
+
ds = raw['train']
|
16 |
+
else:
|
17 |
+
ds = raw
|
18 |
+
|
19 |
+
split = ds.train_test_split(test_size=val_split, seed=42) if hasattr(ds, 'train_test_split') else {'train': ds, 'test': ds}
|
20 |
+
train_ds, val_ds = split['train'], split['test']
|
21 |
+
|
22 |
+
def text_key(example):
|
23 |
+
for k in example.keys():
|
24 |
+
if example[k] is not None and isinstance(example[k], str):
|
25 |
+
return k
|
26 |
+
return None
|
27 |
+
|
28 |
+
sample = train_ds[0]
|
29 |
+
tkey = text_key(sample) or 'text'
|
30 |
+
|
31 |
+
train_tok = train_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=train_ds.column_names)
|
32 |
+
val_tok = val_ds.map(lambda ex: tokenizer(ex[tkey], truncation=True, padding='max_length', max_length=block_size), batched=True, remove_columns=val_ds.column_names)
|
33 |
+
|
34 |
+
def labelize(batch):
|
35 |
+
input_ids = batch['input_ids']
|
36 |
+
labels = [ids[:] for ids in input_ids]
|
37 |
+
for i, ids in enumerate(labels):
|
38 |
+
labels[i] = [(-100 if token == tokenizer.pad_token_id else token) for token in ids]
|
39 |
+
batch['labels'] = labels
|
40 |
+
return batch
|
41 |
+
|
42 |
+
train_tok = train_tok.map(labelize, batched=True)
|
43 |
+
val_tok = val_tok.map(labelize, batched=True)
|
44 |
+
|
45 |
+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
46 |
+
|
47 |
+
train_loader = DataLoader(train_tok, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collator)
|
48 |
+
val_loader = DataLoader(val_tok, batch_size=max(2, batch_size), shuffle=False, num_workers=num_workers, collate_fn=collator)
|
49 |
+
|
50 |
+
return tokenizer, train_loader, val_loader
|
main.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.ui.ui_spaces import build as build_space
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
app = build_space()
|
5 |
+
app.launch()
|
requirements-colab.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
-r requirements.txt
|
2 |
+
bitsandbytes
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
accelerate
|
4 |
+
gradio
|
5 |
+
pandas
|
6 |
+
datasets
|
7 |
+
sentencepiece
|
8 |
+
PyGithub
|
9 |
+
wandb
|
10 |
+
huggingface_hub
|
11 |
+
tenacity
|
train.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse, json, math, os, time
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
|
10 |
+
from models.research_model import ResearchTransformer, ModelConfig
|
11 |
+
|
12 |
+
def save_checkpoint(acc: Accelerator, model, optimizer, ckpt_path: str, epoch: int, step: int, extra: dict):
|
13 |
+
if acc.is_main_process:
|
14 |
+
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
15 |
+
state = {
|
16 |
+
"model": acc.get_state_dict(model),
|
17 |
+
"optimizer": optimizer.state_dict(),
|
18 |
+
"epoch": epoch,
|
19 |
+
"step": step,
|
20 |
+
"extra": extra,
|
21 |
+
}
|
22 |
+
torch.save(state, ckpt_path)
|
23 |
+
|
24 |
+
def load_checkpoint(model, optimizer, ckpt_path: str):
|
25 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
26 |
+
model.load_state_dict(ckpt["model"], strict=False)
|
27 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
28 |
+
return ckpt.get("epoch", 0), ckpt.get("step", 0), ckpt.get("extra", {})
|
29 |
+
|
30 |
+
def build_tokenizer(name: str):
|
31 |
+
tok = AutoTokenizer.from_pretrained(name)
|
32 |
+
if tok.pad_token is None:
|
33 |
+
tok.pad_token = tok.eos_token
|
34 |
+
return tok
|
35 |
+
|
36 |
+
def collate_batch(examples, tokenizer, block_size: int):
|
37 |
+
texts = [ex.get("text") or next((v for v in ex.values() if isinstance(v, str)), "") for ex in examples]
|
38 |
+
toks = tokenizer(texts, padding="max_length", truncation=True, max_length=block_size, return_tensors="pt")
|
39 |
+
input_ids = toks["input_ids"]
|
40 |
+
labels = input_ids.clone()
|
41 |
+
return {"input_ids": input_ids, "labels": labels, "attention_mask": toks["attention_mask"]}
|
42 |
+
|
43 |
+
def main():
|
44 |
+
ap = argparse.ArgumentParser()
|
45 |
+
ap.add_argument("--config", type=str, required=True)
|
46 |
+
ap.add_argument("--resume", action="store_true")
|
47 |
+
args = ap.parse_args()
|
48 |
+
|
49 |
+
with open(args.config, "r") as f:
|
50 |
+
cfg = json.load(f)
|
51 |
+
|
52 |
+
acc = Accelerator()
|
53 |
+
acc.print("Accelerator initialized.")
|
54 |
+
|
55 |
+
model_arch = cfg.get("model_architecture", "ResearchTransformer (Experimental)")
|
56 |
+
dataset_name = cfg.get("dataset_name", "stas/tiny-stories")
|
57 |
+
tokenizer_name = cfg.get("tokenizer_name", "gpt2")
|
58 |
+
block_size = int(cfg.get("block_size", 256))
|
59 |
+
batch_size = int(cfg.get("batch_size", 8))
|
60 |
+
max_batches_per_epoch = int(cfg.get("max_batches_per_epoch", 0)) or None
|
61 |
+
|
62 |
+
params = cfg.get("params", {})
|
63 |
+
epochs = int(params.get("epochs", 1))
|
64 |
+
lr = float(params.get("learning_rate", 5e-5))
|
65 |
+
wd = float(params.get("weight_decay", 0.01))
|
66 |
+
accum_steps = int(cfg.get("accum_steps", 1))
|
67 |
+
|
68 |
+
results_file = cfg.get("results_file", "results.json")
|
69 |
+
ckpt_path = cfg.get("checkpoint_path", os.path.join(os.path.dirname(results_file) or ".", "checkpoint.pt"))
|
70 |
+
sample_every = int(cfg.get("sample_every_steps", 200))
|
71 |
+
|
72 |
+
tokenizer = build_tokenizer(tokenizer_name)
|
73 |
+
vocab_size = int(cfg.get("vocab_size", getattr(tokenizer, 'vocab_size', 65536) or 65536))
|
74 |
+
|
75 |
+
if model_arch == "Official Gemma (Baseline)":
|
76 |
+
model = AutoModelForCausalLM.from_pretrained(tokenizer_name)
|
77 |
+
else:
|
78 |
+
mc = ModelConfig(
|
79 |
+
vocab_size=vocab_size,
|
80 |
+
n_layer=int(cfg.get("n_layer", 6)),
|
81 |
+
n_head=int(cfg.get("n_head", 8)),
|
82 |
+
n_embd=int(cfg.get("n_embd", 512)),
|
83 |
+
block_size=block_size,
|
84 |
+
dropout=float(cfg.get("dropout", 0.1)),
|
85 |
+
)
|
86 |
+
model = ResearchTransformer(mc)
|
87 |
+
|
88 |
+
from datasets import load_dataset
|
89 |
+
raw = load_dataset(dataset_name)
|
90 |
+
if "train" not in raw:
|
91 |
+
raw = {"train": raw}
|
92 |
+
ds = raw["train"]
|
93 |
+
split = ds.train_test_split(test_size=0.05, seed=42) if hasattr(ds, "train_test_split") else {"train": ds, "test": ds}
|
94 |
+
train_ds, val_ds = split["train"], split["test"]
|
95 |
+
|
96 |
+
from torch.utils.data import DataLoader
|
97 |
+
def collate(examples):
|
98 |
+
return collate_batch(examples, tokenizer, block_size)
|
99 |
+
|
100 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate)
|
101 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
|
102 |
+
|
103 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
104 |
+
|
105 |
+
model, optimizer, train_loader, val_loader = acc.prepare(model, optimizer, train_loader, val_loader)
|
106 |
+
|
107 |
+
start_epoch = 0
|
108 |
+
global_step = 0
|
109 |
+
if args.resume and os.path.exists(ckpt_path):
|
110 |
+
start_epoch, global_step, _ = load_checkpoint(model, optimizer, ckpt_path)
|
111 |
+
acc.print(f"Resumed from checkpoint at epoch {start_epoch}, step {global_step}")
|
112 |
+
|
113 |
+
os.makedirs(os.path.dirname(results_file) or ".", exist_ok=True)
|
114 |
+
results = {"config": cfg, "status": "running", "history": [], "samples": []}
|
115 |
+
|
116 |
+
def evaluate():
|
117 |
+
model.eval()
|
118 |
+
losses = []
|
119 |
+
with torch.no_grad():
|
120 |
+
for i, batch in enumerate(val_loader):
|
121 |
+
out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
122 |
+
losses.append(acc.gather_for_metrics(out.loss.detach().repeat(batch["input_ids"].size(0))))
|
123 |
+
if max_batches_per_epoch and i + 1 >= max_batches_per_epoch:
|
124 |
+
break
|
125 |
+
loss = torch.cat(losses).mean().item()
|
126 |
+
ppl = math.exp(min(20.0, loss))
|
127 |
+
return loss, ppl
|
128 |
+
|
129 |
+
def sample_text(prompt: str = "Once upon a time"):
|
130 |
+
model.eval()
|
131 |
+
with torch.no_grad():
|
132 |
+
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(acc.device)
|
133 |
+
gen = model.generate(ids, max_new_tokens=64)
|
134 |
+
text = tokenizer.decode(gen[0], skip_special_tokens=True)
|
135 |
+
return text
|
136 |
+
|
137 |
+
best_val = float("inf")
|
138 |
+
patience, bad_epochs = 3, 0
|
139 |
+
start_time = time.time()
|
140 |
+
for epoch in range(start_epoch, epochs):
|
141 |
+
model.train()
|
142 |
+
epoch_start = time.time()
|
143 |
+
optimizer.zero_grad()
|
144 |
+
running_loss = 0.0
|
145 |
+
|
146 |
+
for i, batch in enumerate(train_loader):
|
147 |
+
out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
148 |
+
loss = out.loss / accum_steps
|
149 |
+
acc.backward(loss)
|
150 |
+
|
151 |
+
if (i + 1) % accum_steps == 0:
|
152 |
+
optimizer.step()
|
153 |
+
optimizer.zero_grad()
|
154 |
+
|
155 |
+
running_loss += out.loss.detach().item()
|
156 |
+
global_step += 1
|
157 |
+
|
158 |
+
if sample_every and global_step % sample_every == 0 and acc.is_main_process:
|
159 |
+
results["samples"].append({"step": global_step, "text": sample_text()})
|
160 |
+
|
161 |
+
if max_batches_per_epoch and i + 1 >= max_batches_per_epoch:
|
162 |
+
break
|
163 |
+
|
164 |
+
if (i + 1) % accum_steps != 0:
|
165 |
+
optimizer.step()
|
166 |
+
optimizer.zero_grad()
|
167 |
+
|
168 |
+
train_time = time.time() - epoch_start
|
169 |
+
val_loss, val_ppl = evaluate()
|
170 |
+
|
171 |
+
try:
|
172 |
+
mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
173 |
+
except Exception:
|
174 |
+
mem = None
|
175 |
+
|
176 |
+
results["history"].append({
|
177 |
+
"epoch": epoch + 1,
|
178 |
+
"train_time_sec": train_time,
|
179 |
+
"val_loss": val_loss,
|
180 |
+
"val_ppl": val_ppl,
|
181 |
+
"max_cuda_mem_gb": mem,
|
182 |
+
"effective_batch_size": batch_size * accum_steps,
|
183 |
+
})
|
184 |
+
|
185 |
+
improve = val_loss < best_val - 1e-5
|
186 |
+
if improve:
|
187 |
+
best_val = val_loss
|
188 |
+
bad_epochs = 0
|
189 |
+
save_checkpoint(acc, model, optimizer, ckpt_path, epoch + 1, global_step, {"best_val": best_val})
|
190 |
+
else:
|
191 |
+
bad_epochs += 1
|
192 |
+
if bad_epochs >= patience:
|
193 |
+
acc.print("Early stopping triggered.")
|
194 |
+
break
|
195 |
+
|
196 |
+
if acc.is_main_process:
|
197 |
+
with open(results_file, "w") as f:
|
198 |
+
json.dump(results, f, indent=2)
|
199 |
+
|
200 |
+
total = time.time() - start_time
|
201 |
+
results["status"] = "completed"
|
202 |
+
results["total_training_time_sec"] = total
|
203 |
+
results["final_validation"] = {"loss": best_val, "perplexity": math.exp(min(20.0, best_val))}
|
204 |
+
if acc.is_main_process:
|
205 |
+
with open(results_file, "w") as f:
|
206 |
+
json.dump(results, f, indent=2)
|
207 |
+
acc.print(f"Done in {total/60:.1f} min. Best val {best_val:.4f}")
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
main()
|