Sketch-2-IMG / app.py
Donmill's picture
fix
d64bd4a verified
import gradio as gr
import requests
import base64
from io import BytesIO
from PIL import Image
import numpy as np
import os
import cv2 # Import OpenCV
HUGGINGFACE_API_KEY = os.environ.get("HF_TOKEN") # Store API key as env variable
HF_MODEL = "stabilityai/stable-diffusion-xl-turbo"
def generate_image(image_data):
if image_data is None:
return "Please draw something on the canvas."
try:
decoded_image = base64.b64decode(image_data.split('base64,')[1])
image = Image.open(BytesIO(decoded_image)).convert("RGB")
np_image = np.array(image)
# --- Dynamic Prompt Engineering (Improved) ---
prompt = "A drawing of " # Base prompt
# 1. Basic Shape Detection (Example with OpenCV)
gray = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY) # Convert to grayscale
_, thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY) # Threshold
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
shapes = []
for contour in contours:
approx = cv2.approxPolyDP(contour, 0.02 * cv2.arcLength(contour, True), True)
num_vertices = len(approx)
if num_vertices == 3:
shapes.append("triangle")
elif num_vertices == 4:
shapes.append("quadrilateral") # Could be square, rectangle, etc.
elif num_vertices > 5:
shapes.append("circle-like shape") # More complex shapes can be classified here.
if shapes:
prompt += ", ".join(shapes)
# 2. Add more sophisticated analysis as you develop
# ... (Object recognition, stroke analysis, etc.)
headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}
payload = {
"inputs": prompt,
"options": {"wait_for_model": True}
}
response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload)
response.raise_for_status()
generated_image_bytes = response.content
generated_image_base64 = base64.b64encode(generated_image_bytes).decode('utf-8')
return f"data:image/png;base64,{generated_image_base64}"
except requests.exceptions.RequestException as e:
return f"API Error: {e}"
except Exception as e:
return f"An error occurred: {e}"
with gr.Blocks() as demo:
canvas = gr.Image(type="numpy", label="Draw Here")
output_image = gr.Image(label="Generated Image")
def throttle(func, time):
import time
last_call = 0
def throttled(*args, **kwargs):
nonlocal last_call
now = time.time()
if now - last_call > time:
last_call = now
return func(*args, **kwargs)
return throttled
throttled_generate = throttle(generate_image, 0.2) # Adjust throttling
canvas.change(throttled_generate, inputs=canvas, outputs=output_image)
demo.launch()