Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM | |
import torch | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
import os | |
# Initialize model and tokenizer | |
MODEL_ID = "erikbeltran/pydiff" | |
GGUF_FILE = "unsloth.Q4_K_M.gguf" | |
try: | |
# Use PreTrainedTokenizerFast instead of LlamaTokenizer | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID) | |
# Ensure the tokenizer has the necessary special tokens | |
special_tokens = { | |
# 'pad_token': '[PAD]', | |
'eos_token': '<|eot_id|>' | |
# 'bos_token': '<s>', | |
# 3 'unk_token': '<unk>' | |
} | |
tokenizer.add_special_tokens(special_tokens) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, gguf_file=GGUF_FILE) | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
except Exception as e: | |
print(f"Error initializing model or tokenizer: {str(e)}") | |
raise | |
def format_diff_response(response): | |
"""Format the response to look like a diff output""" | |
lines = response.split('\n') | |
formatted = [] | |
for line in lines: | |
if line.startswith('+'): | |
formatted.append(f'<span style="color: green">{line}</span>') | |
elif line.startswith('-'): | |
formatted.append(f'<span style="color: red">{line}</span>') | |
else: | |
formatted.append(line) | |
return '<br>'.join(formatted) | |
def create_prompt(request, file_content, system_message): | |
# return f"""<system>{system_message}</system> | |
#<request>{request}</request> | |
#<file>{file_content}</file>""" | |
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
Cutting Knowledge Date: December 2023 | |
Today Date: 26 July 2024 | |
{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|> | |
<request>{request}</request> | |
<file>{file_content}</file><|eot_id|><|start_header_id|>assistant<|end_header_id|>""" | |
def respond(request, file_content, system_message, max_tokens, temperature, top_p): | |
try: | |
prompt = create_prompt(request, file_content, system_message) | |
# Tokenize input | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
add_special_tokens=True, | |
padding=True, | |
truncation=True, | |
max_length=2048 | |
).to(device) | |
# Generate response with streaming | |
response = "" | |
streamer = TextIteratorStreamer(tokenizer,skip_prompt = True , skip_special_tokens=True) | |
generation_kwargs = dict( | |
input_ids=inputs["input_ids"], | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
streamer=streamer, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
) | |
# Start generation in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield formatted responses as they're generated | |
for new_text in streamer: | |
response += new_text | |
yield format_diff_response(response) | |
except Exception as e: | |
yield f"<span style='color: red'>Error generating response: {str(e)}</span>" | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Code Review Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
request_input = gr.Textbox( | |
label="Request", | |
value="fix the error", | |
placeholder="Enter your request (e.g., 'fix the function', 'add error handling')", | |
lines=3 | |
) | |
file_input = gr.Code( | |
label="File Content", | |
value="""def suma(a, b): | |
return a + b | |
print(suma(5, "3")) | |
""", | |
language="python", | |
lines=10 | |
) | |
with gr.Column(): | |
output = gr.HTML(label="Diff Output") | |
with gr.Accordion("Advanced Settings", open=False): | |
system_msg = gr.Textbox( | |
value="you are a coder asistant, returns the answer to user in diff format", | |
label="System Message" | |
) | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=128, | |
step=1, | |
label="Max Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.5, | |
step=0.5, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=1, | |
step=0.05, | |
label="Top-p" | |
) | |
submit_btn = gr.Button("Submit") | |
submit_btn.click( | |
fn=respond, | |
inputs=[ | |
request_input, | |
file_input, | |
system_msg, | |
max_tokens, | |
temperature, | |
top_p | |
], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.launch() |