from contextlib import contextmanager

import gradio as gr
from markup import get_text, highlight
from template import get_templates


templates = get_templates()


def fill_tab(title, explanation):
    """
    Fill the tab with the appropriate title and explanation.
    """
    return gr.Markdown(title), gr.Markdown(explanation)


@contextmanager
def new_section():
    """
    A context manager to create a new section in the interface. Equivalent of:
    ```python
    with gr.Row():
        with gr.Column():
            ...
    ```
    """
    with gr.Row():
        with gr.Column():
            yield


def change(inp, textbox):
    """Based on an `inp`, render and highlight the appropriate code sample.

    Args:
        inp (`str`):
            The input button from the interface.
        textbox (`str`):
            The textbox specifying the tab name from the interface.

    Returns:
        `tuple`: A tuple of the highlighted code diff, and the title for the section.
    """
    if textbox == "base":
        code, explanation, docs = get_text(inp, textbox)
        if inp == "Basic":
            return (
                highlight(code),
                "## Accelerate Code (Base Integration)",
                explanation,
                docs,
            )
        elif inp == "Calculating Metrics":
            return (highlight(code), f"## Accelerate Code ({inp})", explanation, docs)
        else:
            return (highlight(code), f"## Accelerate Code ({inp})", explanation, docs)
    elif textbox == "training_configuration":
        yaml, changes, command, explanation, docs = get_text(inp, textbox)
        return (highlight(yaml), highlight(changes), command, explanation, docs)
    else:
        raise ValueError(f"Invalid tab name: {textbox}")


default_base = change("Basic", "base")
default_training_config = change("Multi GPU", "training_configuration")


def base_features(textbox):
    inp = gr.Radio(
        [
            "Basic",
            "Calculating Metrics",
            "Checkpointing",
            "Experiment Tracking",
            "Gradient Accumulation",
        ],
        label="Select a feature you would like to integrate",
        value="Basic",
    )
    with new_section():
        feature, out = fill_tab("## Accelerate Code", default_base[0])
    with new_section():
        _, explanation = fill_tab("## Explanation", default_base[2])
    with new_section():
        _, docs = fill_tab("## Documentation Links", default_base[3])
    inp.change(
        fn=change, inputs=[inp, textbox], outputs=[out, feature, explanation, docs]
    )


def training_config(textbox):
    inp = gr.Radio(
        [
            "AWS SageMaker",
            "DeepSpeed",
            "Megatron-LM",
            "Multi GPU",
            "Multi Node Multi GPU",
            "PyTorch FSDP",
        ],
        label="Select a distributed YAML configuration you would like to view.",
        value="Multi GPU",
    )
    with new_section():
        _, yaml = fill_tab("## Example YAML Configuration", default_training_config[0])
    with new_section():
        _, changes = fill_tab(
            "## Changes to Training Script", default_training_config[1]
        )
    with new_section():
        _, command = fill_tab("## Command to Run Training", default_training_config[2])
    with new_section():
        _, explanation = fill_tab("## Explanation", default_training_config[3])
    with new_section():
        _, docs = fill_tab("## Documentation Links", default_training_config[4])
    inp.change(
        fn=change,
        inputs=[inp, textbox],
        outputs=[yaml, changes, command, explanation, docs],
    )


# def big_model_inference():
#     inp = gr.Radio(
#         ["Accelerate's Big Model Inference",], # "DeepSpeed ZeRO Stage-3 Offload"
#         label="Select a feature you would like to integrate",
#         value="Basic",
#     )
#     with gr.Row():
#         with gr.Column():
#             feature = gr.Markdown("## Accelerate Code")
#             out = gr.Markdown(default[0])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown(default[1])
#             explanation = gr.Markdown(default[2])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown("## Documentation Links")
#             docs = gr.Markdown(default[3])
#     inp.change(fn=change, inputs=[inp, "big_model_inference"], outputs=[out, feature, explanation, docs])


# def notebook_launcher():
#     inp = gr.Radio(
#         ["Colab GPU", "Colab TPU", "Kaggle GPU", "Kaggle Multi GPU", "Kaggle TPU", "Multi GPU VMs"],
#         label="Select a feature you would like to integrate",
#         value="Basic",
#     )
#     with gr.Row():
#         with gr.Column():
#             feature = gr.Markdown("## Accelerate Code")
#             out = gr.Markdown(default[0])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown(default[1])
#             explanation = gr.Markdown(default[2])
#     with gr.Row():
#         with gr.Column():
#             gr.Markdown("## Documentation Links")
#             docs = gr.Markdown(default[3])
#     inp.change(fn=change, inputs=[inp, "notebook_launcher"], outputs=[out, feature, explanation, docs])


with gr.Blocks() as demo:

    with gr.Tabs():
        with gr.TabItem("Basic Training Integration"):
            textbox = gr.Textbox(label="tab_name", visible=False, value="base")
            base_features(textbox)
        with gr.TabItem("Launch Configuration"):
            textbox = gr.Textbox(
                label="tab_name", visible=False, value="training_configuration"
            )
            training_config(textbox)

demo.launch()