import time

import gradio as gr

from gradio_molecule3d import Molecule3D




def predict (input_sequence, input_ligand,input_msa, input_protein):
    start_time = time.time()
    # Do inference here
    # return an output pdb file with the protein and ligand with resname LIG or UNK. 
    # also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users
    metrics = {"mean_plddt": 80, "binding_affinity": -2}
    end_time = time.time()
    run_time = end_time - start_time
    return ["test_out.pdb", "test_docking_pose.sdf"], metrics, run_time

with gr.Blocks() as app:

    gr.Markdown("# Template for inference")

    gr.Markdown("Title, description, and other information about the model")   
    with gr.Row():
        input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
        input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES")
    with gr.Row():
        input_msa = gr.File(label="Input Protein MSA (A3M)")
        input_protein = gr.File(label="Input protein monomer")
        
    
    # define any options here

    # for automated inference the default options are used
    # slider_option = gr.Slider(0,10, label="Slider Option")
    # checkbox_option = gr.Checkbox(label="Checkbox Option")
    # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")

    btn = gr.Button("Run Inference")

    gr.Examples(
        [
            [
                "SVKSEYAEAAAVGQEAVAVFNTMKAAFQNGDKEAVAQYLARLASLYTRHEELLNRILEKARREGNKEAVTLMNEFTATFQTGKSIFNAMVAAFKNGDDDSFESYLQALEKVTAKGETLADQIAKAL:SVKSEYAEAAAVGQEAVAVFNTMKAAFQNGDKEAVAQYLARLASLYTRHEELLNRILEKARREGNKEAVTLMNEFTATFQTGKSIFNAMVAAFKNGDDDSFESYLQALEKVTAKGETLADQIAKAL",
                "COc1ccc(cc1)n2c3c(c(n2)C(=O)N)CCN(C3=O)c4ccc(cc4)N5CCCCC5=O",
                "test_out.pdb"
            ],
        ],
        [input_sequence, input_ligand, input_protein],
    )
    reps =    [
    {
      "model": 0,
      "style": "cartoon",
      "color": "whiteCarbon",
    },
        {
      "model": 1,
      "style": "stick",
      "color": "greenCarbon",
    }
        
  ]
    
    out = Molecule3D(reps=reps)
    metrics = gr.JSON(label="Metrics")
    run_time = gr.Textbox(label="Runtime")

    btn.click(predict, inputs=[input_sequence, input_ligand, input_msa, input_protein], outputs=[out,metrics, run_time])

app.launch()