import torch
from torch_geometric.data import Dataset
import os
from .context_gen import Reaction_Cluster
import json
from .data_utils import smiles2data, reformat_smiles
from collections import defaultdict
import random
from data_provider.caption_dataset import PretrainCaptionDataset
from data_provider.synthesis_dataset import SynthesisDataset

def format_float_from_string(s):
    try:
        float_value = float(s)
        return f'{float_value:.2f}'
    except ValueError:
        return s

class MoleculeAbstract(Dataset):
    def __init__(self, 
            root,
            rxn_num=1000,
            rxn_batch_size=4,
            smi_max_len=128,
            prompt=None,
            disable_graph_cache=False,
            disable_graphs=False,
            context_style='weighted_rxn',
            use_caption_dataset=False,
            caption_batch_num=10000,
            synthesis_datasetpath=None,
            synthesis_batch_num=10000,
            reverse_ratio=0.5,
            enable_abstract=True,
            enable_property=True,
            smiles_type='default',
            mode='train'
        ):
        super(MoleculeAbstract, self).__init__(root)
        self.root = root
        self.rxn_num = rxn_num
        self.rxn_batch_size = rxn_batch_size
        self.smi_max_len = smi_max_len
        self.context_style = context_style
        self.tokenizer = None
        self.disable_graph_cache = disable_graph_cache
        self.disable_graphs = disable_graphs
        self.use_caption_dataset = use_caption_dataset
        self.smiles_type = smiles_type
        if use_caption_dataset:
            self.caption_dataset = PretrainCaptionDataset(
                os.path.join(root, '../caption_data'),
                smi_max_len=smi_max_len,
                use_graph=not self.disable_graphs,
                disable_graph_cache=disable_graph_cache,
                smiles_type=smiles_type,
            )
            self.caption_batch_num = caption_batch_num
        self.use_synthesis_dataset = bool(synthesis_datasetpath)
        if self.use_synthesis_dataset:
            self.synthesis_dataset = SynthesisDataset(
                synthesis_datasetpath,
                'train',
                smi_max_len,
                roundrobin_train=True,
                use_graph=not disable_graphs,
                disable_graph_cache=disable_graph_cache,
                smiles_type='default',
            )
            self.synthesis_batch_num = synthesis_batch_num
        if not self.disable_graphs:
            self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
        reaction_filename = 'reactions/reactions_test.json' if (mode=='test') else 'reactions/reactions.json'
        if smiles_type=='r_smiles':
            reaction_filename = 'reactions/reactions_wRSMILES.json'
        self.cluster = Reaction_Cluster(self.root, reaction_filename=reaction_filename, reverse_ratio=reverse_ratio)
        self.reload_data_list()
        self.abstract_max_len = 10240
        self.property_max_len = 10240
        self.enable_abstract = enable_abstract
        self.enable_property = enable_property

    def get(self, index):
        return self.__getitem__(index)

    def len(self):
        return len(self)

    def __len__(self):
        data_len = len(self.data_list)
        if self.use_caption_dataset:
            data_len += len(self.caption_index_list)
        if self.use_synthesis_dataset:
            data_len += len(self.synthesis_index_list)
        return data_len
    
    def reload_data_list(self):
        k = self.rxn_batch_size
        if self.context_style == 'weighted_rxn':
            self.data_list = self.cluster(self.rxn_num, k=k)
        elif self.context_style == 'uniform_rxn':
            self.data_list = self.cluster.generate_batch_uniform_rxn(self.rxn_num, k=k)
        elif self.context_style == 'uniform_mol':
            self.data_list = self.cluster.generate_batch_uniform_mol(self.rxn_num, k=k)
        elif self.context_style == 'single_mol':
            self.data_list = self.cluster.generate_batch_single(self.rxn_num)
        elif self.context_style == 'hybrid':
            self.data_list = self.cluster(self.rxn_num//2, k=k)
            self.data_list += self.cluster.generate_batch_uniform_mol(self.rxn_num//2, k=k)
        else:
            raise NotImplementedError
        if self.use_caption_dataset:
            assert self.caption_batch_num*k <= len(self.caption_dataset)
            caption_index_list = random.sample(range(len(self.caption_dataset)), self.caption_batch_num*k)
            self.caption_index_list = [caption_index_list[i*k:(i+1)*k] for i in range(self.caption_batch_num)]
        else:
            self.caption_index_list = []
        if self.use_synthesis_dataset:
            if self.synthesis_dataset.roundrobin_train:
                self.synthesis_dataset.reload_data()
            assert self.synthesis_batch_num <= len(self.synthesis_dataset)
            self.synthesis_index_list = random.sample(range(len(self.synthesis_dataset)), self.synthesis_batch_num)
        else:
            self.synthesis_index_list = []

    def make_prompt(self, mol_batch, smi_max_len=128):
        mol_prompt_list, text_prompt_list = [], []
        last_role = None
        for mol_dict in mol_batch:
            smiles = mol_dict['canon_smiles']
            if self.smiles_type=='r_smiles':
                if 'r_smiles' in mol_dict:
                    smiles = mol_dict['r_smiles']
                # else:
                #     smiles = reformat_smiles(smiles, smiles_type='restricted')
            else:
                smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
            mol_prompt = f'[START_SMILES]{smiles[:smi_max_len]}[END_SMILES]. '
            if 'role' in mol_dict:
                role = {
                    'REACTANT': 'Reactant',
                    'CATALYST': 'Catalyst',
                    'SOLVENT': 'Solvent',
                    'PRODUCT': 'Product',
                }[mol_dict['role']]
                if last_role != role:
                    mol_prompt = f'{role}: {mol_prompt}'
                    last_role = role
            text_prompt = self.make_abstract(mol_dict)
            mol_prompt_list.append(mol_prompt)
            text_prompt_list.append(text_prompt)
        return mol_prompt_list, text_prompt_list

    def make_abstract(self, mol_dict):
        prompt = ''
        if self.enable_abstract and 'abstract' in mol_dict:
            abstract_string = mol_dict['abstract'][:self.abstract_max_len]
            prompt += f'[Abstract] {abstract_string} '

        if self.enable_property:
            property_string = ''
            property_dict = mol_dict['property'] if 'property' in mol_dict else {}
            for property_key in ['Experimental Properties', 'Computed Properties']:
                if not property_key in property_dict:
                    continue
                for key, value in property_dict[property_key].items():
                    if isinstance(value, float):
                        key_value_string = f'{key}: {value:.2f}; '
                    elif isinstance(value, str):
                        float_value = format_float_from_string(value)
                        key_value_string = f'{key}: {float_value}; '
                    else:
                        key_value_string = f'{key}: {value}; '
                    if len(property_string+key_value_string) > self.property_max_len:
                        break
                    property_string += key_value_string
            if property_string:
                property_string = property_string[:self.property_max_len]
                prompt += f'[Properties] {property_string}. '
        return prompt

    def get_caption_data(self, index):
        caption_index = self.caption_index_list[index]
        graph_list, mol_prompt_list, text_prompt_list = [], [], []
        for idx in caption_index:
            graph_item, text, smiles_prompt = self.caption_dataset[idx]
            graph_list.append(graph_item)
            mol_prompt_list.append(smiles_prompt)
            text_prompt_list.append(text)

        return graph_list, mol_prompt_list, text_prompt_list
    
    def get_synthesis_data(self, index):
        synthesis_index = self.synthesis_index_list[index]
        _, graph_list, output_text, input_text = self.synthesis_dataset[synthesis_index]
        return graph_list, [input_text], [output_text]

    def __getitem__(self, index):
        if index < len(self.data_list):
            mol_batch = self.data_list[index]
        elif index < len(self.data_list)+len(self.caption_index_list):
            assert self.use_caption_dataset
            return self.get_caption_data(index-len(self.data_list))
        else:
            assert self.use_synthesis_dataset
            return self.get_synthesis_data(index-(len(self.data_list)+len(self.caption_index_list)))

        graph_list = []
        for mol_dict in mol_batch:
            smiles = mol_dict['canon_smiles']
            if self.disable_graphs:
                graph_item = None
            else:
                if self.disable_graph_cache:
                    graph_item = smiles2data(smiles)
                else:
                    assert smiles in self.mol_graph_map
                    graph_item = self.mol_graph_map[smiles]
            graph_list.append(graph_item)
        mol_prompt_list, text_prompt_list = self.make_prompt(mol_batch, smi_max_len=self.smi_max_len)

        return graph_list, mol_prompt_list, text_prompt_list