Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from io import BytesIO | |
from PIL import Image, ImageDraw, ImageFont | |
from PIL import ImageColor | |
import json | |
from google import genai | |
from google.genai import types | |
# Initialize Google Gemini client | |
client = genai.Client(api_key=os.environ['GOOGLE_API_KEY']) | |
model_name = "gemini-2.0-flash-exp" | |
bounding_box_system_instructions = """ | |
Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects. | |
If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..). | |
""" | |
additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()] | |
def parse_json(json_output): | |
""" | |
Parse JSON output from the Gemini model. | |
""" | |
lines = json_output.splitlines() | |
for i, line in enumerate(lines): | |
if line == "```json": | |
json_output = "\n".join(lines[i+1:]) # Remove everything before "```json" | |
json_output = json_output.split("```")[0] # Remove everything after the closing "```" | |
break | |
return json_output | |
def plot_bounding_boxes(im, bounding_boxes): | |
""" | |
Plots bounding boxes on an image with labels. | |
""" | |
im = im.copy() | |
width, height = im.size | |
draw = ImageDraw.Draw(im) | |
colors = [ | |
'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan', | |
'lime', 'magenta', 'violet', 'gold', 'silver' | |
] + additional_colors | |
try: | |
# Use a default font if NotoSansCJK is not available | |
try: | |
font = ImageFont.load_default() | |
except OSError: | |
print("NotoSansCJK-Regular.ttc not found. Using default font.") | |
font = ImageFont.load_default() | |
bounding_boxes_json = json.loads(bounding_boxes) | |
for i, bounding_box in enumerate(bounding_boxes_json): | |
color = colors[i % len(colors)] | |
abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height) | |
abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width) | |
abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height) | |
abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width) | |
if abs_x1 > abs_x2: | |
abs_x1, abs_x2 = abs_x2, abs_x1 | |
if abs_y1 > abs_y2: | |
abs_y1, abs_y2 = abs_y2, abs_y1 | |
# Draw bounding box and label | |
draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4) | |
if "label" in bounding_box: | |
draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font) | |
except Exception as e: | |
print(f"Error drawing bounding boxes: {e}") | |
return im | |
def predict_bounding_boxes(image, prompt): | |
""" | |
Process the image and prompt through Gemini and draw bounding boxes. | |
""" | |
try: | |
# Resize the image for input | |
image = image.resize((1024, int(1024 * image.height / image.width))) | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
image_bytes = buffered.getvalue() | |
# Make API request to Gemini | |
response = client.models.generate_content( | |
model=model_name, | |
contents=[prompt, image], | |
config=types.GenerateContentConfig( | |
system_instruction=bounding_box_system_instructions, | |
temperature=0.5, | |
safety_settings=[ | |
types.SafetySetting( | |
category="HARM_CATEGORY_DANGEROUS_CONTENT", | |
threshold="BLOCK_ONLY_HIGH", | |
) | |
], | |
) | |
) | |
print("Gemini response:", response.text) | |
# Parse and plot bounding boxes | |
bounding_boxes = parse_json(response.text) | |
if not bounding_boxes: | |
raise ValueError("No bounding boxes returned.") | |
result_image = plot_bounding_boxes(image, bounding_boxes) | |
return result_image | |
except Exception as e: | |
print(f"Error during processing: {e}") | |
return image, f"Error: {e}" | |
def gradio_interface(): | |
""" | |
Gradio app interface for bounding box generation with example pairs. | |
""" | |
# Example image + prompt pairs | |
examples = [ | |
["cookies.jpg", "Detect the cookies and label their types."], | |
["messed_room.jpg", "Find the unorganized item and suggest action in label in the image to fix them."], | |
["yoga.jpg", "Show the different yoga poses and name them."], | |
["zoom_face.png", "Label the tired faces in the image."] | |
] | |
with gr.Blocks(gr.themes.Glass(secondary_hue= "rose")) as demo: | |
gr.Markdown("# Gemini Bounding Box Generator") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Input Section") | |
input_image = gr.Image(type="pil", label="Input Image") | |
input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect.") | |
submit_btn = gr.Button("Generate") | |
with gr.Column(): | |
gr.Markdown("### Output Section") | |
output_image = gr.Image(type="pil", label="Output Image") | |
#output_json = gr.Textbox(label="Bounding Boxes JSON") | |
gr.Markdown("### Examples") | |
gr.Examples( | |
examples=examples, | |
inputs=[input_image, input_prompt], | |
label="Example Images with Prompts" | |
) | |
# Event to generate bounding boxes | |
submit_btn.click( | |
predict_bounding_boxes, | |
inputs=[input_image, input_prompt], | |
outputs=[output_image] | |
) | |
return demo | |
if __name__ == "__main__": | |
app = gradio_interface() | |
app.launch() |