from typing import Tuple, List, Union, Dict, Optional import torch import wandb import json import os from torch import nn, Tensor, autograd from torch.utils.data import DataLoader from omegaconf import DictConfig from hydra.utils import instantiate import logging from pathlib import Path from torch.nn.parallel import DistributedDataParallel as DDP import tokenizers as tk import torch.nn.functional as F from .utils import ( Batch, configure_optimizer_weight_decay, turn_off_beit_grad, VALID_HTML_TOKEN, INVALID_CELL_TOKEN, VALID_BBOX_TOKEN, ) from ..utils import ( printer, compute_grad_norm, count_total_parameters, batch_autoregressive_decode, combine_filename_pred_gt, ) SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"]) class TableTrainer: """A trainer for table recognition. The supported tasks are: 1) table structure extraction 2) table cell bbox detection 3) table cell content recognition Args: ---- device: gpu id vocab: a vocab shared among all tasks model: model architecture log: logger exp_dir: the experiment directory that saves logs, wandb files, model weights, and checkpoints (snapshots) snapshot: specify which snapshot to use, only used in training model_weights: specify which model weight to use, only used in testing beit_pretrained_weights: load SSL pretrained visual encoder freeze_beit_epoch: freeze beit weights for the first {freeze_beit_epoch} epochs """ def __init__( self, device: int, vocab: tk.Tokenizer, model: nn.Module, log: logging.Logger, exp_dir: Path, snapshot: Path = None, model_weights: str = None, beit_pretrained_weights: str = None, freeze_beit_epoch: int = None, ) -> None: self.device = device self.log = log self.exp_dir = exp_dir self.vocab = vocab self.padding_idx = vocab.token_to_id("") self.freeze_beit_epoch = freeze_beit_epoch # loss for training html, cell self.criterion = nn.CrossEntropyLoss(ignore_index=self.padding_idx) self.model = model if ( beit_pretrained_weights is not None and Path(beit_pretrained_weights).is_file() ): self.load_pretrained_beit(Path(beit_pretrained_weights)) assert ( snapshot is None or model_weights is None ), "Cannot set snapshot and model_weights at the same time!" if snapshot is not None and snapshot.is_file(): self.snapshot = self.load_snapshot(snapshot) self.model.load_state_dict(self.snapshot["MODEL"]) self.start_epoch = self.snapshot["EPOCH"] self.global_step = self.snapshot["STEP"] elif model_weights is not None and Path(model_weights).is_file(): self.load_model(Path(model_weights)) else: self.snapshot = None self.start_epoch = 0 self.global_step = 0 if freeze_beit_epoch and freeze_beit_epoch > 0: self._freeze_beit() self.model = self.model.to(device) self.model = DDP(self.model, device_ids=[device]) # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 torch.cuda.set_device(device) # master gpu takes up extra memory torch.cuda.empty_cache() def _freeze_beit(self): if self.start_epoch < self.freeze_beit_epoch: turn_off_beit_grad(self.model) self.log.info( printer( self.device, f"Lock SSL params for {self.freeze_beit_epoch} epochs (params: {count_total_parameters(self.model) / 1e6:.2f}M) - Current epoch {self.start_epoch + 1}", ) ) else: self.log.info( printer( self.device, f"Unlock all weights (params: {count_total_parameters(self.model) / 1e6:.2f}M) - Current epoch {self.start_epoch + 1}", ) ) def train_epoch( self, epoch: int, target: str, loss_weights: List[float], grad_clip: float = None, ): avg_loss = 0.0 # load data from dataloader for i, obj in enumerate(self.train_dataloader): batch = Batch(device=self.device, target=target, vocab=self.vocab, obj=obj) with autograd.detect_anomaly(): loss, _ = batch.inference( self.model, criterion=self.criterion, criterion_bbox=self.criterion_bbox, loss_weights=loss_weights, ) total_loss = loss["total"] self.optimizer.zero_grad() total_loss.backward() if grad_clip: nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=grad_clip ) self.optimizer.step() total_loss = total_loss.detach().cpu().data avg_loss += total_loss self.lr_scheduler.step() self.global_step += 1 if i % 10 == 0: grad_norm = compute_grad_norm(self.model) lr = self.optimizer.param_groups[0]["lr"] # elapsed = time.time() - start loss_info = f"Loss {total_loss:.3f} ({avg_loss / (i + 1):.3f})" if not isinstance(loss["html"], int): loss_info += f" Html {loss['html'].detach().cpu().data:.3f}" if not isinstance(loss["cell"], int): loss_info += f" Cell {loss['cell'].detach().cpu().data:.3f}" if not isinstance(loss["bbox"], int): loss_info += f" Bbox {loss['bbox'].detach().cpu().data:.3f}" self.log.info( printer( self.device, f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | {loss_info} | Grad norm {grad_norm:.3f} | lr {lr:5.1e}", ) ) if i % 100 == 0 and self.device == 0: log_info = { "epoch": epoch, "train_total_loss": total_loss, "learning rate": lr, "grad_norm": grad_norm, } wandb.log( log_info, step=self.global_step, ) def train( self, train_dataloader: DataLoader, valid_dataloader: DataLoader, train_cfg: DictConfig, valid_cfg: DictConfig, ): self.train_dataloader = train_dataloader self.valid_dataloader = valid_dataloader # ensure correct weight decay: https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215 optim_params = configure_optimizer_weight_decay( self.model.module, weight_decay=train_cfg.optimizer.weight_decay ) self.optimizer = instantiate(train_cfg.optimizer, optim_params) self.lr_scheduler = instantiate( train_cfg.lr_scheduler, optimizer=self.optimizer ) if self.snapshot is not None: self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"]) self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"]) self.criterion_bbox = None if "bbox" in train_cfg.target: tmp = [ self.vocab.token_to_id(i) for i in VALID_BBOX_TOKEN[ : train_cfg.img_size[0] + 2 ] # +1 for +1 for bbox == img_size ] tmp = [1.0 if i in tmp else 0.0 for i in range(self.vocab.get_vocab_size())] self.criterion_bbox = nn.CrossEntropyLoss( weight=torch.tensor(tmp, device=self.device), ignore_index=self.padding_idx, ) best_loss = float("inf") self.model.train() if self.freeze_beit_epoch and self.start_epoch < self.freeze_beit_epoch: max_epoch = self.freeze_beit_epoch else: max_epoch = train_cfg.epochs for epoch in range(self.start_epoch, max_epoch): train_dataloader.sampler.set_epoch(epoch) self.train_epoch( epoch, grad_clip=train_cfg.grad_clip, target=train_cfg.target, loss_weights=train_cfg.loss_weights, ) torch.cuda.empty_cache() valid_loss = self.valid(valid_cfg) if self.device == 0: wandb.log( {"valid loss (epoch)": valid_loss}, step=self.global_step, ) if epoch % train_cfg.save_every == 0: self.save_snapshot(epoch, best_loss) if valid_loss < best_loss: self.save_model(epoch) best_loss = valid_loss def valid(self, cfg: DictConfig): total_loss = 0.0 avg_loss = 0.0 total_samples = 0 self.model.eval() for i, obj in enumerate(self.valid_dataloader): batch = Batch( device=self.device, target=cfg.target, vocab=self.vocab, obj=obj ) with torch.no_grad(): loss, _ = batch.inference( self.model, criterion=self.criterion, criterion_bbox=self.criterion_bbox, loss_weights=cfg.loss_weights, ) total_loss = loss["total"] total_loss = total_loss.detach().cpu().data avg_loss += total_loss * batch.image.shape[0] total_samples += batch.image.shape[0] if i % 10 == 0: loss_info = f"Loss {total_loss:.3f} ({avg_loss / total_samples:.3f})" if not isinstance(loss["html"], int): loss_info += f" Html {loss['html'].detach().cpu().data:.3f}" if not isinstance(loss["cell"], int): loss_info += f" Cell {loss['cell'].detach().cpu().data:.3f}" if not isinstance(loss["bbox"], int): loss_info += f" Bbox {loss['bbox'].detach().cpu().data:.3f}" self.log.info( printer( self.device, f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | {loss_info}", ) ) return avg_loss / total_samples def test(self, test_dataloader: DataLoader, cfg: DictConfig, save_to: str): total_result = dict() for i, obj in enumerate(test_dataloader): batch = Batch( device=self.device, target=cfg.target, vocab=self.vocab, obj=obj ) if cfg.target == "html": prefix = [self.vocab.token_to_id("[html]")] valid_token_whitelist = [ self.vocab.token_to_id(i) for i in VALID_HTML_TOKEN ] valid_token_blacklist = None elif cfg.target == "cell": prefix = [self.vocab.token_to_id("[cell]")] valid_token_whitelist = None valid_token_blacklist = [ self.vocab.token_to_id(i) for i in INVALID_CELL_TOKEN ] elif cfg.target == "bbox": prefix = [self.vocab.token_to_id("[bbox]")] valid_token_whitelist = [ self.vocab.token_to_id(i) for i in VALID_BBOX_TOKEN[: cfg.img_size[0]] ] valid_token_blacklist = None else: raise NotImplementedError pred_id = batch_autoregressive_decode( device=self.device, model=self.model, batch_data=batch, prefix=prefix, max_decode_len=cfg.max_seq_len, eos_id=self.vocab.token_to_id(""), valid_token_whitelist=valid_token_whitelist, valid_token_blacklist=valid_token_blacklist, sampling=cfg.sampling, ) if cfg.target == "html": result = combine_filename_pred_gt( filename=batch.name, pred_id=pred_id, gt_id=batch.html_tgt, vocab=self.vocab, type="html", ) elif cfg.target == "cell": result = combine_filename_pred_gt( filename=batch.name, pred_id=pred_id, gt_id=batch.cell_tgt, vocab=self.vocab, type="cell", ) elif cfg.target == "bbox": result = combine_filename_pred_gt( filename=batch.name, pred_id=pred_id, gt_id=batch.bbox_tgt, vocab=self.vocab, type="bbox", ) else: raise NotImplementedError total_result.update(result) if i % 10 == 0: self.log.info( printer( self.device, f"Test: Step {i + 1}/{len(test_dataloader)}", ) ) self.log.info( printer( self.device, f"Converting {len(total_result)} samples to html tables ...", ) ) with open( os.path.join(save_to, cfg.save_to_prefix + f"_{self.device}.json"), "w", encoding="utf-8", ) as f: json.dump(total_result, f, indent=4) return total_result def save_model(self, epoch: int): filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt" torch.save(self.model.module.state_dict(), filename) self.log.info(printer(self.device, f"Saving model to {filename}")) filename = Path(self.exp_dir) / "model" / "best.pt" torch.save(self.model.module.state_dict(), filename) def load_model(self, path: Union[str, Path]): self.model.load_state_dict(torch.load(path, map_location="cpu")) self.log.info(printer(self.device, f"Loading model from {path}")) def save_snapshot(self, epoch: int, best_loss: float): state_info = { "EPOCH": epoch + 1, "STEP": self.global_step, "OPTIMIZER": self.optimizer.state_dict(), "LR_SCHEDULER": self.lr_scheduler.state_dict(), "MODEL": self.model.module.state_dict(), "LOSS": best_loss, } snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt" torch.save(state_info, snapshot_path) self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}")) def load_snapshot(self, path: Path): self.log.info(printer(self.device, f"Loading snapshot from {path}")) snapshot = torch.load(path, map_location="cpu") assert SNAPSHOT_KEYS.issubset(snapshot.keys()) return snapshot def load_pretrained_beit(self, path: Path): self.log.info(printer(self.device, f"Loading pretrained BEiT from {path}")) beit = torch.load(path, map_location="cpu") redundant_keys_in_beit = [ "cls_token", "mask_token", "generator.weight", "generator.bias", ] for key in redundant_keys_in_beit: if key in beit: del beit[key] # max_seq_len in finetuning may go beyond the length in pretraining if ( self.model.pos_embed.embedding.weight.shape[0] != beit["pos_embed.embedding.weight"].shape[0] ): emb_shape = self.model.pos_embed.embedding.weight.shape ckpt_emb = beit["pos_embed.embedding.weight"].clone() assert emb_shape[1] == ckpt_emb.shape[1] ckpt_emb = ckpt_emb.unsqueeze(0).permute(0, 2, 1) ckpt_emb = F.interpolate(ckpt_emb, emb_shape[0], mode="nearest") beit["pos_embed.embedding.weight"] = ckpt_emb.permute(0, 2, 1).squeeze() out = self.model.load_state_dict(beit, strict=False) # ensure missing keys are just token_embed, decoder, and generator missing_keys_prefix = ("token_embed", "decoder", "generator") for key in out[0]: assert key.startswith( missing_keys_prefix ), f"Key {key} should be loaded from BEiT, but missing in current state dict." assert len(out[1]) == 0, f"Unexpected keys from BEiT: {out[1]}"