File size: 8,414 Bytes
039de21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d0dc8
 
039de21
ecb5020
039de21
 
 
 
 
7477452
 
 
 
 
 
 
039de21
 
 
 
40d0dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039de21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7477452
 
039de21
 
7477452
 
039de21
 
 
 
 
 
 
 
7477452
 
 
 
 
039de21
7477452
 
039de21
 
 
 
 
 
 
 
 
 
 
 
 
7161e25
039de21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import os
from PIL import Image, ImageDraw, ImageFont
import json
import gradio as gr
from google import genai
from google.genai import types

# Initialize Google Gemini client
client = genai.Client(api_key=os.environ['GOOGLE_API_KEY'])
model_name = "gemini-2.0-flash-exp"

# Function to parse JSON output from Gemini
def parse_json(json_output):
    """
    Parse JSON output from the Gemini model.
    """
    try:
        lines = json_output.splitlines()
        for i, line in enumerate(lines):
            if line == "```json":
                json_output = "\n".join(lines[i + 1:])  # Remove everything before "```json"
                json_output = json_output.split("```")[0]  # Remove everything after the closing "```"
                break
        return json.loads(json_output)
    except Exception as e:
        print(f"Error parsing JSON: {e}")
        return {}


# Function to draw a flowchart
def draw_flowchart(image, flowchart_json):
    """
    Draws a flowchart on the given image based on JSON input.
    """
    im = image.copy()
    draw = ImageDraw.Draw(im)

    # Load default font
    try:
        font = ImageFont.load_default()
    except Exception as e:
        print(f"Error loading font: {e}")
        return im

    shapes = flowchart_json.get("shapes", [])
    connections = flowchart_json.get("connections", [])

    # Draw shapes
    for shape in shapes:
        x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"]
        shape_type = shape.get("type", "rectangle").lower()
        label = shape.get("label", "")
        color = shape.get("color", "white")

        # Draw the shape
        if shape_type == "rectangle":
            draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3)
        elif shape_type == "ellipse":
            draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3)
        elif shape_type == "diamond":
            points = [
                (x + w // 2, y),          # Top
                (x + w, y + h // 2),      # Right
                (x + w // 2, y + h),      # Bottom
                (x, y + h // 2)           # Left
            ]
            draw.polygon(points, fill=color, outline="black")

        # Calculate text position using getbbox
        bbox = font.getbbox(label)
        text_w = bbox[2] - bbox[0]
        text_h = bbox[3] - bbox[1]
        text_x = x + (w - text_w) // 2
        text_y = y + (h - text_h) // 2

        # Add the label
        draw.text((text_x, text_y), label, fill="black", font=font)

    # Draw connections
    for conn in connections:
        from_shape = next(s for s in shapes if s["id"] == conn["from"])
        to_shape = next(s for s in shapes if s["id"] == conn["to"])
        x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"]
        x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"]

        # Draw the line
        draw.line([x1, y1, x2, y2], fill="black", width=2)

        # Add arrowhead for arrows
        if conn.get("type", "arrow") == "arrow":
            arrow_size = 10
            draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black")

    return im


# Function to draw a flowchart
# Function to draw a flowchart
def olddraw_flowchart(image, flowchart_json):
    """
    Draws a flowchart on the given image based on JSON input.
    """
    im = image.copy()
    draw = ImageDraw.Draw(im)

    # Load default font
    try:
        font = ImageFont.load_default()
    except Exception as e:
        print(f"Error loading font: {e}")
        return im

    shapes = flowchart_json.get("shapes", [])
    connections = flowchart_json.get("connections", [])

    # Draw shapes
    for shape in shapes:
        x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"]
        shape_type = shape.get("type", "rectangle").lower()
        label = shape.get("label", "")
        color = shape.get("color", "white")

        # Draw the shape
        if shape_type == "rectangle":
            draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3)
        elif shape_type == "ellipse":
            draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3)
        elif shape_type == "diamond":
            points = [
                (x + w // 2, y),  # Top
                (x + w, y + h // 2),  # Right
                (x + w // 2, y + h),  # Bottom
                (x, y + h // 2)  # Left
            ]
            draw.polygon(points, fill=color, outline="black")

        # Calculate text position
        text_w, text_h = font.getsize(label)
        text_x = x + (w - text_w) // 2
        text_y = y + (h - text_h) // 2

        # Add the label
        draw.text((text_x, text_y), label, fill="black", font=font)

    # Draw connections
    for conn in connections:
        from_shape = next(s for s in shapes if s["id"] == conn["from"])
        to_shape = next(s for s in shapes if s["id"] == conn["to"])
        x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"]
        x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"]

        # Draw the line
        draw.line([x1, y1, x2, y2], fill="black", width=2)

        # Add arrowhead for arrows
        if conn.get("type", "arrow") == "arrow":
            arrow_size = 10
            draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black")

    return im
# Function to generate flowchart JSON via Gemini
def generate_flowchart(prompt):
    """
    Use Google Gemini to generate JSON for a flowchart.
    """
    try:
        response = client.models.generate_content(
            model=model_name,
            contents=[prompt],
            config=types.GenerateContentConfig(
                system_instruction="""
                Return a JSON structure describing a flowchart. 
                Use formal flowchart conventions with shapes like rectangles, ellipses, and diamonds.
                Each shape should have attributes: id, label, x, y, width, height, type (e.g., 'rectangle', 'ellipse', 'diamond'), and color.
                Also include connections with attributes: from (id), to (id), and type (e.g., 'arrow').
                """,
                temperature=0.5,
            )
        )
        print("Gemini Response:", response.text)
        return parse_json(response.text)
    except Exception as e:
        print(f"Error generating flowchart JSON: {e}")
        return {}

# Function to predict the flowchart
def predict_flowchart(prompt):
    """
    Generate a flowchart image based on the user's prompt.
    """
    try:
        # Generate the flowchart JSON
        flowchart_json = generate_flowchart(prompt)
        if not flowchart_json:
            raise ValueError("Could not generate flowchart JSON.")

        # Create a blank image to draw on
        image = Image.new("RGB", (1000, 800), "white")
        result_image = draw_flowchart(image, flowchart_json)
        return result_image
    except Exception as e:
        print(f"Error during processing: {e}")
        # Return a blank image in case of an error
        error_image = Image.new("RGB", (1000, 800), "white")
        draw = ImageDraw.Draw(error_image)
        draw.text((50, 50), f"Error: {str(e)}", fill="red")
        return error_image

# Define the Gradio interface for flowcharts
def gradio_interface_flowcharts():
    """
    Gradio app interface for flowchart generation.
    """
    with gr.Blocks(gr.themes.Glass(secondary_hue="blue")) as demo:
        gr.Markdown("# Flowchart Generator with Gemini")

        with gr.Row():
            with gr.Column():
                gr.Markdown("### Input Section")
                input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe the flowchart process.")
                submit_btn = gr.Button("Generate Flowchart")

            with gr.Column():
                gr.Markdown("### Output Section")
                output_image = gr.Image(type="pil", label="Output Flowchart")

        # Event to generate flowcharts
        submit_btn.click(
            predict_flowchart,
            inputs=[input_prompt],
            outputs=[output_image]
        )

    return demo

# Run the app
if __name__ == "__main__":
    demo = gradio_interface_flowcharts()
    demo.launch()