import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import spaces

# Load the model and tokenizer
peft_model_id = "rootxhacker/CodeAstra-7B"
config = PeftConfig.from_pretrained(peft_model_id)

# Function to move tensors to CPU
def to_cpu(obj):
    if isinstance(obj, torch.Tensor):
        return obj.cpu()
    elif isinstance(obj, list):
        return [to_cpu(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(to_cpu(item) for item in obj)
    elif isinstance(obj, dict):
        return {key: to_cpu(value) for key, value in obj.items()}
    return obj

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    load_in_4bit=True,
    device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)

@spaces.GPU()
def get_completion(query, model, tokenizer):
    try:
        # Move model to CUDA
        model = model.cuda()
        # Ensure input is on CUDA
        inputs = tokenizer(query, return_tensors="pt").to('cuda')
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
        # Move outputs to CPU before decoding
        outputs = to_cpu(outputs)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"An error occurred: {str(e)}"
    finally:
        # Move model back to CPU to free up GPU memory
        model = model.cpu()
        torch.cuda.empty_cache()
        

@spaces.GPU()
def code_review(code_to_analyze):
    two_shot_prompt = f"""find all vulnerabilities which in the code 
{code_to_analyze} """

    full_response = get_completion(two_shot_prompt, model, tokenizer)
    
    # Return the full response without any processing
    return full_response

# Create Gradio interface
iface = gr.Interface(
    fn=code_review,
    inputs=gr.Textbox(lines=10, label="Enter code to analyze"),
    outputs=gr.Textbox(label="Code Review Result"),
    title="Code Review Expert",
    description="This tool analyzes code for potential security flaws, logic vulnerabilities, and provides guidance on secure coding practices."
)

# Launch the Gradio app
iface.launch()