|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
import os |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
base_model_name = "google/gemma-2b-it" |
|
adapter_model_name = "akhaliq/gemma-3-270m-gradio-coder-adapter" |
|
|
|
|
|
print("Loading Model 1 with adapter...") |
|
tokenizer1 = AutoTokenizer.from_pretrained(adapter_model_name) |
|
base_model1 = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
token=HF_TOKEN |
|
) |
|
model1 = PeftModel.from_pretrained(base_model1, adapter_model_name) |
|
model1.eval() |
|
|
|
|
|
print("Loading Model 2...") |
|
model2_name = "google/gemma-2b-it" |
|
tokenizer2 = AutoTokenizer.from_pretrained(model2_name, token=HF_TOKEN) |
|
model2 = AutoModelForCausalLM.from_pretrained( |
|
model2_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
token=HF_TOKEN |
|
) |
|
model2.eval() |
|
|
|
def generate_code(user_input, model, tokenizer, model_name="Model"): |
|
""" |
|
Generate code based on user input using the selected model |
|
""" |
|
|
|
prompt = f"<start_of_turn>user\n{user_input}<end_of_turn>\n<start_of_turn>model\n" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
do_sample=True, |
|
top_p=0.9, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "<start_of_turn>model" in generated_text: |
|
response = generated_text.split("<start_of_turn>model")[-1].strip() |
|
elif user_input in generated_text: |
|
response = generated_text.split(user_input)[-1].strip() |
|
else: |
|
response = generated_text |
|
|
|
|
|
response = response.replace("<end_of_turn>", "").strip() |
|
|
|
return response |
|
|
|
def generate_both(user_input): |
|
""" |
|
Generate code from both models for comparison |
|
""" |
|
if not user_input.strip(): |
|
return "", "" |
|
|
|
try: |
|
output1 = generate_code(user_input, model1, tokenizer1, "Model 1 (Adapter)") |
|
except Exception as e: |
|
output1 = f"Error with Model 1: {str(e)}" |
|
|
|
try: |
|
output2 = generate_code(user_input, model2, tokenizer2, "Model 2 (Base)") |
|
except Exception as e: |
|
output2 = f"Error with Model 2: {str(e)}" |
|
|
|
return output1, output2 |
|
|
|
|
|
with gr.Blocks(title="Text to Code Generator - Model Comparison", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# π Text to Code Generator - Model Comparison |
|
|
|
Compare code generation from two different Gemma models: |
|
- **Model 1**: Gemma with Gradio Coder Adapter (Fine-tuned) |
|
- **Model 2**: Base Gemma Model |
|
|
|
Simply describe what you want to build, and see how each model responds! |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
input_text = gr.Textbox( |
|
label="Describe what you want to code", |
|
placeholder="e.g., Create a Python function that calculates the factorial of a number", |
|
lines=5, |
|
max_lines=10 |
|
) |
|
|
|
with gr.Row(): |
|
generate_btn = gr.Button("Generate from Both Models", variant="primary", scale=2) |
|
clear_btn = gr.ClearButton([input_text], value="Clear", scale=1) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["Create a Python function to check if a number is prime"], |
|
["Write a JavaScript function to reverse a string"], |
|
["Create a React component for a todo list item"], |
|
["Write a SQL query to find the top 5 customers by total purchase amount"], |
|
["Create a Python class for a bank account with deposit and withdraw methods"], |
|
["Build a simple Gradio interface for text summarization"], |
|
], |
|
inputs=input_text, |
|
label="Example Prompts" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Model 1: With Gradio Coder Adapter") |
|
output_code1 = gr.Code( |
|
label="Generated Code (Model 1)", |
|
language="python", |
|
lines=15, |
|
interactive=True, |
|
show_label=False |
|
) |
|
copy_btn1 = gr.Button("π Copy Code", size="sm") |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Model 2: Base Gemma Model") |
|
output_code2 = gr.Code( |
|
label="Generated Code (Model 2)", |
|
language="python", |
|
lines=15, |
|
interactive=True, |
|
show_label=False |
|
) |
|
copy_btn2 = gr.Button("π Copy Code", size="sm") |
|
|
|
|
|
generate_btn.click( |
|
fn=generate_both, |
|
inputs=input_text, |
|
outputs=[output_code1, output_code2], |
|
api_name="generate" |
|
) |
|
|
|
input_text.submit( |
|
fn=generate_both, |
|
inputs=input_text, |
|
outputs=[output_code1, output_code2] |
|
) |
|
|
|
|
|
copy_btn1.click( |
|
None, |
|
inputs=output_code1, |
|
outputs=None, |
|
js=""" |
|
(code) => { |
|
navigator.clipboard.writeText(code); |
|
const btn = document.querySelector('button:has-text("π Copy Code")'); |
|
const originalText = btn.textContent; |
|
btn.textContent = 'β Copied!'; |
|
setTimeout(() => btn.textContent = originalText, 2000); |
|
return null; |
|
} |
|
""" |
|
) |
|
|
|
copy_btn2.click( |
|
None, |
|
inputs=output_code2, |
|
outputs=None, |
|
js=""" |
|
(code) => { |
|
navigator.clipboard.writeText(code); |
|
const btns = document.querySelectorAll('button:has-text("π Copy Code")'); |
|
const btn = btns[1]; |
|
const originalText = btn.textContent; |
|
btn.textContent = 'β Copied!'; |
|
setTimeout(() => btn.textContent = originalText, 2000); |
|
return null; |
|
} |
|
""" |
|
) |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
--- |
|
π‘ **Tips:** |
|
- Be specific about the programming language you want |
|
- Include details about inputs, outputs, and edge cases |
|
- You can edit the generated code directly in the output box |
|
|
|
**Note:** The adapter model is specifically fine-tuned for generating Gradio code! |
|
""" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |