yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import List, Tuple
import random
import tokenizers as tk
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from ..vocab import TASK_TOKENS, CELL_SPECIAL
from ..model.encoderdecoder import EncoderDecoder
from .misc import html_table_template
__all__ = [
def subsequent_mask(size: int, pad: int = 0):
attn_shape = (size, size)
output = torch.triu(torch.ones(attn_shape), diagonal=1).to(torch.bool)
if pad and pad > 0:
output[:pad] = False
return output
def combine_cell_char_seq(seq: List[str]) -> str:
"""Replace empty token with <empty> in vocab. combine characters into a str"""
if seq:
out = "".join(seq)
out = "<empty>"
return out
def prepare_html_seq(seq: List[str]) -> List[str]:
"""Convert html annotations to html training template."""
out = ["[html]", *seq, "<eos>"]
return out
def prepare_cell_seq(seq: str) -> List[str]:
"""Convert cell sequence to training template."""
for black in CELL_SPECIAL:
seq = seq.replace(black, "")
out = ["[cell]", seq, "<eos>"]
return out
def prepare_bbox_seq(seq: List[dict]):
tmp = [f"bbox-{round(i)}" for i in seq]
out = ["[bbox]"] + tmp + ["<eos>"]
return out
def random_continuous_sequence(seq: List, N: int, length: int = 10) -> List:
"""Randomly sample a continuous sub-sequence from a sequence for N times."""
start_idx = [random.randrange(len(seq)) for _ in range(N)]
subseq_len = [random.randrange(1, length) for _ in range(N)]
output = [(i, min(i + j, len(seq))) for i, j in zip(start_idx, subseq_len)]
return output
# def prepare_bbox_seq(
# seq: List[dict],
# N: int,
# delimiter: str = "<sep>",
# ) -> List[List[str]]:
# """Convert the annotation to bbox input/output sequence."""
# out = list()
# # bbox_loss_start_idx = list()
# subseq_idx = random_continuous_sequence(seq, N)
# for idx in subseq_idx:
# entry = seq[idx[0] : idx[1]]
# tmp = list()
# bbox_seq = list()
# for i in entry:
# if "tokens" in i.keys():
# # pubtabnet and synthtabnet
# tmp.append(combine_cell_char_seq(i["tokens"]))
# if "bbox" in i.keys():
# bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]])
# elif "text" in i.keys():
# # pubtables and icdar
# tmp.append(i["text"])
# if "bbox" in i.keys():
# bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]])
# cell_seq = [delimiter] * len(tmp)
# cell_seq = [q for pair in zip(tmp, cell_seq) for q in pair]
# cell_seq = ["[bbox]", f"{len(entry)}-cell(s)", delimiter] + cell_seq
# bbox_seq.append("<eos>")
# # bbox_loss_start_idx.append(len(cell_seq))
# out.append(cell_seq + bbox_seq)
# return out
def html_str_to_token_list(
seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None
) -> List[str]:
"""Convert decode output (str) to a list of tokens for constructing html table code"""
# works for no <eos>
seq = seq.split("<eos>")[0]
token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
for i in token_black_list:
seq = seq.replace(i, "")
if not splitter:
splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="contiguous")
seq = splitter.pre_tokenize_str(seq)
# only preserve the space for spanning cell tokens
seq = [i[0] for i in seq if len(i[0].strip()) != 0 or i[1][1] - i[1][0] != 1]
return seq
def cell_str_to_token_list(seq: str) -> List[str]:
seq = seq.split("<eos>")[0]
token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
for i in token_black_list:
seq = seq.replace(i, "")
seq = seq.strip()
return seq
def build_table_from_html_and_cell(
structure: List[str], content: List[str] = None
) -> List[str]:
"""Build table from html and cell token list"""
assert structure is not None
html_code = list()
# deal with empty table
if content is None:
content = ["placeholder"] * len(structure)
for tag in structure:
if tag in ("<td>[]</td>", ">[]</td>"):
if len(content) == 0:
cell = content.pop(0)
html_code.append(tag.replace("[]", cell))
return html_code
def bbox_str_to_token_list(
seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None
) -> List[List[int]]:
Note the out could be an empty list
[[ymin, xmin, ymax, xmax],
[ymin, xmin, ymax, xmax],
seq = seq.split("<eos>")[0]
token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
for i in token_black_list:
seq = seq.replace(i, "")
if not splitter:
splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="removed")
seq = splitter.pre_tokenize_str(seq)
seq = [int(i[0].split("-")[1]) for i in seq]
rounded_seq_len = len(seq) // 4 * 4
out = [seq[i : i + 4] for i in range(0, rounded_seq_len, 4)]
return out
def pred_token_within_range(
pred: Tensor,
white_list: List[int] = None,
black_list: List[int] = None,
) -> Tensor:
assert white_list is None or black_list is None
if white_list:
total = set([i for i in range(pred.shape[-1])])
black_list = list(total.difference(set(white_list)))
pred[..., black_list] = -float("inf")
return pred
def greedy_sampling(logits: Tensor):
"""logits should have shape [B, |V|]."""
probs = F.softmax(logits, dim=-1)
next_probs, next_tokens = probs.topk(1)
return next_probs, next_tokens
def batch_autoregressive_decode(
device: int,
model: EncoderDecoder,
prefix: List[int],
max_decode_len: int,
eos_id: int,
valid_token_whitelist: List[int] = None,
valid_token_blacklist: List[int] = None,
sampling: str = "greedy",
use_ddp: bool = True,
) -> Tensor:
"""Auto-regressively generate the output."""
with torch.no_grad():
if use_ddp:
memory = model.module.encode(batch_data.image)
memory = model.encode(batch_data.image)
B = batch_data.image.shape[0]
context = torch.tensor(prefix, dtype=torch.int32).repeat(B, 1).to(device)
for _ in range(max_decode_len):
eos_flag = [eos_id in k for k in context]
if all(eos_flag):
# as long as one sample hasn't reached <eos>, continue decoding until the max seq len
causal_mask = subsequent_mask(context.shape[1]).to(device)
with torch.no_grad():
if use_ddp:
logits = model.module.decode(
memory, context, tgt_mask=causal_mask, tgt_padding_mask=None
logits = model.module.generator(logits)[:, -1, :]
logits = model.decode(
memory, context, tgt_mask=causal_mask, tgt_padding_mask=None
logits = model.generator(logits)[:, -1, :]
logits = pred_token_within_range(
white_list=valid_token_whitelist if valid_token_whitelist else None,
black_list=valid_token_blacklist if valid_token_blacklist else None,
if sampling == "greedy":
next_probs, next_tokens = greedy_sampling(logits)
raise NotImplementedError
context = torch.cat([context, next_tokens], dim=1)
return context
def combine_filename_pred_gt(
filename: List[str], pred_id: Tensor, gt_id: Tensor, vocab: tk.Tokenizer, type: str
) -> dict:
out = dict()
assert len(filename) == len(pred_id)
pred_id = pred_id.detach().cpu().numpy()
gt_id = gt_id.detach().cpu().numpy()
pred_token = vocab.decode_batch(pred_id, skip_special_tokens=False)
gt_token = vocab.decode_batch(gt_id, skip_special_tokens=False)
for idx, name in enumerate(filename):
if type == "html":
pred_token_list = html_str_to_token_list(pred_token[idx])
gt_token_list = html_str_to_token_list(gt_token[idx])
elif type == "cell":
pred_token_list = cell_str_to_token_list(pred_token[idx])
gt_token_list = cell_str_to_token_list(gt_token[idx])
elif type == "bbox":
pred_token_list = bbox_str_to_token_list(pred_token[idx])
gt_token_list = bbox_str_to_token_list(gt_token[idx])
raise ValueError(
f"The supported tasks are html, cell and bbox, while {type} is provided."
out[name] = dict(pred=pred_token_list, gt=gt_token_list)
return out