import uuid
import gradio as gr
import torch
import os
import pandas as pd
from rdkit import Chem
from scripts.pla_net_inference import main
from utils.args import ArgsInit

os.system("nvidia-smi")
print("TORCH_CUDA", torch.cuda.is_available())

PROJECT_URL = "https://www.nature.com/articles/s41598-022-12180-x"

DEFAULT_PATH_DOCKER = "/home/user/app"

ENABLED_MODELS = [
    'aa2ar', 'abl1', 'ace', 'aces', 'ada', 'ada17', 'adrb1', 'adrb2',
    'akt1', 'akt2', 'aldr', 'ampc', 'andr', 'aofb', 'bace1', 'braf',
    'cah2', 'casp3', 'cdk2', 'comt', 'cp2c9', 'cp3a4', 'csf1r',
    'cxcr4', 'def', 'dhi1', 'dpp4', 'drd3', 'dyr', 'egfr', 'esr1',
    'esr2', 'fa10', 'fa7', 'fabp4', 'fak1', 'fgfr1', 'fkb1a', 'fnta',
    'fpps', 'gcr', 'glcm', 'gria2', 'grik1', 'hdac2', 'hdac8',
    'hivint', 'hivpr', 'hivrt', 'hmdh', 'hs90a', 'hxk4', 'igf1r',
    'inha', 'ital', 'jak2', 'kif11', 'kit', 'kith', 'kpcb', 'lck',
    'lkha4', 'mapk2', 'mcr', 'met', 'mk01', 'mk10', 'mk14', 'mmp13',
    'mp2k1', 'nos1', 'nram', 'pa2ga', 'parp1', 'pde5a', 'pgh1', 'pgh2',
    'plk1', 'pnph', 'ppara', 'ppard', 'pparg', 'prgr', 'ptn1', 'pur2',
    'pygm', 'pyrd', 'reni', 'rock1', 'rxra', 'sahh', 'src', 'tgfr1',
    'thb', 'thrb', 'try1', 'tryb1', 'tysy', 'urok', 'vgfr2', 'wee1',
    'xiap'
]

def load_and_filter_data(protein_id, ligand_smiles):
    
    # generate random short id, make short
    random_id = str(uuid.uuid4())[:8]
    
    print("Inference ID: ", random_id)
    
    # check that ligand_smiles is not empty
    if not ligand_smiles or ligand_smiles.strip() == "":
        error_msg = f"!SMILES string is required 💥"
        raise gr.Error(error_msg, duration=5)
    
    
    if protein_id not in ENABLED_MODELS:
        error_msg = f"!Invalid 💥 target protein ID, the available options are: {ENABLED_MODELS}. To do inference other proteins, you can run the model locally an train the model for each target protein."
        raise gr.Error(error_msg, duration=5)
    
    # Split the input SMILES string by ':' to get a list
    smiles_list = ligand_smiles.split(':')
    
    
    
    print("Smiles to predict: ", smiles_list)
    print("Target Protein ID: ", protein_id)
    
    # Validate SMILES
    invalid_smiles = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles.strip())
        if mol is None:
            invalid_smiles.append(smiles.strip())
            

    
    if invalid_smiles:
        error_msg = f"!Invalid 💥 SMILES string(s) : {', '.join(invalid_smiles)}"
        raise gr.Error(error_msg, duration=5)
    
        # Create tmp folder
    os.makedirs(f"{DEFAULT_PATH_DOCKER}/example/tmp", exist_ok=True)
    
    # Save SMILES to CSV
    df = pd.DataFrame({"smiles": [s.strip() for s in smiles_list if s.strip()]})
    df.to_csv(f"{DEFAULT_PATH_DOCKER}/example/tmp/{random_id}_input_smiles.csv", index=False)
    
    # Run inference
    args = ArgsInit().args
    args.nclasses = 2
    args.batch_size = 10
    args.use_prot = True
    args.freeze_molecule = True
    args.conv_encode_edge = True
    args.learn_t = True
    args.binary = True
    
    args.use_gpu = True
    args.target = protein_id
    args.target_list = f"{DEFAULT_PATH_DOCKER}/data/datasets/AD/Targets_Fasta.csv"
    args.target_checkpoint_path = f"{DEFAULT_PATH_DOCKER}/checkpoints/PLA-Net/BINARY_{protein_id}"
    args.input_file_smiles = f"{DEFAULT_PATH_DOCKER}/example/tmp/{random_id}_input_smiles.csv"
    args.output_file = f"{DEFAULT_PATH_DOCKER}/example/tmp/{random_id}_output_predictions.csv"
    
   
    print("Args: ", args)
    main(args)
    
    # Load the CSV file
    df = pd.read_csv(f'{DEFAULT_PATH_DOCKER}/example/tmp/{random_id}_output_predictions.csv')
    
    print("Prediction Results output: ", df)
    return df

def load_description(fp):
    with open(fp, 'r', encoding='utf-8') as f:
        content = f.read()
    return content

def run_inference(protein_id, ligand_smile):
    result_df = load_and_filter_data(protein_id, ligand_smile)
    return result_df

def create_interface():
    with gr.Blocks(title="PLA-Net Web Inference") as inference:
        gr.HTML(load_description("gradio/title.md"))

        gr.Markdown("### Input")
        with gr.Row():
            with gr.Column():
                gr.Markdown("#### Target Protein")
                protein_id = gr.Dropdown(
                    choices=ENABLED_MODELS,
                    label="Target Protein ID",
                    info="Select the target protein from the dropdown menu.",
                    value="ada"
                )
                gr.Markdown(" Check the available target proteins [here](https://github.com/juliocesar-io/PLA-Net/blob/main/data/targets.md). The corresponding protein sequences are available in [here](https://github.com/juliocesar-io/PLA-Net/blob/main/data/datasets/AD/Targets_Fasta.csv).")
            with gr.Column():
                gr.Markdown("#### Ligand")
                ligand_smile = gr.Textbox(
                    info="Provide SMILES input (separate multiple SMILES with ':' )",
                    placeholder="SMILES input",
                    label="SMILES string(s)",
                )
                gr.Examples(
                    examples=[
                        "Cn4c(CCC(=O)Nc3ccc2ccn(CC[C@H](CO)n1cnc(C(N)=O)c1)c2c3)nc5ccccc45",
                        "OCCCCCn1cnc2C(O)CN=CNc12",
                        "Nc4nc(c1ccco1)c3ncn(C(=O)NCCc2ccccc2)c3n4"
                    ],
                    inputs=ligand_smile,
                    label="Example SMILES"
                )
        btn = gr.Button("Run")
        gr.Markdown("### Output")   
        out = gr.Dataframe(
            headers=["target", "smiles", "interaction_probability", "interaction_class"],
            datatype=["str", "str", "number", "number"],
            label="Prediction Results"
        )
        
        btn.click(fn=run_inference, inputs=[protein_id, ligand_smile], outputs=out)
        
        gr.Markdown("""
                    PLA-Net model for predicting interactions
                    between small organic molecules and one of the 102 target proteins in the AD dataset. Graph representations
                    of the molecule and a given target protein are generated from SMILES and FASTA sequences and are used as
                    input to the Ligand Module (LM) and Protein Module (PM), respectively. Each module comprises a deep GCN
                    followed by an average pooling layer, which extracts relevant features of their corresponding input graph. Both
                    representations are finally concatenated and combined through a fully connected layer to predict the target–
                    ligand interaction probability.
                    """)
        
        gr.Markdown("""
        Ruiz Puentes, P., Rueda-Gensini, L., Valderrama, N. et al. 
        Predicting target–ligand interactions with graph convolutional networks 
        for interpretable pharmaceutical discovery. Sci Rep 12, 8434 (2022). 
        [https://doi.org/10.1038/s41598-022-12180-x](https://doi.org/10.1038/s41598-022-12180-x)
        """)
    
    return inference

if __name__ == "__main__":
    interface = create_interface()
    interface.launch()