import gradio as gr
import yaml
from gradio_huggingfacehub_search import HuggingfaceHubSearch

MARKDOWN_DESCRIPTION = """
# mergekit config.yaml generator

GUI to template a YAML configuration file for mergekit, which you can then copy/paste into [mergekit-gui](https://huggingface.co/spaces/arcee-ai/mergekit-gui) 🔥
"""

DEFAULT_PARAMETERS = """
t:
- filter: self_attn
  value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
  value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5
"""


def create_config_yaml(
    model1,
    model1_layers,
    model2,
    model2_layers,
    merge_method,
    base_model,
    parameters,
    dtype,
) -> str:
    dict_config = {
        "slices": [
            {
                "sources": [
                    {"model": model1, "layer_range": yaml.safe_load(model1_layers)},
                    {"model": model2, "layer_range": yaml.safe_load(model2_layers)},
                ]
            }
        ],
        "merge_method": merge_method,
        "base_model": base_model,
    }

    if parameters:
        dict_config["parameters"] = yaml.safe_load(parameters)
    if dtype:
        dict_config["dtype"] = dtype

    return yaml.dump(dict_config, sort_keys=False)


# make sure to add the themes as well
with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN_DESCRIPTION)
    with gr.Row():
        # model_name_input = gr.Textbox(label="Model Name", value="my-merge")
        model1_input = HuggingfaceHubSearch(
                label="Model 1",
                placeholder="Search for model 1 on Huggingface",
                search_type="model",
                value="BioMistral/BioMistral-7B"
        )
        model1_layers_input = gr.Textbox(
            label="Model 1 Layer Range", placeholder="[start, end]", value="[0, 32]"
        )
        model2_input = HuggingfaceHubSearch(
                label="Model 2",
                placeholder="Search for model 2 on Huggingface",
                search_type="model",
                value="CorticalStack/pastiche-crown-clown-7b-dare-dpo"
        )
        model2_layers_input = gr.Textbox(
            label="Model 2 Layer Range", placeholder="[start, end]", value="[0, 32]"
        )
    merge_method_input = gr.Dropdown(
        label="Merge Method", choices=["slerp", "linear"], value="slerp"
    )
    base_model_input = gr.Textbox(label="Base Model", value="BioMistral/BioMistral-7B")
    parameters_input = gr.Code(
        language="yaml",
        label="Merge Parameters",
        value=DEFAULT_PARAMETERS,
    )
    dtype_input = gr.Textbox(label="Dtype", value="bfloat16")

    create_button = gr.Button("Create config.yaml", variant="primary")

    output_zone = gr.Code(language="yaml", lines=10)

    create_button.click(
        fn=create_config_yaml,
        inputs=[
            model1_input,
            model1_layers_input,
            model2_input,
            model2_layers_input,
            merge_method_input,
            base_model_input,
            parameters_input,
            dtype_input,
        ],
        outputs=[output_zone],
    )

    gr.Markdown("A Space by [1littlecoder](https://huggingface.co/1littlecoder)")

demo.launch()