yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import List, Tuple, Dict
import torch
from torch import Tensor, nn
from torchtext.vocab import Vocab
import tokenizers as tk
from ..utils import pred_token_within_range, subsequent_mask
from ..vocab import (
HTML_TOKENS,
TASK_TOKENS,
RESERVED_TOKENS,
BBOX_TOKENS,
)
VALID_HTML_TOKEN = ["<eos>"] + HTML_TOKENS
INVALID_CELL_TOKEN = (
["<sos>", "<pad>", "<empty>", "<sep>"] + TASK_TOKENS + RESERVED_TOKENS
)
VALID_BBOX_TOKEN = [
"<eos>"
] + BBOX_TOKENS # image size will be addressed after instantiation
class Batch:
"""Wrap up a batch of training samples with different training targets.
The input is not torch tensor
Shape of the image (src): B, S, E
Shape of the text (tgt): B, N, S, E (M includes 1 table detection, 1 structure, 1 cell, and multiple bbox)
Reshape text to (B * N, S, E) and inflate the image to match the shape of the text
Args:
----
device: gpu id
"""
def __init__(
self,
device: torch.device,
target: str,
vocab: Vocab,
obj: List,
) -> None:
self.device = device
self.image = obj[0].to(device)
self.name = obj[1]["filename"]
self.target = target
self.vocab = vocab
self.image_size = self.image.shape[-1]
if "table" in target:
raise NotImplementedError
if "html" in target:
self.valid_html_token = [vocab.token_to_id(i) for i in VALID_HTML_TOKEN]
(
self.html_src,
self.html_tgt,
self.html_casual_mask,
self.html_padding_mask,
) = self._prepare_transformer_input(obj[1]["html"])
if "cell" in target:
self.invalid_cell_token = [vocab.token_to_id(i) for i in INVALID_CELL_TOKEN]
(
self.cell_src,
self.cell_tgt,
self.cell_casual_mask,
self.cell_padding_mask,
) = self._prepare_transformer_input(obj[1]["cell"])
if "bbox" in target:
(
self.bbox_src,
self.bbox_tgt,
self.bbox_casual_mask,
self.bbox_padding_mask,
) = self._prepare_transformer_input(obj[1]["bbox"])
def _prepare_transformer_input(
self, seq: List[tk.Encoding]
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
tmp = [i.ids for i in seq]
tmp = torch.tensor(tmp, dtype=torch.int32)
src = tmp[:, :-1].to(self.device)
tgt = tmp[:, 1:].type(torch.LongTensor).to(self.device)
casual_mask = subsequent_mask(src.shape[-1]).to(self.device)
tmp = [i.attention_mask[:-1] for i in seq] # padding mask
tmp = torch.tensor(tmp, dtype=torch.bool)
padding_mask = (~tmp).to(self.device)
return src, tgt, casual_mask, padding_mask
def _inference_one_task(
self, model, memory, src, casual_mask, padding_mask, use_ddp
):
if use_ddp:
out = model.module.decode(memory, src, casual_mask, padding_mask)
out = model.module.generator(out)
else:
out = model.decode(memory, src, casual_mask, padding_mask)
out = model.generator(out)
return out
def inference(
self,
model: nn.Module,
criterion: nn.Module,
criterion_bbox: nn.Module = None,
loss_weights: dict = None,
use_ddp: bool = True,
) -> Tuple[Dict, Dict]:
pred = dict()
loss = dict(table=0, html=0, cell=0, bbox=0)
if use_ddp:
memory = model.module.encode(self.image)
else:
memory = model.encode(self.image)
# inference + suppress invalid logits + compute loss
if "html" in self.target:
out_html = self._inference_one_task(
model,
memory,
self.html_src,
self.html_casual_mask,
self.html_padding_mask,
use_ddp,
)
pred["html"] = pred_token_within_range(
out_html, white_list=self.valid_html_token
).permute(0, 2, 1)
loss["html"] = criterion(pred["html"], self.html_tgt)
if "cell" in self.target:
out_cell = self._inference_one_task(
model,
memory,
self.cell_src,
self.cell_casual_mask,
self.cell_padding_mask,
use_ddp,
)
pred["cell"] = pred_token_within_range(
out_cell, black_list=self.invalid_cell_token
).permute(0, 2, 1)
loss["cell"] = criterion(pred["cell"], self.cell_tgt)
if "bbox" in self.target:
assert criterion_bbox is not None
out_bbox = self._inference_one_task(
model,
memory,
self.bbox_src,
self.bbox_casual_mask,
self.bbox_padding_mask,
use_ddp,
)
pred["bbox"] = out_bbox.permute(0, 2, 1)
loss["bbox"] = criterion_bbox(pred["bbox"], self.bbox_tgt)
total = 0.0
for k, v in loss_weights.items():
total += loss[k] * v
loss["total"] = total
return loss, pred
def configure_optimizer_weight_decay(
model: nn.Module, weight_decay: float
) -> List[Dict]:
weight_decay_blacklist = (nn.LayerNorm, nn.BatchNorm2d, nn.Embedding)
if hasattr(model, "no_weight_decay"):
skip_list = model.no_weight_decay()
decay = set()
no_decay = set()
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, weight_decay_blacklist):
no_decay.add(fpn)
elif pn in skip_list:
no_decay.add(fpn)
param_dict = {pn: p for pn, p in model.named_parameters()}
decay = param_dict.keys() - no_decay
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
return optim_groups
def turn_off_beit_grad(model: nn.Module):
"Freeze BEiT pretrained weights."
for param in model.encoder.parameters():
param.requires_grad = False
for param in model.backbone.parameters():
param.requires_grad = False
for param in model.pos_embed.parameters():
param.requires_grad = False
def turn_on_beit_grad(model: nn.Module):
for param in model.parameters():
param.requires_grad = True