File size: 3,007 Bytes
7c93490
 
0f981a7
 
7c93490
002ab21
7c93490
002ab21
0f981a7
002ab21
 
0f981a7
7c93490
002ab21
7c93490
0f981a7
7c93490
 
002ab21
7c93490
0f981a7
002ab21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f981a7
002ab21
 
0f981a7
7c93490
 
 
 
 
0f981a7
7c93490
 
0f981a7
7c93490
 
 
0f981a7
7c93490
002ab21
 
7c93490
0f981a7
 
7c93490
d64bd4a
7c93490
0f981a7
7c93490
 
 
 
 
 
 
 
 
 
0f981a7
002ab21
0f981a7
b96880f
0f981a7
002ab21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import requests
import base64
from io import BytesIO
from PIL import Image
import numpy as np
import os
import cv2  # Import OpenCV

HUGGINGFACE_API_KEY = os.environ.get("HF_TOKEN")  # Store API key as env variable
HF_MODEL = "stabilityai/stable-diffusion-xl-turbo"

def generate_image(image_data):
    if image_data is None:
        return "Please draw something on the canvas."

    try:
        decoded_image = base64.b64decode(image_data.split('base64,')[1])
        image = Image.open(BytesIO(decoded_image)).convert("RGB")
        np_image = np.array(image)

        # --- Dynamic Prompt Engineering (Improved) ---
        prompt = "A drawing of "  # Base prompt

        # 1. Basic Shape Detection (Example with OpenCV)
        gray = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY)  # Convert to grayscale
        _, thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)  # Threshold
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        shapes = []
        for contour in contours:
            approx = cv2.approxPolyDP(contour, 0.02 * cv2.arcLength(contour, True), True)
            num_vertices = len(approx)
            if num_vertices == 3:
                shapes.append("triangle")
            elif num_vertices == 4:
                shapes.append("quadrilateral")  # Could be square, rectangle, etc.
            elif num_vertices > 5:
                shapes.append("circle-like shape") # More complex shapes can be classified here.

        if shapes:
            prompt += ", ".join(shapes)

        # 2. Add more sophisticated analysis as you develop
        # ... (Object recognition, stroke analysis, etc.)

        headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}
        payload = {
            "inputs": prompt,
            "options": {"wait_for_model": True}
        }

        response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload)
        response.raise_for_status()

        generated_image_bytes = response.content
        generated_image_base64 = base64.b64encode(generated_image_bytes).decode('utf-8')
        return f"data:image/png;base64,{generated_image_base64}"

    except requests.exceptions.RequestException as e:
        return f"API Error: {e}"
    except Exception as e:
        return f"An error occurred: {e}"


with gr.Blocks() as demo:
    canvas = gr.Image(type="numpy", label="Draw Here")
    output_image = gr.Image(label="Generated Image")

    def throttle(func, time):
        import time
        last_call = 0
        def throttled(*args, **kwargs):
            nonlocal last_call
            now = time.time()
            if now - last_call > time:
                last_call = now
                return func(*args, **kwargs)
        return throttled

    throttled_generate = throttle(generate_image, 0.2)  # Adjust throttling

    canvas.change(throttled_generate, inputs=canvas, outputs=output_image)

demo.launch()