Spaces:
Build error
Build error
File size: 6,933 Bytes
daf0288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
from typing import List, Tuple, Dict
import torch
from torch import Tensor, nn
from torchtext.vocab import Vocab
import tokenizers as tk
from ..utils import pred_token_within_range, subsequent_mask
from ..vocab import (
HTML_TOKENS,
TASK_TOKENS,
RESERVED_TOKENS,
BBOX_TOKENS,
)
VALID_HTML_TOKEN = ["<eos>"] + HTML_TOKENS
INVALID_CELL_TOKEN = (
["<sos>", "<pad>", "<empty>", "<sep>"] + TASK_TOKENS + RESERVED_TOKENS
)
VALID_BBOX_TOKEN = [
"<eos>"
] + BBOX_TOKENS # image size will be addressed after instantiation
class Batch:
"""Wrap up a batch of training samples with different training targets.
The input is not torch tensor
Shape of the image (src): B, S, E
Shape of the text (tgt): B, N, S, E (M includes 1 table detection, 1 structure, 1 cell, and multiple bbox)
Reshape text to (B * N, S, E) and inflate the image to match the shape of the text
Args:
----
device: gpu id
"""
def __init__(
self,
device: torch.device,
target: str,
vocab: Vocab,
obj: List,
) -> None:
self.device = device
self.image = obj[0].to(device)
self.name = obj[1]["filename"]
self.target = target
self.vocab = vocab
self.image_size = self.image.shape[-1]
if "table" in target:
raise NotImplementedError
if "html" in target:
self.valid_html_token = [vocab.token_to_id(i) for i in VALID_HTML_TOKEN]
(
self.html_src,
self.html_tgt,
self.html_casual_mask,
self.html_padding_mask,
) = self._prepare_transformer_input(obj[1]["html"])
if "cell" in target:
self.invalid_cell_token = [vocab.token_to_id(i) for i in INVALID_CELL_TOKEN]
(
self.cell_src,
self.cell_tgt,
self.cell_casual_mask,
self.cell_padding_mask,
) = self._prepare_transformer_input(obj[1]["cell"])
if "bbox" in target:
(
self.bbox_src,
self.bbox_tgt,
self.bbox_casual_mask,
self.bbox_padding_mask,
) = self._prepare_transformer_input(obj[1]["bbox"])
def _prepare_transformer_input(
self, seq: List[tk.Encoding]
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
tmp = [i.ids for i in seq]
tmp = torch.tensor(tmp, dtype=torch.int32)
src = tmp[:, :-1].to(self.device)
tgt = tmp[:, 1:].type(torch.LongTensor).to(self.device)
casual_mask = subsequent_mask(src.shape[-1]).to(self.device)
tmp = [i.attention_mask[:-1] for i in seq] # padding mask
tmp = torch.tensor(tmp, dtype=torch.bool)
padding_mask = (~tmp).to(self.device)
return src, tgt, casual_mask, padding_mask
def _inference_one_task(
self, model, memory, src, casual_mask, padding_mask, use_ddp
):
if use_ddp:
out = model.module.decode(memory, src, casual_mask, padding_mask)
out = model.module.generator(out)
else:
out = model.decode(memory, src, casual_mask, padding_mask)
out = model.generator(out)
return out
def inference(
self,
model: nn.Module,
criterion: nn.Module,
criterion_bbox: nn.Module = None,
loss_weights: dict = None,
use_ddp: bool = True,
) -> Tuple[Dict, Dict]:
pred = dict()
loss = dict(table=0, html=0, cell=0, bbox=0)
if use_ddp:
memory = model.module.encode(self.image)
else:
memory = model.encode(self.image)
# inference + suppress invalid logits + compute loss
if "html" in self.target:
out_html = self._inference_one_task(
model,
memory,
self.html_src,
self.html_casual_mask,
self.html_padding_mask,
use_ddp,
)
pred["html"] = pred_token_within_range(
out_html, white_list=self.valid_html_token
).permute(0, 2, 1)
loss["html"] = criterion(pred["html"], self.html_tgt)
if "cell" in self.target:
out_cell = self._inference_one_task(
model,
memory,
self.cell_src,
self.cell_casual_mask,
self.cell_padding_mask,
use_ddp,
)
pred["cell"] = pred_token_within_range(
out_cell, black_list=self.invalid_cell_token
).permute(0, 2, 1)
loss["cell"] = criterion(pred["cell"], self.cell_tgt)
if "bbox" in self.target:
assert criterion_bbox is not None
out_bbox = self._inference_one_task(
model,
memory,
self.bbox_src,
self.bbox_casual_mask,
self.bbox_padding_mask,
use_ddp,
)
pred["bbox"] = out_bbox.permute(0, 2, 1)
loss["bbox"] = criterion_bbox(pred["bbox"], self.bbox_tgt)
total = 0.0
for k, v in loss_weights.items():
total += loss[k] * v
loss["total"] = total
return loss, pred
def configure_optimizer_weight_decay(
model: nn.Module, weight_decay: float
) -> List[Dict]:
weight_decay_blacklist = (nn.LayerNorm, nn.BatchNorm2d, nn.Embedding)
if hasattr(model, "no_weight_decay"):
skip_list = model.no_weight_decay()
decay = set()
no_decay = set()
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, weight_decay_blacklist):
no_decay.add(fpn)
elif pn in skip_list:
no_decay.add(fpn)
param_dict = {pn: p for pn, p in model.named_parameters()}
decay = param_dict.keys() - no_decay
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
return optim_groups
def turn_off_beit_grad(model: nn.Module):
"Freeze BEiT pretrained weights."
for param in model.encoder.parameters():
param.requires_grad = False
for param in model.backbone.parameters():
param.requires_grad = False
for param in model.pos_embed.parameters():
param.requires_grad = False
def turn_on_beit_grad(model: nn.Module):
for param in model.parameters():
param.requires_grad = True
|