# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from pytorch_lightning import LightningDataModule
import torch_geometric
# from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader
from torch_geometric.loader.dataloader import Collater
from data_provider.molecule_abstract_dataset import MoleculeAbstract
import re
from transformers import BatchEncoding

# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")

# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"

def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].

    Parameters
    ----------
    n : str
        Input text to split

    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"


def smiles_handler(text, mol_ph, is_gal=True):
    smiles_list = []
    for match in CUSTOM_SEQ_RE.finditer(text):
        smiles = match.group(3)
        smiles_list.append(smiles)
    if is_gal:
        text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
        text = escape_custom_split_sequence(text)
        return text, smiles_list
    else:
        text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text)
        return text, smiles_list


def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization

    Parameters
    ----------
    text : str
        Input text to split

    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)


def tokenize_and_merge_batched_qa_pairs(tokenizer, qa_pairs_list, max_length):
    tokenized_batches = {
        'input_ids': [],
        'attention_mask': []
    }
    for qa_pairs in qa_pairs_list:
        max_length_per_qa = max_length // len(qa_pairs)
        batch_input_ids = []
        batch_attention_mask = []
        for qa in qa_pairs:
            # here qa should be string
            tokens = tokenizer(qa,
                            truncation=True,
                            padding=False,
                            add_special_tokens=False,
                            max_length=max_length_per_qa,
                            return_tensors='pt',
                            return_attention_mask=True)
            batch_input_ids.extend(tokens['input_ids'].squeeze().tolist())
            batch_attention_mask.extend(tokens['attention_mask'].squeeze().tolist())

        # Pad the batch to max_length
        padding_length = max_length - len(batch_input_ids)
        batch_input_ids.extend([tokenizer.pad_token_id] * padding_length)
        batch_attention_mask.extend([0] * padding_length)

        tokenized_batches['input_ids'].append(torch.tensor(batch_input_ids).unsqueeze(0))
        tokenized_batches['attention_mask'].append(torch.tensor(batch_attention_mask).unsqueeze(0))

    tokenized_batches['input_ids'] = torch.cat(tokenized_batches['input_ids'], dim=0)
    tokenized_batches['attention_mask'] = torch.cat(tokenized_batches['attention_mask'], dim=0)

    tokenized_batch = BatchEncoding(data=tokenized_batches, tensor_type='pt')
    return tokenized_batch

class TrainCollater:
    def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False):
        self.text_max_len = text_max_len
        self.tokenizer = tokenizer
        self.collater = Collater([], [])
        self.mol_ph = mol_ph
        self.mol_token_id = mol_token_id
        self.is_gal = is_gal
        self.disable_graphs = disable_graphs

    def __call__(self, batch):
        graphs, mol_prompt, text_prompt = zip(*batch)
        if not self.disable_graphs:
            graphs = [graph for graph_batch in graphs for graph in graph_batch]
            graphs = self.collater(graphs)

        qa_pairs = []
        for mol_batch, text_batch in zip(mol_prompt, text_prompt):
            qa_list = []
            for mol_prompt, text_prompt in zip(mol_batch, text_batch):
                smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
                qa_list.append(f'{smiles_prompt} {text_prompt}')
            qa_pairs.append(qa_list)

        self.tokenizer.padding_side = 'right'
        qa_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, qa_pairs, self.text_max_len)

        is_mol_token = qa_batch.input_ids == self.mol_token_id
        qa_batch['is_mol_token'] = is_mol_token

        return graphs, qa_batch

class InferenceCollater:
    def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False, last_only=False):
        self.text_max_len = text_max_len
        self.tokenizer = tokenizer
        self.collater = Collater([], [])
        self.mol_ph = mol_ph
        self.mol_token_id = mol_token_id
        self.is_gal = is_gal
        self.disable_graphs = disable_graphs
        self.last_only = last_only

    def __call__(self, batch):
        graphs, mol_prompt, text_prompt = zip(*batch)
        rxn_ids = [0 for i in range(len(mol_prompt))]
        if self.last_only:
            mol_prompt = [[mol_batch[-1]] for mol_batch in mol_prompt]
            text_prompt = [[text_batch[-1]] for text_batch in text_prompt]
            graphs = [[graph_batch[-1]] for graph_batch in graphs]
        if not self.disable_graphs:
            graphs = [graph for graph_batch in graphs for graph in graph_batch]
            graphs = self.collater(graphs)

        input_text, output_text = [], []
        for mol_batch, text_batch in zip(mol_prompt, text_prompt):
            qa_list = []
            for mol_prompt, text_prompt in list(zip(mol_batch, text_batch))[:-1]:
                smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
                qa_list.append(f'{smiles_prompt} {text_prompt}')
            qa_list.append(f'{smiles_handler(mol_batch[-1], self.mol_ph, self.is_gal)[0]} ')
            output_text.append(text_batch[-1])
            input_text.append(qa_list)

        self.tokenizer.padding_side = 'right'
        input_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, input_text, self.text_max_len)

        is_mol_token = input_batch.input_ids == self.mol_token_id
        input_batch['is_mol_token'] = is_mol_token
        
        return rxn_ids, graphs, input_batch, output_text, input_text


class PretrainDM(LightningDataModule):
    def __init__(
        self,
        num_workers: int = 0,
        batch_size: int = 256,
        root: str = 'data/',
        text_max_len: int = 128,
        rxn_max_len: int = 128,
        smi_max_len: int = 128,
        tokenizer=None,
        args=None,
    ):
        super().__init__()
        self.args = args
        self.batch_size = batch_size
        self.inference_batch_size = args.inference_batch_size
        self.num_workers = num_workers
        self.text_max_len = text_max_len
        self.rxn_max_len = rxn_max_len
        self.pretrain_dataset = MoleculeAbstract(
            root,
            rxn_num=args.pretrain_rxn_num,
            rxn_batch_size=args.rxn_batch_size,
            smi_max_len=smi_max_len,
            disable_graph_cache=args.disable_graph_cache,
            context_style=args.context_style,
            disable_graphs=args.disable_graphs,
            use_caption_dataset=args.pretrain_use_caption,
            caption_batch_num=args.caption_batch_num,
            synthesis_datasetpath=args.pretrain_synthesis_path,
            synthesis_batch_num=args.synthesis_batch_num,
            reverse_ratio=args.reverse_ratio,
            enable_abstract=not args.disable_abstract,
            enable_property=not args.disable_property,
            smiles_type=args.smiles_type,
        )
        self.test_dataset = MoleculeAbstract(
            root,
            rxn_num=args.pretrain_rxn_num,
            rxn_batch_size=args.rxn_batch_size,
            smi_max_len=smi_max_len,
            disable_graph_cache=args.disable_graph_cache,
            context_style=args.context_style,
            disable_graphs=args.disable_graphs,
            use_caption_dataset=args.pretrain_use_caption,
            caption_batch_num=args.caption_batch_num,
            reverse_ratio=args.reverse_ratio,
            enable_abstract=not args.disable_abstract,
            enable_property=not args.disable_property,
            smiles_type=args.smiles_type,
            mode='test',
        )
        self.init_tokenizer(tokenizer)
        self.mol_ph_token = '<mol>' * self.args.num_query_token
        self.is_gal = args.opt_model.find('galactica') >= 0
        self.disable_graphs = args.disable_graphs
        self.last_only = args.pretrain_eval_last_only

    def init_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer
        self.pretrain_dataset.tokenizer = tokenizer
        self.test_dataset.tokenizer = tokenizer
        self.mol_token_id = self.tokenizer.mol_token_id
        # self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]

    def train_dataloader(self):
        self.pretrain_dataset.reload_data_list()
        loader = DataLoader(
            self.pretrain_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=False,
            drop_last=True,
            persistent_workers=True,
            collate_fn=TrainCollater(
                tokenizer=self.tokenizer,
                text_max_len=self.text_max_len,
                mol_ph=self.mol_ph_token,
                mol_token_id=self.mol_token_id,
                is_gal=self.is_gal,
                disable_graphs=self.disable_graphs,
            ),
        )
        return loader
    def val_dataloader(self):
        test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.inference_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=False,
            drop_last=False,
            persistent_workers=True,
            collate_fn=InferenceCollater(
                tokenizer=self.tokenizer, 
                text_max_len=self.text_max_len,
                mol_ph=self.mol_ph_token, 
                mol_token_id=self.mol_token_id, 
                is_gal=self.is_gal,
                disable_graphs=self.disable_graphs,
                last_only=self.last_only,
            ),
        )
        return [test_loader]

    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("Data module")
        parser.add_argument('--num_workers', type=int, default=2)
        parser.add_argument('--batch_size', type=int, default=4)
        parser.add_argument('--inference_batch_size', type=int, default=4)
        parser.add_argument('--use_smiles', action='store_true', default=False)
        parser.add_argument('--root', type=str, default='data/action_data')
        parser.add_argument('--context_style', type=str, default='weighted_rxn', choices=['weighted_rxn', 'uniform_rxn', 'uniform_mol', 'single_mol', 'hybrid'])
        parser.add_argument('--rxn_max_len', type=int, default=512)
        parser.add_argument('--text_max_len', type=int, default=512)
        parser.add_argument('--smi_max_len', type=int, default=128)
        parser.add_argument('--pretrain_rxn_num', type=int, default=50000)
        parser.add_argument('--reverse_ratio', type=float, default=0.5, help='ratio of reversed reactions (retro reactions)')
        parser.add_argument('--disable_abstract', action='store_true', default=False)
        parser.add_argument('--disable_property', action='store_true', default=False)
        parser.add_argument('--pretrain_use_caption', action='store_true', default=False)
        parser.add_argument('--caption_batch_num', type=int, default=5000)
        parser.add_argument('--pretrain_synthesis_path', type=str, default=None)
        parser.add_argument('--synthesis_batch_num', type=int, default=5000)
        parser.add_argument('--rxn_batch_size', type=int, default=4)
        parser.add_argument('--roundrobin_train', action='store_true', default=False)
        parser.add_argument('--test_subset', type=int, default=-1)
        parser.add_argument('--pretrain_eval_last_only', default=False, action='store_true')
        parser.add_argument('--prompt', type=str, default=None)
        return parent_parser