from typing import List, Tuple import random import tokenizers as tk import torch from torch import Tensor, nn import torch.nn.functional as F from ..vocab import TASK_TOKENS, CELL_SPECIAL from ..model.encoderdecoder import EncoderDecoder from .misc import html_table_template __all__ = [ "subsequent_mask", "combine_cell_char_seq", "random_continuous_sequence", "prepare_html_seq", "prepare_cell_seq", "prepare_bbox_seq", "html_str_to_token_list", "cell_str_to_token_list", "bbox_str_to_token_list", "pred_token_within_range", "batch_autoregressive_decode", "greedy_sampling", "combine_filename_pred_gt", "build_table_from_html_and_cell" ] def subsequent_mask(size: int, pad: int = 0): attn_shape = (size, size) output = torch.triu(torch.ones(attn_shape), diagonal=1).to(torch.bool) if pad and pad > 0: output[:pad] = False return output def combine_cell_char_seq(seq: List[str]) -> str: """Replace empty token with in vocab. combine characters into a str""" if seq: out = "".join(seq) else: out = "" return out def prepare_html_seq(seq: List[str]) -> List[str]: """Convert html annotations to html training template.""" out = ["[html]", *seq, ""] return out def prepare_cell_seq(seq: str) -> List[str]: """Convert cell sequence to training template.""" for black in CELL_SPECIAL: seq = seq.replace(black, "") out = ["[cell]", seq, ""] return out def prepare_bbox_seq(seq: List[dict]): tmp = [f"bbox-{round(i)}" for i in seq] out = ["[bbox]"] + tmp + [""] return out def random_continuous_sequence(seq: List, N: int, length: int = 10) -> List: """Randomly sample a continuous sub-sequence from a sequence for N times.""" start_idx = [random.randrange(len(seq)) for _ in range(N)] subseq_len = [random.randrange(1, length) for _ in range(N)] output = [(i, min(i + j, len(seq))) for i, j in zip(start_idx, subseq_len)] return output # def prepare_bbox_seq( # seq: List[dict], # N: int, # delimiter: str = "", # ) -> List[List[str]]: # """Convert the annotation to bbox input/output sequence.""" # out = list() # # bbox_loss_start_idx = list() # subseq_idx = random_continuous_sequence(seq, N) # for idx in subseq_idx: # entry = seq[idx[0] : idx[1]] # tmp = list() # bbox_seq = list() # for i in entry: # if "tokens" in i.keys(): # # pubtabnet and synthtabnet # tmp.append(combine_cell_char_seq(i["tokens"])) # if "bbox" in i.keys(): # bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]]) # elif "text" in i.keys(): # # pubtables and icdar # tmp.append(i["text"]) # if "bbox" in i.keys(): # bbox_seq.extend([f"bbox-{round(j)}" for j in i["bbox"]]) # cell_seq = [delimiter] * len(tmp) # cell_seq = [q for pair in zip(tmp, cell_seq) for q in pair] # cell_seq = ["[bbox]", f"{len(entry)}-cell(s)", delimiter] + cell_seq # bbox_seq.append("") # # bbox_loss_start_idx.append(len(cell_seq)) # out.append(cell_seq + bbox_seq) # return out def html_str_to_token_list( seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None ) -> List[str]: """Convert decode output (str) to a list of tokens for constructing html table code""" # works for no seq = seq.split("")[0] token_black_list = ["", "", *TASK_TOKENS] for i in token_black_list: seq = seq.replace(i, "") if not splitter: splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="contiguous") seq = splitter.pre_tokenize_str(seq) # only preserve the space for spanning cell tokens seq = [i[0] for i in seq if len(i[0].strip()) != 0 or i[1][1] - i[1][0] != 1] return seq def cell_str_to_token_list(seq: str) -> List[str]: seq = seq.split("")[0] token_black_list = ["", "", *TASK_TOKENS] for i in token_black_list: seq = seq.replace(i, "") seq = seq.strip() return seq def build_table_from_html_and_cell( structure: List[str], content: List[str] = None ) -> List[str]: """Build table from html and cell token list""" assert structure is not None html_code = list() # deal with empty table if content is None: content = ["placeholder"] * len(structure) for tag in structure: if tag in ("[]", ">[]"): if len(content) == 0: continue cell = content.pop(0) html_code.append(tag.replace("[]", cell)) else: html_code.append(tag) return html_code def bbox_str_to_token_list( seq: str, splitter: tk.pre_tokenizers.PreTokenizer = None ) -> List[List[int]]: """ Note the out could be an empty list return [[ymin, xmin, ymax, xmax], [ymin, xmin, ymax, xmax], ... ] """ seq = seq.split("")[0] token_black_list = ["", "", *TASK_TOKENS] for i in token_black_list: seq = seq.replace(i, "") if not splitter: splitter = tk.pre_tokenizers.Split(pattern=" ", behavior="removed") seq = splitter.pre_tokenize_str(seq) seq = [int(i[0].split("-")[1]) for i in seq] rounded_seq_len = len(seq) // 4 * 4 out = [seq[i : i + 4] for i in range(0, rounded_seq_len, 4)] return out def pred_token_within_range( pred: Tensor, white_list: List[int] = None, black_list: List[int] = None, ) -> Tensor: assert white_list is None or black_list is None if white_list: total = set([i for i in range(pred.shape[-1])]) black_list = list(total.difference(set(white_list))) pred[..., black_list] = -float("inf") return pred def greedy_sampling(logits: Tensor): """logits should have shape [B, |V|].""" probs = F.softmax(logits, dim=-1) next_probs, next_tokens = probs.topk(1) return next_probs, next_tokens def batch_autoregressive_decode( device: int, model: EncoderDecoder, batch_data, prefix: List[int], max_decode_len: int, eos_id: int, valid_token_whitelist: List[int] = None, valid_token_blacklist: List[int] = None, sampling: str = "greedy", use_ddp: bool = True, ) -> Tensor: """Auto-regressively generate the output.""" model.eval() with torch.no_grad(): if use_ddp: memory = model.module.encode(batch_data.image) else: memory = model.encode(batch_data.image) B = batch_data.image.shape[0] context = torch.tensor(prefix, dtype=torch.int32).repeat(B, 1).to(device) for _ in range(max_decode_len): eos_flag = [eos_id in k for k in context] if all(eos_flag): break # as long as one sample hasn't reached , continue decoding until the max seq len causal_mask = subsequent_mask(context.shape[1]).to(device) with torch.no_grad(): if use_ddp: logits = model.module.decode( memory, context, tgt_mask=causal_mask, tgt_padding_mask=None ) logits = model.module.generator(logits)[:, -1, :] else: logits = model.decode( memory, context, tgt_mask=causal_mask, tgt_padding_mask=None ) logits = model.generator(logits)[:, -1, :] logits = pred_token_within_range( logits.detach(), white_list=valid_token_whitelist if valid_token_whitelist else None, black_list=valid_token_blacklist if valid_token_blacklist else None, ) if sampling == "greedy": next_probs, next_tokens = greedy_sampling(logits) else: raise NotImplementedError context = torch.cat([context, next_tokens], dim=1) return context def combine_filename_pred_gt( filename: List[str], pred_id: Tensor, gt_id: Tensor, vocab: tk.Tokenizer, type: str ) -> dict: out = dict() assert len(filename) == len(pred_id) pred_id = pred_id.detach().cpu().numpy() gt_id = gt_id.detach().cpu().numpy() pred_token = vocab.decode_batch(pred_id, skip_special_tokens=False) gt_token = vocab.decode_batch(gt_id, skip_special_tokens=False) for idx, name in enumerate(filename): if type == "html": pred_token_list = html_str_to_token_list(pred_token[idx]) gt_token_list = html_str_to_token_list(gt_token[idx]) elif type == "cell": pred_token_list = cell_str_to_token_list(pred_token[idx]) gt_token_list = cell_str_to_token_list(gt_token[idx]) elif type == "bbox": pred_token_list = bbox_str_to_token_list(pred_token[idx]) gt_token_list = bbox_str_to_token_list(gt_token[idx]) else: raise ValueError( f"The supported tasks are html, cell and bbox, while {type} is provided." ) out[name] = dict(pred=pred_token_list, gt=gt_token_list) return out