import time
import json
import requests

import gradio as gr

STYLE = """
.no-border {
    border: none !important;
}

.group-border {
  padding: 10px;
  border-width: 1px;
  border-radius: 10px;
  border-color: gray;
  border-style: solid;
  box-shadow: 1px 1px 3px;
}
.control-label-font {
  font-size: 13pt !important;
}
.control-button {
  background: none !important;
  border-color: #69ade2 !important;
  border-width: 2px !important;
  color: #69ade2 !important;
}
.center {
  text-align: center;
}
.right {
  text-align: right;
}
.no-label {
  padding: 0px !important;
}
.no-label > label > span {
  display: none;
}
.small-big {
  font-size: 12pt !important;
}

"""

def avaliable_providers():
    providers = []

    headers = {
        "Content-Type": "application/json",
    }
    endpoint_url = "https://api.endpoints.huggingface.cloud/v2/provider"
    response = requests.get(endpoint_url, headers=headers)

    providers = {}

    for provider in response.json()['vendors']:
        if provider['status'] == 'available':
            regions = {}

            availability = False
            for region in provider['regions']:
                if region["status"] == "available":
                    regions[region['name']] = {
                        "label": region['label'],
                        "computes": region['computes']
                    }
                    availability = True

            if availability:
                providers[provider['name']] = regions

    return providers

providers = avaliable_providers()

def update_regions(provider):
    avalialbe_regions = []
    regions = providers[provider]

    for region, attributes in regions.items():
        avalialbe_regions.append(f"{region}[{attributes['label']}]")

    return gr.Dropdown.update(
        choices=avalialbe_regions,
        value=avalialbe_regions[0] if len(avalialbe_regions) > 0 else None
    )

def update_compute_options(provider, region):
    avalialbe_compute_options = []
    computes = providers[provider][region.split("[")[0].strip()]["computes"]

    for compute in computes:
        if compute['status'] == 'available':
            accelerator = compute['accelerator']
            numAccelerators = compute['numAccelerators']
            memoryGb = compute['memoryGb']
            architecture = compute['architecture']
            instanceType = compute['instanceType']
            pricePerHour = compute['pricePerHour']

            type = f"{numAccelerators}vCPU {memoryGb} · {architecture}" if accelerator == "cpu" else f"{numAccelerators}x {architecture}"

            avalialbe_compute_options.append(
                f"{compute['accelerator'].upper()} [{compute['instanceSize']}] · {type} · {instanceType} · ${pricePerHour}/hour"
            )

    return gr.Dropdown.update(
        choices=avalialbe_compute_options,
        value=avalialbe_compute_options[0] if len(avalialbe_compute_options) > 0 else None
    )

def submit(
    hf_account_input,
    hf_token_input,
    endpoint_name_input,
    provider_selector,
    region_selector,
    repository_selector,
    task_selector,
    framework_selector,
    compute_selector,
    min_node_selector,
    max_node_selector,
    security_selector,
    custom_kernel,
    max_input_length,
    max_tokens,
    max_batch_prefill_token,
    max_batch_total_token    
):
    compute_resources = compute_selector.split("·")
    accelerator = compute_resources[0][:3].strip()

    size_l_index = compute_resources[0].index("[") - 1
    size_r_index = compute_resources[0].index("]")
    size = compute_resources[0][size_l_index : size_r_index].strip()

    type = compute_resources[-2].strip()

    payload = {
      "accountId": hf_account_input.strip(),
      "compute": {
        "accelerator": accelerator.lower(),
        "instanceSize": size[1:],
        "instanceType": type,
        "scaling": {
          "maxReplica": int(max_node_selector),
          "minReplica": int(min_node_selector)
        }
      },
      "model": {
        "framework": framework_selector.lower(),
        "image": {
          "custom": {
            "health_route": "/health",
            "env": {
                "DISABLE_CUSTOM_KERNELS": "true" if custom_kernel == "Enabled" else "false",
                "MAX_BATCH_PREFILL_TOKENS": str(max_batch_prefill_token),
                "MAX_BATCH_TOTAL_TOKENS": str(max_batch_total_token),
                "MAX_INPUT_LENGTH": str(max_input_length),
                "MAX_TOTAL_TOKENS": str(max_tokens),
                "MODEL_ID": repository_selector.lower(),
                # QUANTIZE: 'bitsandbytes' | 'gptq';
            },
            "url": "ghcr.io/huggingface/text-generation-inference:1.0.1",
          }
        },
        "repository": repository_selector.lower(),
        # "revision": "main",
        "task": task_selector.lower()
      },
      "name": endpoint_name_input.strip().lower(),
      "provider": {
        "region": region_selector.split("[")[0].lower(),
        "vendor": provider_selector.lower()
      },
      "type": security_selector.lower()
    }

    print(payload)

    payload = json.dumps(payload)
    print(payload)

    headers = {
        "Authorization": f"Bearer {hf_token_input.strip()}",
        "Content-Type": "application/json",
    }
    endpoint_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/"#{hf_account_input.strip()}"
    print(endpoint_url)

    response = requests.post(endpoint_url, headers=headers, data=payload)

    if response.status_code == 400:
        return f"{response.text}. Malformed data in {payload}"
    elif response.status_code == 401:
        return "Invalid token"
    elif response.status_code == 409:
        return f"Endpoint {endpoint_name_input} already exists"
    elif response.status_code == 202:
        return f"Endpoint {endpoint_name_input} created successfully on {provider_selector.lower()} using {repository_selector.lower()}@main.\nPlease check out the progress at https://ui.endpoints.huggingface.co/endpoints."
    else:
        return f"something went wrong {response.status_code} = {response.text}"

with gr.Blocks(css=STYLE) as hf_endpoint:
    with gr.Tab("Hugging Face", elem_classes=["no-border"]):
        gr.Markdown("# Deploy LLM on 🤗 Hugging Face Inference Endpoint", elem_classes=["center"])

        with gr.Column(elem_classes=["group-border"]):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("""### Hugging Face account ID (name)""")
                    hf_account_input = gr.Textbox(show_label=False, elem_classes=["no-label", "small-big"])

                with gr.Column():
                    gr.Markdown("### Hugging Face access token")
                    hf_token_input = gr.Textbox(show_label=False, type="password", elem_classes=["no-label", "small-big"])

            with gr.Row():
                with gr.Column():
                    gr.Markdown("""### Target model

Model from the Hugging Face hub""")
                    repository_selector = gr.Textbox(
                        value="NousResearch/Nous-Hermes-Llama2-13b",
                        interactive=False,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

                with gr.Column():
                    gr.Markdown("""### Target model version(branch)

Branch name of the Model""")
                    revision_selector = gr.Textbox(
                        value=f"main",
                        interactive=False,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

        with gr.Column(elem_classes=["group-border"]):
            with gr.Column():
                gr.Markdown("""### Endpoint name

Name for your new endpoint""")
                endpoint_name_input = gr.Textbox(show_label=False, elem_classes=["no-label", "small-big"])

            with gr.Row():
                with gr.Column():
                    gr.Markdown("""### Cloud Provider""")
                    provider_selector = gr.Dropdown(
                        choices=providers.keys(),
                        interactive=True,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

                with gr.Column():
                    gr.Markdown("""### Cloud Region""")
                    region_selector = gr.Dropdown(
                        [],
                        value="",
                        interactive=True,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

            with gr.Row(visible=False):
                with gr.Column():
                    gr.Markdown("### Task")
                    task_selector = gr.Textbox(
                        value="text-generation",
                        interactive=False,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

                with gr.Column():
                    gr.Markdown("### Framework")
                    framework_selector = gr.Textbox(
                        value="PyTorch",
                        interactive=False,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

            with gr.Column():
                gr.Markdown("""### Compute Instance Type""")
                compute_selector = gr.Dropdown(
                    [],
                    value="",
                    interactive=True,
                    show_label=False,
                    elem_classes=["no-label", "small-big"]
                )

            with gr.Row():
                with gr.Column():
                    gr.Markdown("""### Min Number of Nodes""")
                    min_node_selector = gr.Number(
                        value=1,
                        interactive=True,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

                with gr.Column():
                    gr.Markdown("""### Max Number of Nodes""")
                    max_node_selector = gr.Number(
                        value=1,
                        interactive=True,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

                with gr.Column():
                    gr.Markdown("""### Security Level""")
                    security_selector = gr.Radio(
                        choices=["Protected", "Public", "Private"],
                        value="Public",
                        interactive=True,
                        show_label=False,
                        elem_classes=["no-label", "small-big"]
                    )

        with gr.Column(elem_classes=["group-border"]):
            with gr.Accordion("Serving Container", open=False, elem_classes=["no-border"]):
                with gr.Column():
                    gr.Markdown("""### Container Type
    
    Text Generation Inference is an optimized container for text generation task""")
                    _ = gr.Textbox("Text Generation Inference", show_label=False, elem_classes=["no-label", "small-big"])
    
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("""### Custom Cuda Kernels
    
    TGI uses custom kernels to speed up inference for some models. You can try disabling them if you encounter issues.""")
                        custom_kernel = gr.Dropdown(
                            value="Enabled",
                            choices=["Enabled", "Disabled"],
                            interactive=True,
                            show_label=False,
                            elem_classes=["no-label", "small-big"]
                        )
    
                    with gr.Column():
                        gr.Markdown("""### Quantization
    
    Quantization can reduce the model size and improve latency, with little degradation in model accuracy.""")
                        _ = gr.Dropdown(
                            value="None",
                            choices=["None", "Bitsandbytes", "GPTQ"],
                            interactive=True,
                            show_label=False,
                            elem_classes=["no-label", "small-big"]
                        )
    
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("""### Max Input Length (per Query)
    
    Increasing this value can impact the amount of RAM required. Some models can only handle a finite range of sequences.""")
                        max_input_length = gr.Number(
                            value=1024,
                            interactive=True,
                            show_label=False,
                            elem_classes=["no-label", "small-big"]
                        )
    
                    with gr.Column():
                        gr.Markdown("""### Max Number of Tokens (per Query)
    
    The larger this value, the more memory each request will consume and the less effective batching can be.""")
                        max_tokens = gr.Number(
                            value=1512,
                            interactive=True,
                            show_label=False,
                            elem_classes=["no-label", "small-big"]
                        )
    
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("""### Max Batch Prefill Tokens
    
    Number of prefill tokens used during continuous batching. It can be useful to adjust this number since the prefill operation is memory-intensive and compute-bound.""")
                        max_batch_prefill_token = gr.Number(
                            value=2048,
                            interactive=True,
                            show_label=False,
                            elem_classes=["no-label", "small-big"]
                        )
    
                    with gr.Column():
                        gr.Markdown("""### Max Batch Total Tokens
    
    Number of tokens that can be passed before forcing waiting queries to be put on the batch. A value of 1000 can fit 10 queries of 100 tokens or a single query of 1000 tokens.""")
                        max_batch_total_token = gr.Number(
                            value=None,
                            interactive=True,
                            show_label=False,
                            elem_classes=["no-label", "small-big"]
                        )

        submit_button = gr.Button(
            value="Submit",
            elem_classes=["control-label-font", "control-button"]
        )

        status_txt = gr.Textbox(
            value="any status update will be displayed here",
            interactive=False,
            elem_classes=["no-label"]
        )

        provider_selector.change(update_regions, inputs=provider_selector, outputs=region_selector)
        region_selector.change(update_compute_options, inputs=[provider_selector, region_selector], outputs=compute_selector)

        submit_button.click(
            submit,
            inputs=[
                hf_account_input,
                hf_token_input,
                endpoint_name_input,
                provider_selector,
                region_selector,
                repository_selector,
                task_selector,
                framework_selector,
                compute_selector,
                min_node_selector,
                max_node_selector,
                security_selector,
                custom_kernel,
                max_input_length,
                max_tokens,
                max_batch_prefill_token,
                max_batch_total_token],
            outputs=status_txt)

    with gr.Tab("AWS", elem_classes=["no-border"]):
        gr.Markdown("# Deploy LLM on 🤗 Hugging Face Inference Endpoint", elem_classes=["center"])

    with gr.Tab("GCP", elem_classes=["no-border"]):
        gr.Markdown("# Deploy LLM on 🤗 Hugging Face Inference Endpoint", elem_classes=["center"])

    with gr.Tab("Azure", elem_classes=["no-border"]):
        gr.Markdown("# Deploy LLM on 🤗 Hugging Face Inference Endpoint", elem_classes=["center"])

    with gr.Tab("Lambdalabs", elem_classes=["no-border"]):
        gr.Markdown("# Deploy LLM on 🤗 Hugging Face Inference Endpoint", elem_classes=["center"])

hf_endpoint.launch(enable_queue=True, debug=True)