alps / unitable /src /trainer /train_vqvae.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
import math
import wandb
from pathlib import Path
from typing import Tuple, List, Union, Dict
from omegaconf import DictConfig
from hydra.utils import instantiate
import logging
import torch
import time
from functools import partial
from torch import nn, Tensor, autograd
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torchvision.utils import make_grid
from ..utils import printer, compute_grad_norm
SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"])
class VqvaeTrainer:
def __init__(
self,
device: int,
model: nn.Module,
log: logging.Logger,
exp_dir: Path,
snapshot: Path = None,
model_weights: Path = None, # only for testing
) -> None:
self.device = device
self.log = log
self.exp_dir = exp_dir
assert (
snapshot is None or model_weights is None
), "Snapshot and model weights cannot be set at the same time."
self.model = model
if snapshot is not None and snapshot.is_file():
self.snapshot = self.load_snapshot(snapshot)
self.model.load_state_dict(self.snapshot["MODEL"])
self.start_epoch = self.snapshot["EPOCH"]
self.global_step = self.snapshot["STEP"]
elif model_weights is not None and model_weights.is_file():
self.load_model(model_weights)
else:
self.snapshot = None
self.start_epoch = 0
self.model = self.model.to(device)
self.model = DDP(self.model, device_ids=[device])
# https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113
torch.cuda.set_device(device) # master gpu takes up extra memory
torch.cuda.empty_cache()
def train_epoch(
self,
epoch: int,
starting_temp: float,
anneal_rate: float,
temp_min: float,
grad_clip: float = None,
):
start = time.time()
total_loss = 0.0
total_samples = 0
# load data from dataloader
for i, obj in enumerate(self.train_dataloader):
if isinstance(obj, Tensor):
img = obj.to(self.device)
elif isinstance(obj, (list, tuple)):
img = obj[0].to(self.device)
else:
raise ValueError(f"Unrecognized object type {type(obj)}")
# temperature annealing
self.temp = max(
starting_temp * math.exp(-anneal_rate * self.global_step), temp_min
)
with autograd.detect_anomaly():
loss, soft_recons = self.model(
img, return_loss=True, return_recons=True, temp=self.temp
)
self.optimizer.zero_grad()
loss.backward()
if grad_clip:
nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=grad_clip
)
self.optimizer.step()
loss = loss.detach().cpu().data
total_loss += loss * img.shape[0]
total_samples += img.shape[0]
self.lr_scheduler.step()
self.global_step += 1
if i % 10 == 0:
grad_norm = compute_grad_norm(self.model)
lr = self.optimizer.param_groups[0]["lr"]
elapsed = time.time() - start
self.log.info(
printer(
self.device,
f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f}) | Grad norm {grad_norm:.3f} | {total_samples / elapsed:4.1f} images/s | lr {lr:5.1e} | Temp {self.temp:.2e}",
)
)
# visualize reconstruction images
if i % 100 == 0 and self.device == 0:
lr = self.optimizer.param_groups[0]["lr"]
k = 4 # num of images saved for visualization
codes = self.model.module.get_codebook_indices(img[:k])
hard_recons = self.model.module.decode(codes)
img = img[:k].detach().cpu()
soft_recons = soft_recons[:k].detach().cpu()
codes = codes.flatten(start_dim=1).detach().cpu()
hard_recons = hard_recons.detach().cpu()
make_vis = partial(make_grid, nrow=int(math.sqrt(k)), normalize=True)
img, soft_recons, hard_recons = map(
make_vis, (img, soft_recons, hard_recons)
)
log_info = {
"epoch": epoch,
"train_loss": loss,
"temperature": self.temp,
"learning rate": lr,
"original images": wandb.Image(
img, caption=f"step: {self.global_step}"
),
"soft reconstruction": wandb.Image(
soft_recons, caption=f"step: {self.global_step}"
),
"hard reconstruction": wandb.Image(
hard_recons, caption=f"step: {self.global_step}"
),
"codebook_indices": wandb.Histogram(codes),
}
wandb.log(
log_info,
step=self.global_step,
)
return total_loss, total_samples
def train(
self,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
train_cfg: DictConfig,
valid_cfg: DictConfig,
):
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.optimizer = instantiate(
train_cfg.optimizer, params=self.model.parameters()
)
self.lr_scheduler = instantiate(
train_cfg.lr_scheduler, optimizer=self.optimizer
)
if self.snapshot is not None:
self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"])
self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"])
best_loss = float("inf")
self.model.train()
self.global_step = 0
# self.temp = train_cfg.starting_temp
for epoch in range(self.start_epoch, train_cfg.epochs):
train_dataloader.sampler.set_epoch(epoch)
epoch_loss, epoch_samples = self.train_epoch(
epoch,
starting_temp=train_cfg.starting_temp,
anneal_rate=train_cfg.temp_anneal_rate,
temp_min=train_cfg.temp_min,
grad_clip=train_cfg.grad_clip,
)
torch.cuda.empty_cache()
valid_loss, valid_samples = self.valid(valid_cfg)
# reduce loss to gpu 0
training_info = torch.tensor(
[epoch_loss, epoch_samples, valid_loss, valid_samples],
device=self.device,
)
dist.reduce(
training_info,
dst=0,
op=dist.ReduceOp.SUM,
)
if self.device == 0:
grad_norm = compute_grad_norm(self.model)
epoch_loss, epoch_samples, valid_loss, valid_samples = training_info
epoch_loss, valid_loss = (
float(epoch_loss) / epoch_samples,
float(valid_loss) / valid_samples,
)
log_info = {
"train loss (epoch)": epoch_loss,
"valid loss (epoch)": valid_loss,
"train_samples": epoch_samples,
"valid_samples": valid_samples,
"grad_norm": grad_norm,
}
wandb.log(
log_info,
step=self.global_step,
)
if epoch % train_cfg.save_every == 0:
self.save_snapshot(epoch, best_loss)
if valid_loss < best_loss:
self.save_model(epoch)
best_loss = valid_loss
def valid(self, cfg: DictConfig):
total_samples = 0
total_loss = 0.0
self.model.eval()
for i, obj in enumerate(self.valid_dataloader):
if isinstance(obj, Tensor):
img = obj.to(self.device)
elif isinstance(obj, (list, tuple)):
img = obj[0].to(self.device)
else:
raise ValueError(f"Unrecognized object type {type(obj)}")
with torch.no_grad():
loss = self.model(
img, return_loss=True, return_recons=False, temp=self.temp
)
loss = loss.detach().cpu().data
total_loss += loss * img.shape[0]
total_samples += img.shape[0]
if i % 10 == 0:
self.log.info(
printer(
self.device,
f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f})",
)
)
return total_loss, total_samples
def save_model(self, epoch: int):
filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt"
torch.save(self.model.module.state_dict(), filename)
self.log.info(printer(self.device, f"Saving model to {filename}"))
filename = Path(self.exp_dir) / "model" / f"best.pt"
torch.save(self.model.module.state_dict(), filename)
def load_model(self, path: Union[str, Path]):
self.model.load_state_dict(torch.load(path, map_location="cpu"))
self.log.info(printer(self.device, f"Loading model from {path}"))
def save_snapshot(self, epoch: int, best_loss: float):
state_info = {
"EPOCH": epoch + 1,
"STEP": self.global_step,
"OPTIMIZER": self.optimizer.state_dict(),
"LR_SCHEDULER": self.lr_scheduler.state_dict(),
"MODEL": self.model.module.state_dict(),
"LOSS": best_loss,
}
snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt"
torch.save(state_info, snapshot_path)
self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}"))
def load_snapshot(self, path: Path):
self.log.info(printer(self.device, f"Loading snapshot from {path}"))
snapshot = torch.load(path, map_location="cpu")
assert SNAPSHOT_KEYS.issubset(snapshot.keys())
return snapshot