import sys
import os
import logging as log
from typing import Generator

import gradio as gr
from gradio.themes.utils import sizes
from text_generation import Client
from src.request import StarCoderRequest, StarCoderRequestConfig

from src.utils import (
    get_file_as_string,
    get_sections,
    get_url_from_env_or_default_path,
    preview
)
from constants import (
    FIM_MIDDLE,
    FIM_PREFIX,
    FIM_SUFFIX,
    END_OF_TEXT,
    MIN_TEMPERATURE,
)
from settings import (
    DEFAULT_PORT,
    DEFAULT_STARCODER_API_PATH,
    DEFAULT_STARCODER_BASE_API_PATH,
)

HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Gracefully exit the app if the HF_TOKEN is not set,
# printing to system `errout` the error (instead of raising an exception)
# and the expected behavior
if not HF_TOKEN:
    ERR_MSG = """
        Please set the HF_TOKEN environment variable with your Hugging Face API token.
        You can get one by signing up at https://huggingface.co/join and then visiting
        https://huggingface.co/settings/tokens."""
    print(ERR_MSG, file=sys.stderr)
    # gr.errors.GradioError(ERR_MSG)
    # gr.close_all(verbose=False)
    sys.exit(1)

API_URL_STAR = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH)
API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH)

preview("StarCoder Model URL", API_URL_STAR)
preview("StarCoderBase Model URL", API_URL_BASE)
preview("HF Token", HF_TOKEN, ofuscate=True)

_styles = get_file_as_string("styles.css")
_script = get_file_as_string("community-btn.js")
_sharing_icon_svg = get_file_as_string("community-icon.svg")
_loading_icon_svg = get_file_as_string("loading-icon.svg")

# Loads the whole content of the ./README.md file
# slicing/unpacking its different sections into their proper variables
readme_file_content = get_file_as_string("README.md", path='./')
(
    manifest,
    description,
    disclaimer,
    formats,
) = get_sections(readme_file_content, "---", up_to=4)

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("IBM Plex Sans", [400, 600]),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
    text_size=sizes.text_lg,
)

HEADERS = {
    "Authorization": f"Bearer {HF_TOKEN}",
}
client_star = Client(API_URL_STAR, headers=HEADERS)
client_base = Client(API_URL_BASE, headers=HEADERS)

def get_tokens_collector(request: StarCoderRequest) -> Generator[str, None, None]:

    model_client = client_star if request.settings.version == "StarCoder" else client_base
    stream = model_client.generate_stream(request.prompt, **request.settings.kwargs())
    for response in stream:
        # print(response.token.id, response.token.text)
        # if token.text != END_OF_TEXT:
        if response.token.id != 0:
            yield response.token.text

def get_tokens_accumulator(request: StarCoderRequest) -> Generator[str, None, None]:
    # start with the prefix (if in fim_mode)
    output = request.prefix if request.fim_mode else request.prompt
    for token in get_tokens_collector(request=request):
        output += token
        yield output
    # after the last token, append the suffix (if in fim_mode)
    if request.fim_mode:
        output += request.suffix
        yield output
    # Append an extra line at the end
    yield output + '\n'

def get_tokens_linker(request: StarCoderRequest) -> str:
    return "".join(list(get_tokens_collector(request)))

def generate(
        prompt: str,
        temperature = 0.9,
        max_new_tokens = 256,
        top_p = 0.95,
        repetition_penalty = 1.0,
        version = "StarCoder",
    ) -> Generator[str, None, None]:
    request = StarCoderRequest(
        prompt=prompt,
        settings=StarCoderRequestConfig(
            version=version,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )
    )
    yield from get_tokens_accumulator(request)

def process_example(
        prompt: str,
        temperature = 0.9,
        max_new_tokens = 256,
        top_p = 0.95,
        repetition_penalty = 1.0,
        version = "StarCoder",
    ) -> Generator[str, None, None]:
    request = StarCoderRequest(
        prompt=prompt,
        settings=StarCoderRequestConfig(
            version=version,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )
    )
    yield from get_tokens_linker(request)

# todo: move it into the README too
examples = [
    "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
    "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
    "def alternating(list1, list2):\n   results = []\n   for i in range(min(len(list1), len(list2))):\n       results.append(list1[i])\n       results.append(list2[i])\n   if len(list1) > len(list2):\n       <FILL_HERE>\n   else:\n       results.extend(list2[i+1:])\n   return results",
]

with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
    with gr.Column():
        gr.Markdown(description)
        with gr.Row():
            with gr.Column():
                instruction = gr.Textbox(
                    placeholder="Enter your code here",
                    label="Code",
                    elem_id="q-input",
                )
                submit = gr.Button("Generate", variant="primary")
                output = gr.Code(elem_id="q-output", lines=30)
                with gr.Row():
                    with gr.Column():
                        with gr.Accordion("Advanced settings", open=False):
                            with gr.Row():
                                column_1, column_2 = gr.Column(), gr.Column()
                                with column_1:
                                    temperature = gr.Slider(
                                        label="Temperature",
                                        value=0.2,
                                        minimum=0.0,
                                        maximum=1.0,
                                        step=0.05,
                                        interactive=True,
                                        info="Higher values produce more diverse outputs",
                                    )
                                    max_new_tokens = gr.Slider(
                                        label="Max new tokens",
                                        value=256,
                                        minimum=0,
                                        maximum=8192,
                                        step=64,
                                        interactive=True,
                                        info="The maximum numbers of new tokens",
                                    )
                                with column_2:
                                    top_p = gr.Slider(
                                        label="Top-p (nucleus sampling)",
                                        value=0.90,
                                        minimum=0.0,
                                        maximum=1,
                                        step=0.05,
                                        interactive=True,
                                        info="Higher values sample more low-probability tokens",
                                    )
                                    repetition_penalty = gr.Slider(
                                        label="Repetition penalty",
                                        value=1.2,
                                        minimum=1.0,
                                        maximum=2.0,
                                        step=0.05,
                                        interactive=True,
                                        info="Penalize repeated tokens",
                                    )
                    with gr.Column():
                        version = gr.Dropdown(
                                    ["StarCoderBase", "StarCoder"],
                                    value="StarCoder",
                                    label="Version",
                                    info="",
                                    )
                gr.Markdown(disclaimer)
                with gr.Group(elem_id="share-btn-container"):
                    community_icon = gr.HTML(_sharing_icon_svg, visible=True)
                    loading_icon = gr.HTML(_loading_icon_svg, visible=True)
                    share_button = gr.Button(
                        "Share to community", elem_id="share-btn", visible=True
                    )
                gr.Examples(
                    examples=examples,
                    inputs=[instruction],
                    cache_examples=False,
                    fn=process_example,
                    outputs=[output],
                )
                gr.Markdown(formats)

    submit.click(
        generate,
        inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version],
        outputs=[output],
        # preprocess=False,
        max_batch_size=8,
        show_progress=True
    )
    share_button.click(None, [], [], _js=_script)

demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT)