import torch import numpy as np import copy from collections import OrderedDict import json from datasets import ClassLabel import random import math from functools import lru_cache from matplotlib import font_manager from colorama import Fore, Style, init class BaseQuantizer: @property def ignore_tokens(self): if self.num_mask_tokens > 0: if self.mask_type == 'cm3': return [self.predict_start_token] + self.mask_tokens elif self.mask_type == 'mask_aug': return [self.mask_aug_token] else: raise ValueError(f'Invalid mask type {self.mask_type}') else: return [] def __init__(self, simplify_json=False, mask_all=False, num_mask_tokens=0, mask_type='cm3', **kwargs): self.simplify_json=simplify_json self.io_ignore_replace_tokens = ['<split-text>'] self.mask_all = mask_all self.num_mask_tokens = num_mask_tokens self.mask_type = mask_type if self.mask_type == 'mask_aug': self.mask_aug_token = '<mask-aug>' elif self.mask_type == 'cm3': self.predict_start_token = '<pred-start>' else: raise ValueError(f'Invalid mask type {self.mask_type}') def get_additional_mask_tokens(self): if self.mask_type == 'cm3': # 两种配置:1. ['<pred-start>'] + '<mask-%d>',数量和self.num_mask_tokens相关 2. ['<mask-aug>'] self.mask_tokens = ['<mask-%d>' % i for i in range(self.num_mask_tokens)] return [self.predict_start_token] + self.mask_tokens elif self.mask_type == 'mask_aug': return [self.mask_aug_token] else: raise ValueError(f'Invalid mask type {self.mask_type}') def dump2json(self, json_example): if self.simplify_json: # 将 dict 转化为 str, 如果simplify_json is True,那么缩减空格和换行,删除token的双引号 content = json.dumps(json_example, separators=(',',':')) for token in self.additional_special_tokens: content = content.replace(f'"{token}"', token) else: content = json.dumps(json_example) return content def load_json(self, content): # 将str转化为json replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change if self.simplify_json: for token in replace_tokens: # 如果simplify_json is True,那么为 token 添加双引号 content = content.replace(token, f'"{token}"') return json.loads(content) def apply_masking(self, json_example, mask_all=None, return_meta=False, target_keys=['width', 'height', 'left', 'top'], target_element_types=None ): if mask_all is None: mask_all = self.mask_all json_example = copy.deepcopy(json_example) target_keys = set(target_keys) target_tokens = [] for shape_i, shape in enumerate(json_example['layers']['textlayer']): # element_type = self.general_dequantize(shape['type'],'type',to_float=False) # if target_element_types is not None: # if element_type not in target_element_types: # continue for key_i, key in enumerate(shape.keys()): if key in target_keys: target_tokens.append((shape_i, key_i, key, shape[key])) if not mask_all: target_num_mask_tokens = random.randint(1, self.num_mask_tokens) if len(target_tokens) > target_num_mask_tokens: random.shuffle(target_tokens) target_tokens = target_tokens[:target_num_mask_tokens] # sort by shape_i and key_i target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1]) else: if len(target_tokens) > self.num_mask_tokens: # 取最后面几个 target_tokens = target_tokens[-self.num_mask_tokens:] tuples = [] meta_infos = [] for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens): if self.mask_type == 'cm3': mask_token = self.mask_tokens[mask_i] elif self.mask_type == 'mask_aug': mask_token = self.mask_aug_token else: raise ValueError(f'Invalid mask type {self.mask_type}') # <one-1><decimal0-1><decimal1-2> if '<' in value: num_token = value.count('<') else: num_token = value.count(' ') json_example['layers']['textlayer'][shape_i][key] = mask_token tuples.append((mask_token, value, num_token)) meta_infos.append((shape_i,key)) if return_meta: return json_example, tuples, meta_infos else: return json_example, tuples def make_prediction_postfix(self, tuples): postfix = self.predict_start_token for mask_token, value, num_token in tuples: postfix = postfix+ f'{mask_token}{value}' return postfix # specs={ # "width":"size", # "height":"size", # "left":"pos", # "top":"pos", # "x":"pos", # center x # "y":"pos", # center y # "opacity":"opacity", # "color":"color", # "angle":"angle", # "font_size":"font_size", # 'ratio':'ratio', # 'letter_spacing': 'spacing', # 'textlen': 'textlen' # } specs={ "width":"size", "height":"size", "x":"pos", # center x "y":"pos", # center y "color":"color", "font":"font" } # TODO change min_max_bins # min_max_bins = { # 'size':(0,2,256), # 'pos':(-1,1,256), # # 'opacity':(0,1,8), # 'opacity':(0,255,8), # 'color':(0,255,32), # 'angle':(0,2*np.pi,64), # 'font_size':(2,200,100), # 'spacing': (0,1,40), # 'textlen': (1,20,20) # } min_max_bins = { 'size': (0,1,256), 'pos': (0,1,256), 'color': (0,137,138), 'font': (0,511,512) } import numpy as np # pre 和 post 分别代表 10 的幂,分别对应大数和小数部分,参数代表位数 def get_keys_and_multipliers(pre_decimal=3, post_decimal=2): pre_keys = ['one', 'ten', 'hundred', 'thousand'] pre_multiplers = [1, 10, 100, 1000] assert pre_decimal <= len(pre_keys) pre_keys = pre_keys[:pre_decimal][::-1] pre_multiplers = pre_multiplers[:pre_decimal][::-1] post_keys = [f'decimal{x}' for x in range(post_decimal)] post_multiplers = [10 ** -(x+1) for x in range(post_decimal)] keys = pre_keys + post_keys multiplers = pre_multiplers + post_multiplers return keys, multiplers class DecimalQuantizer: def __init__(self, max_pre_decimal=3, max_post_decimal=2): self.max_pre_decimal = max_pre_decimal self.max_post_decimal = max_post_decimal self.keys, self.multiplers = get_keys_and_multipliers(max_pre_decimal, max_post_decimal) self.symbols = { -1: '<symbol-1>', 1: '<symbol-0>', } def get_vocab(self): special_tokens = [*self.symbols.values()] # ['<symbol-1>', '<symbol-0>'] for key in self.keys: # ['one', 'ten', 'hundred', 'thousand'] + ['decimal0', 'decimal1] special_tokens.extend([f'<{key}-{i}>' for i in range(10)]) return special_tokens def check_valid(self, token): prefix = token.lstrip('<').split('-')[0] # '<symbol-1>' -> 'symbol-1>' -> ['symbol', '1>'] if prefix =='symbol' or prefix in self.keys: return True else: return False # 小数点后保留两位 def __call__(self, val, pre_decimal=None, post_decimal=None, need_symbol=False): # 100.00 if pre_decimal is None: pre_decimal = self.max_pre_decimal if post_decimal is None: post_decimal = self.max_post_decimal assert pre_decimal <= self.max_pre_decimal assert post_decimal <= self.max_post_decimal keys, multiplers = get_keys_and_multipliers(pre_decimal, post_decimal) symbol = int(np.sign(val)) # 返回一个浮点数(1.0, -1.0 或 0.0),代表正负和0 if symbol == 0: # 两类:>= 0 & < 0 symbol = 1 val = round(abs(val), post_decimal) # 将 val 的绝对值四舍五入到 post_decimal 位小数 tokens = [] if need_symbol: # self.symbols = {-1: '<symbol-1>', 1: '<symbol-0>',} symbol_type = self.symbols[symbol] tokens.append(symbol_type) else: assert symbol >= 0 for key, multipler in zip(keys, multiplers): # 用于获取对于给定数值 val,每一位的数字,并且生成为'<one-7>'这样的token v = math.floor(val / multipler) if v > 9: raise ValueError(f'Invalid value {val} for {pre_decimal} pre_decimal and {post_decimal} post_decimal') val = val - v * multipler tokens.append(f'<{key}-{v}>') # 对于val,生成每一位数字对应的token,如果need_symbol = True,还会在前面加上 标识 >= 0 和 < 0 的 symbol-1 和 symbol-0 return ''.join(tokens) def parse_token(self, token): # <hundred-1> -> hundred, 1 key, val = token[1:-1].split('-') return key, int(val) def decode(self, tokens_str): # 将token_str用 > 先拆开,再添上 > ,然后转化为 list tokens = tokens_str.split('>') tokens = [x+'>' for x in tokens if x != ''] if tokens[0].startswith('<symbol'): symbol_type = tokens[0] tokens = tokens[1:] inv_map = {v: k for k, v in self.symbols.items()} # 和 原字典 键、值 对调 symbol = inv_map[symbol_type] else: symbol = 1 accumulater = 0 for token in tokens: key, val = self.parse_token(token) multipler_index = self.keys.index(key) multipler = self.multiplers[multipler_index] actual_val = val * multipler # print(key, val, multipler, actual_val) accumulater += actual_val accumulater = accumulater * symbol # 还原出原来的整数,带有符号,并且精度 由 pre/post_decimal位数控制 return accumulater # min_max_bins = { # 'size': (0,1,256), # 'pos': (0,1,256), # 'color': (0,137,138), # 'font': (0,511,512) # } pre_post_decimals={ 'size': { 'pre_decimal': 1, 'post_decimal': 2, 'need_symbol': False }, 'pos': { 'pre_decimal': 1, 'post_decimal': 2, 'need_symbol': True }, 'opacity': { 'pre_decimal': 1, 'post_decimal': 1, 'need_symbol': False }, 'color':{ 'pre_decimal': 3, 'post_decimal': 0, 'need_symbol': False }, 'angle':{ 'pre_decimal': 1, 'post_decimal': 2, 'need_symbol': False }, 'font_size':{ 'pre_decimal': 3, 'post_decimal': 0, 'need_symbol': False }, } class QuantizerV4(BaseQuantizer): def __init__(self, quant=True, decimal_quantize_types = [], decimal_quantize_kwargs = {'max_pre_decimal':3, 'max_post_decimal':2}, mask_values=False, **kwargs): super().__init__(**kwargs) self.min = min self.max = max self.quant = quant self.mask_values = mask_values self.text_split_token = '<split-text>' self.decimal_quantize_types = decimal_quantize_types self.decimal_quantize = len(decimal_quantize_types) > 0 if len(decimal_quantize_types) > 0: print('decimal quantize types', decimal_quantize_types) self.decimal_quantizer = DecimalQuantizer(**decimal_quantize_kwargs) else: self.decimal_quantizer = None self.set_min_max_bins(min_max_bins) # min_max_bins = { # 'size': (0,1,256), # 'pos': (0,1,256), # 'color': (0,137,138), # 'font': (0,511,512) # } self.width = kwargs.get('width', 1456) self.height = kwargs.get('height', 1457) self.width = int(self.width) self.height = int(self.height) def set_min_max_bins(self, min_max_bins): # 检查 n_bins是否是偶数,然后将其 +1 min_max_bins = copy.deepcopy(min_max_bins) # adjust the bins to plus one for type_name, (min_val, max_val, n_bins) in min_max_bins.items(): assert n_bins % 2 == 0 # must be even min_max_bins[type_name] = (min_val, max_val, n_bins+1) self.min_max_bins = min_max_bins def setup_tokenizer(self, tokenizer): # 整个函数生成additional_special_tokens:1. '<split-text>' 2.<one-1> <symbol-1> : decimal quantizer 3. <size-255> quantizerV4 4.self.get_additional_mask_tokens() # 然后tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens}) additional_special_tokens = [self.text_split_token] # self.text_split_token = '<split-text>' if self.decimal_quantize: special_tokens = self.decimal_quantizer.get_vocab() # <one-1> <symbol-1> self.io_ignore_replace_tokens += special_tokens # self.io_ignore_replace_tokens = ['<split-text>'] 在BaseQuantizer中声明 additional_special_tokens += special_tokens # the order must be preserved, other wise the tokenizer will be wrong rest_types = [key for key in self.min_max_bins.keys() if key not in self.decimal_quantize_types] for type_name in rest_types: min_val, max_val, n_bins = self.min_max_bins[type_name] additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] # <size-256> if self.num_mask_tokens > 0: additional_special_tokens.extend(self.get_additional_mask_tokens()) print('additional_special_tokens', additional_special_tokens) tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens}) self.additional_special_tokens = set(additional_special_tokens) return tokenizer @lru_cache(maxsize=128) # 缓存函数的返回值,以提高性能。maxsize=128 表示缓存最多存储 128 个不同的输入结果 def get_bins(self, real_type): # real_type: size, pos, font, color # 返回 最小值,最大值,等距数组 min_val, max_val, n_bins = self.min_max_bins[real_type] return min_val, max_val, np.linspace(min_val, max_val, n_bins) def quantize(self, x, type): # (0.25, 'y') -> (<size-50>) if not self.quant: return x """Quantize a float array x into n_bins discrete values.""" real_type = specs[type] # x, y, width, height, color, font -> size, pos, font, color min_val, max_val, bins = self.get_bins(real_type) x = np.clip(float(x), min_val, max_val) # 确保 x 的值在 [min_val, max_val] 范围内,否则截断 if self.decimal_quantize and real_type in self.decimal_quantize_types: return self.decimal_quantizer(x, **pre_post_decimals[real_type]) val = np.digitize(x, bins) - 1 # val是一个整数,取值范围在[0, len(bins)],换句话说就是bins数组的索引 n_bins = len(bins) assert val >= 0 and val < n_bins return f'<{real_type}-{val}>' # <size-255> def dequantize(self, x): # (<size-255> -> 0.99?) # <pos-1>->1 val = x.split('-')[1].strip('>') # <pos-1>->pos real_type = x.split('-')[0][1:] if self.decimal_quantize and self.decimal_quantizer.check_valid(x): return self.decimal_quantizer.decode(x) min_val, max_val, bins = self.get_bins(real_type) return bins[int(val)] def construct_map_dict(self): map_dict = {} for i in range(self.min_max_bins['size'][2]): # 'size': (0, 1, 256), name = "<size-%d>" % i value = self.dequantize(name) map_dict[name] = str(value) # 255 -> 0.99? for i in range(self.min_max_bins['pos'][2]): name = "<pos-%d>" % i value = self.dequantize(name) map_dict[name] = str(value) return map_dict def postprocess_colorandfont(self, json_example): # 将其中的 正则 匹配部分 用双引号包裹 import re json_example = re.sub(r'(<font-\d+>)', r'"\1"', json_example) json_example = re.sub(r'(<color-\d+>)', r'"\1"', json_example) return json_example def to_str(self, x, type): feature = self.get_feature(type) return feature.int2str(x) def convert2layout(self, example): # 将原始的数据转化为 <size-255> 的 token形式 new_example = OrderedDict() new_example['wholecaption'] = example['wholecaption'] new_layout = [] for meta_layer in example['layout']: new_layout.append({ "layer": meta_layer["layer"], "x": self.quantize(meta_layer["x"]/self.width, 'x'), "y": self.quantize(meta_layer["y"]/self.height, 'y'), "width": self.quantize(meta_layer["width"]/self.width, 'width'), "height": self.quantize(meta_layer["height"]/self.height, 'height') }) new_example['layout'] = new_layout return new_example def apply_masking(self, json_example, mask_all=None, return_meta=False, # target_keys=['width', 'height', 'left', 'top'], # useless # target_element_types=None, # useless mask_values = True ): if mask_all is None: mask_all = self.mask_all json_example = copy.deepcopy(json_example) # 这段内容对json中的一些 value 替换为 <mask-i>,并用self.num_mask_tokens限制mask的数量,根据参数还可能进行随机mask # 并记录 <mask-i> & value & num_token = value.count('<') 的 三元tuple target_tokens = [] if self.mask_values and mask_values: target_tokens.append((-1,-1,'globalcaption', json_example['globalcaption'])) target_tokens.append((-1,-1,'canvas_width', json_example['canvas_width'])) target_tokens.append((-1,-1,'canvas_height', json_example['canvas_height'])) target_tokens.append((-1,-1,'category', json_example['category'])) target_tokens.append((-1,-1,'keywords', json_example['keywords'])) target_tokens.append((-1,-1,'bgcaption', json_example['layers']['bglayer']['bgcaption'])) target_tokens.append((-1,-1,'flag', json_example['layers']['objlayer']['flag'])) target_tokens.append((-1,-1,'objcaption', json_example['layers']['objlayer']['objcaption'])) for layer_i, textlayer in enumerate(json_example['layers']['textlayer']): target_tokens.append((layer_i, -1, 'text', json_example['layers']['textlayer'][textlayer])) if not mask_all: # 随机取值 target_num_mask_tokens, 上界是self.num_mask_tokens target_num_mask_tokens = random.randint(1, self.num_mask_tokens) if len(target_tokens) > target_num_mask_tokens: random.shuffle(target_tokens) target_tokens = target_tokens[:target_num_mask_tokens] # sort by shape_i and key_i target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1]) else: # 取定值 num_mask_tokens if len(target_tokens) > self.num_mask_tokens: # 取最后面几个 target_tokens = target_tokens[-self.num_mask_tokens:] tuples = [] meta_infos = [] layer_list = ['heading', 'subheading', 'body'] for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens): if self.mask_type == 'cm3': mask_token = self.mask_tokens[mask_i] elif self.mask_type == 'mask_aug': mask_token = self.mask_aug_token else: raise ValueError(f'Invalid mask type {self.mask_type}') # <one-1><decimal0-1><decimal1-2> if '<' in value: num_token = value.count('<') else: num_token = value.count(' ') + 1 if shape_i == -1: if key in ['bgcaption']: json_example['layers']['bglayer']['bgcaption'] = mask_token elif key in ['objcaption']: json_example['layers']['objlayer']['objcaption'] = mask_token elif key in ['flag']: json_example['layers']['objlayer']['flag'] = mask_token else: json_example[key] = mask_token else: curlayer = layer_list[shape_i] json_example['layers']['textlayer'][curlayer] = mask_token tuples.append((mask_token, value, num_token)) meta_infos.append((shape_i,key)) if return_meta: return json_example, tuples, meta_infos else: return json_example, tuples # useless orginally used for render def is_font_exists(font_name): font_list = font_manager.findSystemFonts() # print("\nfont_list: ",font_list) for font in font_list: if font_name.lower() in font.lower(): return True return False def print_info(msg): print(Fore.GREEN + "[INFO] " + msg) def print_warning(msg): print(Fore.YELLOW + "[WARNING] " + msg) def print_error(msg): print(Fore.RED + "[ERROR] " + msg) def load_feature(path): with open(path) as f: content = f.read() content = json.loads(content) names = [content[str(i)] for i in range(len(content))] return ClassLabel(num_classes= len(names), names=names) def get_quantizer(version='v1', update_vocab=False, **kwargs): """ if kwargs.pop('separate_alpha', False): # useless kwargs['n_visual_tokens'] *= 2 """ if version == 'v4': quantizer = QuantizerV4(**kwargs) else: raise NotImplementedError return quantizer