import time
import traceback

import gradio as gr

from gradio_molecule3d import Molecule3D
from run_on_seq import run_on_sample_seqs
from env_consts import RUN_CONFIG_PATH, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH, MODEL_NAME_TO_CKPT


def predict(input_sequence, input_ligand, input_protein, model_variation):
    print("Strating inference!!!!!!!!!!!!!!!!!", input_sequence, input_ligand, input_protein)
    start_time = time.time()
    try:
        ckpt_path = MODEL_NAME_TO_CKPT[model_variation]
        metrics = run_on_sample_seqs(input_sequence, input_protein, input_ligand, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH,
                                     RUN_CONFIG_PATH, ckpt_path)
        end_time = time.time()
        run_time = end_time - start_time
        return [OUTPUT_PROT_PATH, OUTPUT_LIG_PATH], metrics, run_time
    except Exception as e:
        print(f"Error during inference: {e}")
        traceback.print_exc()  # Print the full traceback
        return None, {"error": str(e)}, "Error occurred"  # return error message to the output.


with gr.Blocks() as app:
    print("Starting app!!!!")
    gr.Markdown("DockFormer")

    model_variation = gr.Dropdown(
        choices=["DockFormer-Screen", "DockFormer-PDBBind", "DockFormer-PLINDER"],
        label="Select model variation",
        value="DockFormer-Screen"  # Default value
    )

    # gr.Markdown("Title, description, and other information about the model")
    with gr.Row():
        input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
        input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES")
    with gr.Row():
        # input_msa = gr.File(label="Input Protein MSA (A3M)")
        input_protein = gr.File(label="Input protein monomer")

    # define any options here

    # for automated inference the default options are used
    # slider_option = gr.Slider(0,10, label="Slider Option")
    # checkbox_option = gr.Checkbox(label="Checkbox Option")
    # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")

    btn = gr.Button("Run Inference")

    gr.Examples(
        [
            [
                "MLLLPLPLLLFLLCSRAEAGEIIGGTESKPHSRPYMAYLEIVTSNGPSKFCGGFLIRRNFVLTAAHCAGRSITVTLGAHNITEEEDTWQKLEVIKQFRHPKYNTSTLHHDIMLLKLKEKASLTLAVGTLPFPSQKNFVPPGRMCRVAGWGRTGVLKPGSDTLQEVKLRLMDPQACSHFRDFDHNLQLCVGNPRKTKSAFKGDSGGPLLCAGVAQGIVSYGRSDAKPPAVFTRISHYRPWINQILQAN",
                "[nH]1c5c(c(c1C(c2ccccc2)C3=C(C(CC3=O)c4ccccc4)O)CCNC(=O)C)ccc(c5)CC",
                "resources/example/L1001.pdb"
            ],
        ],
        [input_sequence, input_ligand, input_protein],
    )
    reps = [
        {
            "model": 0,
            "style": "cartoon",
            "color": "whiteCarbon",
        },
        {
            "model": 1,
            "style": "stick",
            "color": "greenCarbon",
        }

    ]

    out = Molecule3D(reps=reps)
    metrics = gr.JSON(label="Metrics")
    run_time = gr.Textbox(label="Runtime")

    btn.click(predict, inputs=[input_sequence, input_ligand, input_protein, model_variation],
              outputs=[out, metrics, run_time])

app.launch()