from data_provider.context_gen import *

def parse_args():
    parser = argparse.ArgumentParser(description="A simple argument parser")

	# Script arguments
    parser.add_argument('--name', default='none', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--chunk_size', default=100, type=int)
    parser.add_argument('--rxn_num', default=50000, type=int)
    parser.add_argument('--k', default=4, type=int)
    parser.add_argument('--root', default='data/pretrain_data', type=str)

    args = parser.parse_args()
    return args

def pad_shorter_array(arr1, arr2):
    len1 = arr1.shape[0]
    len2 = arr2.shape[0]
    if len1 > len2:
        arr2 = np.pad(arr2, (0, len1 - len2), 'constant')
    elif len2 > len1:
        arr1 = np.pad(arr1, (0, len2 - len1), 'constant')
    return arr1, arr2

def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'):
    num_full_chunks = len(values) // chunk_size
    values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1)
    values = np.sort(values)[::-1]
    plt.figure(figsize=(10, 4), dpi=100)
    x = np.arange(len(values))
    plt.bar(x, values, color=color)
    current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
    plt.xticks((current_values/chunk_size).astype(int), current_values)
    plt.ylabel('Molecule Frequency', fontsize=20)
    if x_lim:
        plt.xlim(*x_lim)
    if y_lim:
        plt.ylim(*y_lim)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout(pad=0.5)
    plt.savefig(target_path)
    print(f'Figure saved to {target_path}')
    plt.clf()

def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100):
    num_full_chunks = len(list1) // chunk_size
    list1, list2 = pad_shorter_array(list1, list2)
    values1, values2 = [
        np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1]
        for values in (list1, list2)]

    plt.figure(figsize=(10, 6), dpi=100)
    x = np.arange(len(values1))
    plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6)
    plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5)
    current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
    plt.xticks((current_values/chunk_size).astype(int), current_values)
    plt.ylabel('Molecule Frequency', fontsize=20)
    if x_lim:
        plt.xlim(*x_lim)
    if y_lim:
        plt.ylim(*y_lim)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.tight_layout(pad=0.5)
    plt.legend(fontsize=24, loc='upper right')
    plt.savefig(target_path)
    print(f'Figure saved to {target_path}')
    plt.clf()

def statistics(args):
    if args.seed:
        set_random_seed(args.seed)
    # 1141864 rxns from ord
    # 1120773 rxns from uspto
    cluster = Reaction_Cluster(args.root)

    rxn_num = len(cluster.reaction_data)
    abstract_num = 0
    property_num = 0
    calculated_property_num = 0
    experimental_property_num = 0
    avg_calculated_property_len = 0
    avg_experimental_property_len = 0
    mol_set = set()
    for rxn_dict in cluster.reaction_data:
        for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
            for mol in rxn_dict[key]:
                mol_set.add(mol)
    mol_num = len(mol_set)

    for mol_dict in cluster.property_data:
        if 'abstract' in mol_dict:
            abstract_num += 1
        if 'property' in mol_dict:
            property_num += 1
            if 'Experimental Properties' in mol_dict['property']:
                experimental_property_num += 1
                avg_experimental_property_len += len(mol_dict['property']['Experimental Properties'])
            if 'Computed Properties' in mol_dict['property']:
                calculated_property_num += 1
                avg_calculated_property_len += len(mol_dict['property']['Computed Properties'])
            
    print(f'Reaction Number: {rxn_num}')
    print(f'Molecule Number: {mol_num}')
    print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)')
    print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)')
    print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule')
    print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule')

def visualize(args):
    if args.seed:
        set_random_seed(args.seed)
    cluster = Reaction_Cluster(args.root)
    prob_values, rxn_weights = cluster.visualize_mol_distribution()
    rand_prob_values, rand_rxn_weights = cluster._randomly(
        cluster.visualize_mol_distribution
    )
    fig_root = f'results/{args.name}/'

    plot_distribution(prob_values, fig_root+'mol_distribution.pdf')
    plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf')
    plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf')
    plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf')
    
    plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5))
    plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf')


def visualize_frequency(args):
    if args.seed:
        set_random_seed(args.seed)
    fig_root = f'results/{args.name}/'
    name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}'
    cache_path = f'{fig_root}/freq_{name_suffix}.npy'
    if os.path.exists(cache_path):
        mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True)
    else:
        cluster = Reaction_Cluster(args.root)
        mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs)
        rand_mol_freq, rand_rxn_freq = cluster._randomly(
            cluster.visualize_mol_frequency,
            rxn_num=args.rxn_num, k=args.k, epochs=args.epochs
        )
        np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True)

    color1 = '#FA7F6F'
    color2 = '#80AFBF'
    color3 = '#FFBE7A'
    plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
    # plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
    plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
    # plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)

    plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size)
    # plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size)

if __name__=='__main__':
    args = parse_args()
    print(args, flush=True)
    # statistics(args)
    # visualize(args)
    visualize_frequency(args)