import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from threading import Thread

description = """
<p style="text-align: center; font-size: 24px; color: #292b47;">
    <strong>🚀 <span style='color: #3264ff;'>DeciCoder-6B: Bigger, Faster, Stronger </span></strong>
</p>
<span style='color: #292b47;'>Welcome to the <a href="https://huggingface.co/Deci/DeciCoder-6B" style="color: #3264ff;">DeciCoder-6B playground</a>! DeciCoder-6B was trained on the Python, Java, Javascript, Rust, C++, C, and C# subset of the Starcoder Training Dataset, and it's released under the Apache 2.0 license. This model is capable of code-completion and instruction following. It surpasses CodeGen 2.5 7B, CodeLlama 7B, abd StarCoder 7B in its supported languages on HumanEval, and leads by 3 points in Python over StarCoderBase 15.5B.</span>
"""


checkpoint = "Deci/DeciCoder-6B"

model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                             trust_remote_code=True,
                                             device_map="auto",
                                             low_cpu_mem_usage=True,
                                             load_in_4bit=True)

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)

tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"

pipe = pipeline("text-generation",
                model=model,
                tokenizer=tokenizer,
                device_map="auto",
                max_length=2048,
                temperature=1e-3,
)

def code_generation(prompt: str) -> str:
    """
    Generates code based on the given prompt. Handles both regular and FIM (Fill-In-Missing) generation.

    Args:
        prompt (str): The input code prompt.

    Returns:
        str: The HTML-styled generated code.
    """
    completion = pipe(prompt)[0]['generated_text']
    return completion.replace("<|endoftext|>", "")


demo = gr.Blocks(
    css=".gradio-container {background-color: #FAFBFF; color: #292b47}"
)
with demo:
    gr.Markdown(value=description)
    with gr.Row():
        code = gr.Code(lines=10, language="python", label="👨🏽‍💻 Input", value="def nth_element_in_fibonnaci(element):\n    \"\"\"Returns the nth element of the Fibonnaci sequence.\"\"\"")
        output = gr.Code(label="💻 Generated code")
    with gr.Row():
        run = gr.Button(value="👨🏽‍💻 Generate code")
        clear = gr.Button("🗑️ Clear")

    clear.click(lambda: (None, None), None, [code, output], queue=False)
    event = run.click(code_generation, [code], output)
    gr.HTML(label="Keep in touch", value="<img src='https://huggingface.co/spaces/Deci/DeciCoder-Demo/resolve/main/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")

demo.launch(debug=True)