import spaces
import gradio as gr
from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM
import torch
from threading import Thread
from transformers import TextIteratorStreamer
import os

# Initialize model and tokenizer
MODEL_ID = "erikbeltran/pydiff"
GGUF_FILE = "unsloth.Q4_K_M.gguf"

try:
    # Use PreTrainedTokenizerFast instead of LlamaTokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID)
    
    # Ensure the tokenizer has the necessary special tokens
    special_tokens = {
    #    'pad_token': '[PAD]',
        'eos_token': '<|eot_id|>'
    #    'bos_token': '<s>',
    # 3   'unk_token': '<unk>'
    }
    tokenizer.add_special_tokens(special_tokens)
    
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, gguf_file=GGUF_FILE)

    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    
except Exception as e:
    print(f"Error initializing model or tokenizer: {str(e)}")
    raise

def format_diff_response(response):
    """Format the response to look like a diff output"""
    lines = response.split('\n')
    formatted = []
    for line in lines:
        if line.startswith('+'):
            formatted.append(f'<span style="color: green">{line}</span>')
        elif line.startswith('-'):
            formatted.append(f'<span style="color: red">{line}</span>')
        else:
            formatted.append(line)
    return '<br>'.join(formatted)

def create_prompt(request, file_content, system_message):
#    return f"""<system>{system_message}</system>
#<request>{request}</request>
#<file>{file_content}</file>"""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 July 2024

{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>

<request>{request}</request>
<file>{file_content}</file><|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

@spaces.GPU
def respond(request, file_content, system_message, max_tokens, temperature, top_p):
    try:
        prompt = create_prompt(request, file_content, system_message)
        
        # Tokenize input
        inputs = tokenizer(
            prompt, 
            return_tensors="pt", 
            add_special_tokens=True,
            padding=True,
            truncation=True,
            max_length=2048
        ).to(device)
        
        # Generate response with streaming
        response = ""
        streamer = TextIteratorStreamer(tokenizer,skip_prompt = True , skip_special_tokens=True)
        
        generation_kwargs = dict(
            input_ids=inputs["input_ids"],
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            streamer=streamer,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
        )
        
        # Start generation in a separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        
        # Yield formatted responses as they're generated
        for new_text in streamer:
            response += new_text
            yield format_diff_response(response)
            
    except Exception as e:
        yield f"<span style='color: red'>Error generating response: {str(e)}</span>"

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Code Review Assistant")
    
    with gr.Row():
        with gr.Column():
            request_input = gr.Textbox(
                label="Request",
                value="fix the error",
                placeholder="Enter your request (e.g., 'fix the function', 'add error handling')",
                lines=3
            )
            file_input = gr.Code(
                label="File Content",
                value="""def suma(a, b):
    return a + b

print(suma(5, "3")) 
                """,
                language="python",
                lines=10
            )
        with gr.Column():
            output = gr.HTML(label="Diff Output")
    
    with gr.Accordion("Advanced Settings", open=False):
        system_msg = gr.Textbox(
            value="you are a coder asistant, returns the answer to user in diff format",
            label="System Message"
        )
        max_tokens = gr.Slider(
            minimum=1,
            maximum=2048,
            value=128,
            step=1,
            label="Max Tokens"
        )
        temperature = gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.5,
            step=0.5,
            label="Temperature"
        )
        top_p = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=1,
            step=0.05,
            label="Top-p"
        )

    submit_btn = gr.Button("Submit")
    submit_btn.click(
        fn=respond,
        inputs=[
            request_input,
            file_input,
            system_msg,
            max_tokens,
            temperature,
            top_p
        ],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()