Spaces:
Running
Running
import gradio as gr | |
import requests | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import numpy as np | |
import os | |
import cv2 # Import OpenCV | |
HUGGINGFACE_API_KEY = os.environ.get("HF_TOKEN") # Store API key as env variable | |
HF_MODEL = "stabilityai/stable-diffusion-xl-turbo" | |
def generate_image(image_data): | |
if image_data is None: | |
return "Please draw something on the canvas." | |
try: | |
decoded_image = base64.b64decode(image_data.split('base64,')[1]) | |
image = Image.open(BytesIO(decoded_image)).convert("RGB") | |
np_image = np.array(image) | |
# --- Dynamic Prompt Engineering (Improved) --- | |
prompt = "A drawing of " # Base prompt | |
# 1. Basic Shape Detection (Example with OpenCV) | |
gray = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY) # Convert to grayscale | |
_, thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY) # Threshold | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
shapes = [] | |
for contour in contours: | |
approx = cv2.approxPolyDP(contour, 0.02 * cv2.arcLength(contour, True), True) | |
num_vertices = len(approx) | |
if num_vertices == 3: | |
shapes.append("triangle") | |
elif num_vertices == 4: | |
shapes.append("quadrilateral") # Could be square, rectangle, etc. | |
elif num_vertices > 5: | |
shapes.append("circle-like shape") # More complex shapes can be classified here. | |
if shapes: | |
prompt += ", ".join(shapes) | |
# 2. Add more sophisticated analysis as you develop | |
# ... (Object recognition, stroke analysis, etc.) | |
headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"} | |
payload = { | |
"inputs": prompt, | |
"options": {"wait_for_model": True} | |
} | |
response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload) | |
response.raise_for_status() | |
generated_image_bytes = response.content | |
generated_image_base64 = base64.b64encode(generated_image_bytes).decode('utf-8') | |
return f"data:image/png;base64,{generated_image_base64}" | |
except requests.exceptions.RequestException as e: | |
return f"API Error: {e}" | |
except Exception as e: | |
return f"An error occurred: {e}" | |
with gr.Blocks() as demo: | |
canvas = gr.Image(type="numpy", label="Draw Here") | |
output_image = gr.Image(label="Generated Image") | |
def throttle(func, time): | |
import time | |
last_call = 0 | |
def throttled(*args, **kwargs): | |
nonlocal last_call | |
now = time.time() | |
if now - last_call > time: | |
last_call = now | |
return func(*args, **kwargs) | |
return throttled | |
throttled_generate = throttle(generate_image, 0.2) # Adjust throttling | |
canvas.change(throttled_generate, inputs=canvas, outputs=output_image) | |
demo.launch() |