Spaces:
Sleeping
Sleeping
import os | |
from PIL import Image, ImageDraw, ImageFont | |
import json | |
import gradio as gr | |
from google import genai | |
from google.genai import types | |
# Initialize Google Gemini client | |
client = genai.Client(api_key=os.environ['GOOGLE_API_KEY']) | |
model_name = "gemini-2.0-flash-exp" | |
# Function to parse JSON output from Gemini | |
def parse_json(json_output): | |
""" | |
Parse JSON output from the Gemini model. | |
""" | |
try: | |
lines = json_output.splitlines() | |
for i, line in enumerate(lines): | |
if line == "```json": | |
json_output = "\n".join(lines[i + 1:]) # Remove everything before "```json" | |
json_output = json_output.split("```")[0] # Remove everything after the closing "```" | |
break | |
return json.loads(json_output) | |
except Exception as e: | |
print(f"Error parsing JSON: {e}") | |
return {} | |
# Function to draw a flowchart | |
def draw_flowchart(image, flowchart_json): | |
""" | |
Draws a flowchart on the given image based on JSON input. | |
""" | |
im = image.copy() | |
draw = ImageDraw.Draw(im) | |
# Load default font | |
try: | |
font = ImageFont.load_default() | |
except Exception as e: | |
print(f"Error loading font: {e}") | |
return im | |
shapes = flowchart_json.get("shapes", []) | |
connections = flowchart_json.get("connections", []) | |
# Draw shapes | |
for shape in shapes: | |
x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"] | |
shape_type = shape.get("type", "rectangle").lower() | |
label = shape.get("label", "") | |
color = shape.get("color", "white") | |
# Draw the shape | |
if shape_type == "rectangle": | |
draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
elif shape_type == "ellipse": | |
draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
elif shape_type == "diamond": | |
points = [ | |
(x + w // 2, y), # Top | |
(x + w, y + h // 2), # Right | |
(x + w // 2, y + h), # Bottom | |
(x, y + h // 2) # Left | |
] | |
draw.polygon(points, fill=color, outline="black") | |
# Calculate text position using getbbox | |
bbox = font.getbbox(label) | |
text_w = bbox[2] - bbox[0] | |
text_h = bbox[3] - bbox[1] | |
text_x = x + (w - text_w) // 2 | |
text_y = y + (h - text_h) // 2 | |
# Add the label | |
draw.text((text_x, text_y), label, fill="black", font=font) | |
# Draw connections | |
for conn in connections: | |
from_shape = next(s for s in shapes if s["id"] == conn["from"]) | |
to_shape = next(s for s in shapes if s["id"] == conn["to"]) | |
x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"] | |
x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"] | |
# Draw the line | |
draw.line([x1, y1, x2, y2], fill="black", width=2) | |
# Add arrowhead for arrows | |
if conn.get("type", "arrow") == "arrow": | |
arrow_size = 10 | |
draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black") | |
return im | |
# Function to draw a flowchart | |
# Function to draw a flowchart | |
def olddraw_flowchart(image, flowchart_json): | |
""" | |
Draws a flowchart on the given image based on JSON input. | |
""" | |
im = image.copy() | |
draw = ImageDraw.Draw(im) | |
# Load default font | |
try: | |
font = ImageFont.load_default() | |
except Exception as e: | |
print(f"Error loading font: {e}") | |
return im | |
shapes = flowchart_json.get("shapes", []) | |
connections = flowchart_json.get("connections", []) | |
# Draw shapes | |
for shape in shapes: | |
x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"] | |
shape_type = shape.get("type", "rectangle").lower() | |
label = shape.get("label", "") | |
color = shape.get("color", "white") | |
# Draw the shape | |
if shape_type == "rectangle": | |
draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
elif shape_type == "ellipse": | |
draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3) | |
elif shape_type == "diamond": | |
points = [ | |
(x + w // 2, y), # Top | |
(x + w, y + h // 2), # Right | |
(x + w // 2, y + h), # Bottom | |
(x, y + h // 2) # Left | |
] | |
draw.polygon(points, fill=color, outline="black") | |
# Calculate text position | |
text_w, text_h = font.getsize(label) | |
text_x = x + (w - text_w) // 2 | |
text_y = y + (h - text_h) // 2 | |
# Add the label | |
draw.text((text_x, text_y), label, fill="black", font=font) | |
# Draw connections | |
for conn in connections: | |
from_shape = next(s for s in shapes if s["id"] == conn["from"]) | |
to_shape = next(s for s in shapes if s["id"] == conn["to"]) | |
x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"] | |
x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"] | |
# Draw the line | |
draw.line([x1, y1, x2, y2], fill="black", width=2) | |
# Add arrowhead for arrows | |
if conn.get("type", "arrow") == "arrow": | |
arrow_size = 10 | |
draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black") | |
return im | |
# Function to generate flowchart JSON via Gemini | |
def generate_flowchart(prompt): | |
""" | |
Use Google Gemini to generate JSON for a flowchart. | |
""" | |
try: | |
response = client.models.generate_content( | |
model=model_name, | |
contents=[prompt], | |
config=types.GenerateContentConfig( | |
system_instruction=""" | |
Return a JSON structure describing a flowchart. | |
Use formal flowchart conventions with shapes like rectangles, ellipses, and diamonds. | |
Each shape should have attributes: id, label, x, y, width, height, type (e.g., 'rectangle', 'ellipse', 'diamond'), and color. | |
Also include connections with attributes: from (id), to (id), and type (e.g., 'arrow'). | |
""", | |
temperature=0.5, | |
) | |
) | |
print("Gemini Response:", response.text) | |
return parse_json(response.text) | |
except Exception as e: | |
print(f"Error generating flowchart JSON: {e}") | |
return {} | |
# Function to predict the flowchart | |
def predict_flowchart(prompt): | |
""" | |
Generate a flowchart image based on the user's prompt. | |
""" | |
try: | |
# Generate the flowchart JSON | |
flowchart_json = generate_flowchart(prompt) | |
if not flowchart_json: | |
raise ValueError("Could not generate flowchart JSON.") | |
# Create a blank image to draw on | |
image = Image.new("RGB", (1000, 800), "white") | |
result_image = draw_flowchart(image, flowchart_json) | |
return result_image | |
except Exception as e: | |
print(f"Error during processing: {e}") | |
# Return a blank image in case of an error | |
error_image = Image.new("RGB", (1000, 800), "white") | |
draw = ImageDraw.Draw(error_image) | |
draw.text((50, 50), f"Error: {str(e)}", fill="red") | |
return error_image | |
# Define the Gradio interface for flowcharts | |
def gradio_interface_flowcharts(): | |
""" | |
Gradio app interface for flowchart generation. | |
""" | |
with gr.Blocks(gr.themes.Glass(secondary_hue="blue")) as demo: | |
gr.Markdown("# Flowchart Generator with Gemini") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Input Section") | |
input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe the flowchart process.") | |
submit_btn = gr.Button("Generate Flowchart") | |
with gr.Column(): | |
gr.Markdown("### Output Section") | |
output_image = gr.Image(type="pil", label="Output Flowchart") | |
# Event to generate flowcharts | |
submit_btn.click( | |
predict_flowchart, | |
inputs=[input_prompt], | |
outputs=[output_image] | |
) | |
return demo | |
# Run the app | |
if __name__ == "__main__": | |
demo = gradio_interface_flowcharts() | |
demo.launch() |