import torch
import numpy as np
import copy
from collections import OrderedDict
import json
from datasets import ClassLabel
import random
import math
from functools import lru_cache
from matplotlib import font_manager
from colorama import Fore, Style, init 


class BaseQuantizer:
    @property
    def ignore_tokens(self):
        if self.num_mask_tokens > 0:
            if self.mask_type == 'cm3':
                return [self.predict_start_token] + self.mask_tokens
            elif self.mask_type == 'mask_aug':
                return [self.mask_aug_token]
            else:
                raise ValueError(f'Invalid mask type {self.mask_type}')
        else:
            return []

    def __init__(self, simplify_json=False, mask_all=False,
                    num_mask_tokens=0, mask_type='cm3', **kwargs):
        self.simplify_json=simplify_json
        self.io_ignore_replace_tokens = ['<split-text>']
        self.mask_all = mask_all
        self.num_mask_tokens = num_mask_tokens
        self.mask_type = mask_type
        if self.mask_type == 'mask_aug':
            self.mask_aug_token = '<mask-aug>'
        elif self.mask_type == 'cm3':
            self.predict_start_token = '<pred-start>'
        else:
            raise ValueError(f'Invalid mask type {self.mask_type}')

    def get_additional_mask_tokens(self):
        if self.mask_type == 'cm3': # 两种配置:1. ['<pred-start>'] + '<mask-%d>',数量和self.num_mask_tokens相关 2. ['<mask-aug>']
            self.mask_tokens = ['<mask-%d>' % i for i in range(self.num_mask_tokens)]
            return [self.predict_start_token] + self.mask_tokens
        elif self.mask_type == 'mask_aug':
            return [self.mask_aug_token]
        else:
            raise ValueError(f'Invalid mask type {self.mask_type}')

    def dump2json(self, json_example):
        if self.simplify_json: # 将 dict 转化为 str, 如果simplify_json is True,那么缩减空格和换行,删除token的双引号
            content = json.dumps(json_example, separators=(',',':'))
            for token in self.additional_special_tokens:
                content = content.replace(f'"{token}"', token)
        else:
            content = json.dumps(json_example)
        return content

    def load_json(self, content): # 将str转化为json
        replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change 
        if self.simplify_json:
            for token in replace_tokens: # 如果simplify_json is True,那么为 token 添加双引号
                content = content.replace(token, f'"{token}"')
        return json.loads(content)

    def apply_masking(self, 
                  json_example, 
                  mask_all=None, 
                  return_meta=False, 
                  target_keys=['width', 'height', 'left', 'top'],
                  target_element_types=None
                  ):
        if mask_all is None:
            mask_all = self.mask_all
        json_example = copy.deepcopy(json_example)
        target_keys = set(target_keys)
        target_tokens = []
        for shape_i, shape in enumerate(json_example['layers']['textlayer']):
            # element_type = self.general_dequantize(shape['type'],'type',to_float=False)
            # if target_element_types is not None:
                # if element_type not in target_element_types:
                    # continue
            for key_i, key in enumerate(shape.keys()):
                if key in target_keys:
                    target_tokens.append((shape_i, key_i, key, shape[key]))
        if not mask_all:
            target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
            if len(target_tokens) > target_num_mask_tokens:
                random.shuffle(target_tokens)
                target_tokens = target_tokens[:target_num_mask_tokens]
                # sort by shape_i and key_i
                target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
        else:
            if len(target_tokens) > self.num_mask_tokens:
                # 取最后面几个
                target_tokens = target_tokens[-self.num_mask_tokens:]

        tuples = []
        meta_infos = []
        for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
            if self.mask_type == 'cm3':
                mask_token = self.mask_tokens[mask_i]
            elif self.mask_type == 'mask_aug':
                mask_token = self.mask_aug_token
            else:
                raise ValueError(f'Invalid mask type {self.mask_type}')
            # <one-1><decimal0-1><decimal1-2>
            if '<' in value:
                num_token = value.count('<')
            else:
                num_token = value.count(' ')
            json_example['layers']['textlayer'][shape_i][key] = mask_token
            tuples.append((mask_token, value, num_token))
            meta_infos.append((shape_i,key))
        if return_meta:
            return json_example, tuples, meta_infos
        else:
            return json_example, tuples

    def make_prediction_postfix(self, tuples):
        postfix = self.predict_start_token
        for mask_token, value, num_token in tuples:
            postfix = postfix+ f'{mask_token}{value}'
        return postfix

# specs={
#     "width":"size",
#     "height":"size",
#     "left":"pos",
#     "top":"pos",
#     "x":"pos", # center x
#     "y":"pos", # center y
#     "opacity":"opacity",
#     "color":"color",
#     "angle":"angle",
#     "font_size":"font_size",
#     'ratio':'ratio',
#     'letter_spacing': 'spacing',
#     'textlen': 'textlen'
# }

specs={
    "width":"size",
    "height":"size",
    "x":"pos", # center x
    "y":"pos", # center y
    "color":"color",
    "font":"font"
}

# TODO change min_max_bins
# min_max_bins = {
#     'size':(0,2,256),
#     'pos':(-1,1,256),
#     # 'opacity':(0,1,8),
#     'opacity':(0,255,8),
#     'color':(0,255,32),
#     'angle':(0,2*np.pi,64),
#     'font_size':(2,200,100),
#     'spacing': (0,1,40),
#     'textlen': (1,20,20)
# }
min_max_bins = {
    'size': (0,1,256),
    'pos': (0,1,256),
    'color': (0,137,138),
    'font': (0,511,512)
}

import numpy as np

# pre 和 post 分别代表 10 的幂,分别对应大数和小数部分,参数代表位数
def get_keys_and_multipliers(pre_decimal=3, post_decimal=2):
    pre_keys = ['one', 'ten', 'hundred', 'thousand']
    pre_multiplers = [1, 10, 100, 1000]
    assert pre_decimal <= len(pre_keys)
    pre_keys = pre_keys[:pre_decimal][::-1]
    pre_multiplers = pre_multiplers[:pre_decimal][::-1]

    post_keys = [f'decimal{x}' for x in range(post_decimal)]
    post_multiplers = [10 ** -(x+1) for x in range(post_decimal)]

    keys =  pre_keys + post_keys
    multiplers = pre_multiplers + post_multiplers
    return keys, multiplers

class DecimalQuantizer:
    def __init__(self, max_pre_decimal=3, max_post_decimal=2):
        self.max_pre_decimal = max_pre_decimal
        self.max_post_decimal = max_post_decimal
        self.keys, self.multiplers = get_keys_and_multipliers(max_pre_decimal, max_post_decimal)
        self.symbols = {
            -1: '<symbol-1>',
            1: '<symbol-0>',
        }

    def get_vocab(self):
        special_tokens = [*self.symbols.values()]   # ['<symbol-1>', '<symbol-0>']
        for key in self.keys:  # ['one', 'ten', 'hundred', 'thousand'] + ['decimal0', 'decimal1]
            special_tokens.extend([f'<{key}-{i}>' for i in range(10)])
        return special_tokens

    def check_valid(self, token):
        prefix = token.lstrip('<').split('-')[0] # '<symbol-1>' -> 'symbol-1>' -> ['symbol', '1>']
        if prefix =='symbol' or prefix in self.keys:
            return True
        else:
            return False

    # 小数点后保留两位
    def __call__(self, val, pre_decimal=None, post_decimal=None, need_symbol=False): # 100.00
        if pre_decimal is None:
            pre_decimal = self.max_pre_decimal
        if post_decimal is None:
            post_decimal = self.max_post_decimal
        
        assert pre_decimal <= self.max_pre_decimal
        assert post_decimal <= self.max_post_decimal

        keys, multiplers = get_keys_and_multipliers(pre_decimal, post_decimal)
        
        symbol = int(np.sign(val)) # 返回一个浮点数(1.0, -1.0 或 0.0),代表正负和0
        if symbol == 0: # 两类:>= 0  &  < 0
            symbol = 1
        val = round(abs(val), post_decimal) # 将 val 的绝对值四舍五入到 post_decimal 位小数
        
        tokens = []
        if need_symbol: # self.symbols = {-1: '<symbol-1>', 1: '<symbol-0>',}
            symbol_type = self.symbols[symbol]
            tokens.append(symbol_type)
        else:
            assert symbol >= 0

        for key, multipler in zip(keys, multiplers):
            # 用于获取对于给定数值 val,每一位的数字,并且生成为'<one-7>'这样的token
            v = math.floor(val / multipler)
            if v > 9:
                raise ValueError(f'Invalid value {val} for {pre_decimal} pre_decimal and {post_decimal} post_decimal')
            val = val - v * multipler
            tokens.append(f'<{key}-{v}>')
            
        # 对于val,生成每一位数字对应的token,如果need_symbol = True,还会在前面加上 标识 >= 0 和 < 0 的 symbol-1 和 symbol-0
        return ''.join(tokens)

    def parse_token(self, token):
        # <hundred-1> -> hundred, 1
        key, val = token[1:-1].split('-')
        return key, int(val)
    
    def decode(self, tokens_str): # 将token_str用 > 先拆开,再添上 > ,然后转化为 list
        tokens = tokens_str.split('>')
        tokens = [x+'>' for x in tokens if x != '']
        if tokens[0].startswith('<symbol'):
            symbol_type = tokens[0]
            tokens = tokens[1:]
            inv_map = {v: k for k, v in self.symbols.items()} # 和 原字典 键、值 对调
            symbol = inv_map[symbol_type]
        else:
            symbol = 1

        accumulater = 0
        for token in tokens:
            key, val = self.parse_token(token)
            multipler_index = self.keys.index(key)
            multipler = self.multiplers[multipler_index]
            actual_val = val * multipler
            # print(key, val, multipler, actual_val)
            accumulater += actual_val
        accumulater = accumulater * symbol
        
        # 还原出原来的整数,带有符号,并且精度 由 pre/post_decimal位数控制
        return accumulater

# min_max_bins = {
#     'size': (0,1,256),
#     'pos': (0,1,256),
#     'color': (0,137,138),
#     'font': (0,511,512)
# }
pre_post_decimals={
    'size': {
        'pre_decimal': 1,
        'post_decimal': 2,
        'need_symbol': False
    },
    'pos': {
        'pre_decimal': 1,
        'post_decimal': 2,
        'need_symbol': True
    },
    'opacity': {
        'pre_decimal': 1,
        'post_decimal': 1,
        'need_symbol': False
    },
    'color':{
        'pre_decimal': 3,
        'post_decimal': 0,
        'need_symbol': False
    },
    'angle':{
        'pre_decimal': 1,
        'post_decimal': 2,
        'need_symbol': False
    },
    'font_size':{
        'pre_decimal': 3,
        'post_decimal': 0,
        'need_symbol': False
    },
}

class QuantizerV4(BaseQuantizer):
    def __init__(self, quant=True,
                 decimal_quantize_types = [],
                 decimal_quantize_kwargs = {'max_pre_decimal':3, 'max_post_decimal':2},
                 mask_values=False,
                 **kwargs):
        super().__init__(**kwargs)
        self.min = min
        self.max = max
        self.quant = quant
        self.mask_values = mask_values
        self.text_split_token = '<split-text>'
        self.decimal_quantize_types = decimal_quantize_types
        self.decimal_quantize = len(decimal_quantize_types) > 0
        if len(decimal_quantize_types) > 0:
            print('decimal quantize types', decimal_quantize_types)
            self.decimal_quantizer = DecimalQuantizer(**decimal_quantize_kwargs)
        else:
            self.decimal_quantizer = None
            
        self.set_min_max_bins(min_max_bins) 
        # min_max_bins = {
        #     'size': (0,1,256),
        #     'pos': (0,1,256),
        #     'color': (0,137,138),
        #     'font': (0,511,512)
        # }
        self.width = kwargs.get('width', 1456)
        self.height = kwargs.get('height', 1457)
        self.width = int(self.width)
        self.height = int(self.height)

    def set_min_max_bins(self, min_max_bins): # 检查 n_bins是否是偶数,然后将其 +1
        min_max_bins = copy.deepcopy(min_max_bins)
        # adjust the bins to plus one
        for type_name, (min_val, max_val, n_bins) in min_max_bins.items():
            assert n_bins % 2 == 0 # must be even
            min_max_bins[type_name] = (min_val, max_val, n_bins+1)
        self.min_max_bins = min_max_bins

    def setup_tokenizer(self, tokenizer):
        # 整个函数生成additional_special_tokens:1. '<split-text>' 2.<one-1> <symbol-1> : decimal quantizer 3. <size-255> quantizerV4 4.self.get_additional_mask_tokens()
        # 然后tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
        additional_special_tokens = [self.text_split_token]  # self.text_split_token = '<split-text>'
        if self.decimal_quantize:
            special_tokens = self.decimal_quantizer.get_vocab() # <one-1> <symbol-1> 
            self.io_ignore_replace_tokens += special_tokens # self.io_ignore_replace_tokens = ['<split-text>'] 在BaseQuantizer中声明
            additional_special_tokens += special_tokens
        # the order must be preserved, other wise the tokenizer will be wrong
        rest_types = [key for key in self.min_max_bins.keys() if key not in self.decimal_quantize_types]
        for type_name in rest_types:
            min_val, max_val, n_bins = self.min_max_bins[type_name]
            additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] # <size-256>

        if self.num_mask_tokens > 0:
            additional_special_tokens.extend(self.get_additional_mask_tokens())
        
        print('additional_special_tokens', additional_special_tokens)
        
        tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
        self.additional_special_tokens = set(additional_special_tokens)
        return tokenizer

    @lru_cache(maxsize=128) # 缓存函数的返回值,以提高性能。maxsize=128 表示缓存最多存储 128 个不同的输入结果
    def get_bins(self, real_type): # real_type: size, pos, font, color
        # 返回 最小值,最大值,等距数组
        min_val, max_val, n_bins = self.min_max_bins[real_type]
        return min_val, max_val, np.linspace(min_val, max_val, n_bins)

    def quantize(self, x, type): # (0.25, 'y') -> (<size-50>)
        if not self.quant:
            return x
        """Quantize a float array x into n_bins discrete values."""
        real_type = specs[type] # x, y, width, height, color, font -> size, pos, font, color
        min_val, max_val, bins = self.get_bins(real_type)
        x = np.clip(float(x), min_val, max_val) # 确保 x 的值在 [min_val, max_val] 范围内,否则截断
        if self.decimal_quantize and real_type in self.decimal_quantize_types:
            return self.decimal_quantizer(x, **pre_post_decimals[real_type])
        val = np.digitize(x, bins) - 1  # val是一个整数,取值范围在[0, len(bins)],换句话说就是bins数组的索引
        n_bins = len(bins)
        assert val >= 0 and val < n_bins
        return f'<{real_type}-{val}>' # <size-255>
    
    def dequantize(self, x): # (<size-255> -> 0.99?)
        # <pos-1>->1
        val = x.split('-')[1].strip('>')
        # <pos-1>->pos
        real_type = x.split('-')[0][1:]
        if self.decimal_quantize and self.decimal_quantizer.check_valid(x):
            return self.decimal_quantizer.decode(x)
        min_val, max_val, bins = self.get_bins(real_type)
        return bins[int(val)]

    def construct_map_dict(self):
        map_dict = {}
        for i in range(self.min_max_bins['size'][2]): # 'size': (0, 1, 256),
            name = "<size-%d>" % i
            value = self.dequantize(name)
            map_dict[name] = str(value) # 255 -> 0.99?
        for i in range(self.min_max_bins['pos'][2]):
            name = "<pos-%d>" % i
            value = self.dequantize(name)
            map_dict[name] = str(value)
        return map_dict
    
    def postprocess_colorandfont(self, json_example):
        # 将其中的 正则 匹配部分 用双引号包裹
        import re
        json_example = re.sub(r'(<font-\d+>)', r'"\1"', json_example)
        json_example = re.sub(r'(<color-\d+>)', r'"\1"', json_example)
        return json_example
   
    def to_str(self, x, type):
        feature = self.get_feature(type)
        return feature.int2str(x)

    def convert2layout(self, example): # 将原始的数据转化为 <size-255> 的 token形式
        new_example = OrderedDict()
        new_example['wholecaption'] = example['wholecaption']
        new_layout = []
        for meta_layer in example['layout']:
            new_layout.append({
                "layer": meta_layer["layer"],
                "x": self.quantize(meta_layer["x"]/self.width, 'x'),
                "y": self.quantize(meta_layer["y"]/self.height, 'y'),
                "width": self.quantize(meta_layer["width"]/self.width, 'width'),
                "height": self.quantize(meta_layer["height"]/self.height, 'height')
            })
        new_example['layout'] = new_layout
        return new_example
    
    def apply_masking(self, 
                  json_example, 
                  mask_all=None, 
                  return_meta=False, 
                  # target_keys=['width', 'height', 'left', 'top'], # useless
                  # target_element_types=None, # useless
                  mask_values = True
                  ):
        if mask_all is None:
            mask_all = self.mask_all
        
        json_example = copy.deepcopy(json_example)
        
        # 这段内容对json中的一些 value 替换为 <mask-i>,并用self.num_mask_tokens限制mask的数量,根据参数还可能进行随机mask
        # 并记录 <mask-i> & value & num_token = value.count('<') 的 三元tuple 
        target_tokens = []
        if self.mask_values and mask_values:
            target_tokens.append((-1,-1,'globalcaption', json_example['globalcaption']))
            target_tokens.append((-1,-1,'canvas_width', json_example['canvas_width']))
            target_tokens.append((-1,-1,'canvas_height', json_example['canvas_height']))
            target_tokens.append((-1,-1,'category', json_example['category']))
            target_tokens.append((-1,-1,'keywords', json_example['keywords']))
            target_tokens.append((-1,-1,'bgcaption', json_example['layers']['bglayer']['bgcaption']))
            target_tokens.append((-1,-1,'flag', json_example['layers']['objlayer']['flag']))
            target_tokens.append((-1,-1,'objcaption', json_example['layers']['objlayer']['objcaption']))
        for layer_i, textlayer in enumerate(json_example['layers']['textlayer']):
            target_tokens.append((layer_i, -1, 'text', json_example['layers']['textlayer'][textlayer]))
        if not mask_all: # 随机取值 target_num_mask_tokens, 上界是self.num_mask_tokens
            target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
            if len(target_tokens) > target_num_mask_tokens:
                random.shuffle(target_tokens)
                target_tokens = target_tokens[:target_num_mask_tokens]
                # sort by shape_i and key_i
                target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
        else: # 取定值 num_mask_tokens
            if len(target_tokens) > self.num_mask_tokens:
                # 取最后面几个
                target_tokens = target_tokens[-self.num_mask_tokens:]

        tuples = []
        meta_infos = []
        layer_list = ['heading', 'subheading', 'body']
        for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
            if self.mask_type == 'cm3':
                mask_token = self.mask_tokens[mask_i]
            elif self.mask_type == 'mask_aug':
                mask_token = self.mask_aug_token
            else:
                raise ValueError(f'Invalid mask type {self.mask_type}')
            # <one-1><decimal0-1><decimal1-2>
            if '<' in value:
                num_token = value.count('<')
            else:
                num_token = value.count(' ') + 1
            if shape_i == -1:
                if key in ['bgcaption']:
                    json_example['layers']['bglayer']['bgcaption'] = mask_token
                elif key in ['objcaption']:
                    json_example['layers']['objlayer']['objcaption'] = mask_token
                elif key in ['flag']:
                    json_example['layers']['objlayer']['flag'] = mask_token
                else:
                    json_example[key] = mask_token
            else:
                curlayer = layer_list[shape_i]
                json_example['layers']['textlayer'][curlayer] = mask_token
            tuples.append((mask_token, value, num_token)) 
            meta_infos.append((shape_i,key))
        if return_meta:
            return json_example, tuples, meta_infos
        else:
            return json_example, tuples


# useless orginally used for render
def is_font_exists(font_name):  
    font_list = font_manager.findSystemFonts() 
    # print("\nfont_list: ",font_list)
    for font in font_list:  
        if font_name.lower() in font.lower():  
            return True  
    return False 

def print_info(msg):  
    print(Fore.GREEN + "[INFO] " + msg)  
  
def print_warning(msg):  
    print(Fore.YELLOW + "[WARNING] " + msg)  
  
def print_error(msg):  
    print(Fore.RED + "[ERROR] " + msg)  

def load_feature(path):
    with open(path) as f:
        content = f.read()
    content = json.loads(content)
    names = [content[str(i)] for i in range(len(content))]
    return ClassLabel(num_classes= len(names), names=names)

def get_quantizer(version='v1', update_vocab=False, **kwargs):
    """ if kwargs.pop('separate_alpha', False):  # useless
        kwargs['n_visual_tokens'] *= 2 """
    if version == 'v4':
        quantizer = QuantizerV4(**kwargs)
    else:
        raise NotImplementedError

    return quantizer