draptic-demo / app.py
matteomarjanovic's picture
change way of loading image
871938d
import gradio as gr
import numpy as np
import random
import spaces #[uncomment to use ZeroGPU]
# from diffusers import DiffusionPipeline
from diffusers import FluxControlPipeline
from controlnet_aux import CannyDetector
from huggingface_hub import login
import torch
import subprocess
from groq import Groq
import base64
from io import BytesIO
import os
from PIL import Image
from google import genai
from google.genai import types
login(token=os.environ.get("HF_API_KEY"))
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
# Load FLUX image generator
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use
flat_lora_path = "matteomarjanovic/flatsketcher"
canny_lora_path = "black-forest-labs/FLUX.1-Canny-dev-lora"
flat_weigths_file = "lora.safetensors"
canny_weigths_file = "flux1-canny-dev-lora.safetensors"
processor = CannyDetector()
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = FluxControlPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
pipe.load_lora_weights(flat_lora_path, weight_name=flat_weigths_file, adapter_name="flat")
pipe.load_lora_weights(canny_lora_path, weight_name=canny_weigths_file, adapter_name="canny")
pipe.set_adapters(["flat", "canny"], adapter_weights=[0.8, 0.4])
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# def encode_image(image_path):
# with open(image_path, "rb") as image_file:
# return base64.b64encode(image_file.read()).decode('utf-8')
def encode_image(pil_image):
# Convert PIL image to bytes
buffered = BytesIO()
pil_image.save(buffered, format=pil_image.format or "PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
# @spaces.GPU #[uncomment to use ZeroGPU]
# def infer(
# prompt,
# progress=gr.Progress(track_tqdm=True),
# ):
# # seed = random.randint(0, MAX_SEED)
# # generator = torch.Generator().manual_seed(seed)
# image = pipe(
# prompt=prompt,
# guidance_scale=0.,
# num_inference_steps=4,
# width=1420,
# height=1080,
# max_sequence_length=256,
# ).images[0]
# return image
description_prompt = """
I want you to imagine how the technical flat sketch of the garment you see in the picture would look like, and describe it in rich details, in one paragraph.
Don't add any additional comment.
Specify that the flat sketch is black and white (even if the original garment has a color) and that it doesn't include the person that wear the garment.
Clarify that it's not made on a paper sheet, but it's digitally made, so it has plain white background, not paper.
Describe only the part that is visible in the picture (front or back of the garment, not both).
It should start with "The technical flat sketch of..."
The style of the result should look somewhat like the following example:
The technical flat sketch of the dress depicts a midi-length, off-the-shoulder design with a smocked bodice and short puff sleeves that have elasticized cuffs.
The elastic neckline sits straight across the chest, ensuring a secure fit.
The bodice transitions into a flowy, tiered skirt with three evenly spaced gathered panels, creating soft volume.
Elasticized areas are marked with textured lines, while the gathers and drape is indicated through subtle curved strokes, ensuring clarity in construction details.
The flat sketch does NOT include any person and it's only the in black and white, being a technical drawing.
"""
@spaces.GPU #[uncomment to use ZeroGPU]
def generate_description_fn(
image,
progress=gr.Progress(track_tqdm=True),
):
base64_image = encode_image(image)
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=[description_prompt, image]
)
prompt = response.text + " In the style of FLTSKC"
control_image = processor(
image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024
)
width, height = control_image.size
image = pipe(
prompt=prompt,
control_image=control_image,
guidance_scale=3.,
num_inference_steps=4,
width=width,
height=height,
max_sequence_length=256,
).images[0]
return prompt, image
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
.gradio-container {
background-color: oklch(98% 0 0);
}
.btn-primary {
background-color: #422ad5;
outline-color: #422ad5;
}
"""
def load_image():
image_path = "hoodie.png"
default_img = Image.open(image_path)
return default_img
# generated_prompt = ""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
# gr.Markdown("# Draptic: from garment image to technical flat sketch")
with gr.Row():
with gr.Column(elem_id="col-input-image"):
# gr.Markdown(" ## Drop your image here")
input_image = gr.Image(type="pil", sources=["upload", "clipboard"])
with gr.Column(elem_id="col-container"):
generate_button = gr.Button("Generate flat sketch", scale=0, variant="primary", elem_classes="btn btn-primary")
result = gr.Image(label="Result", show_label=False)
if result:
gr.Markdown("## Description of the garment:")
generated_prompt = gr.Markdown("")
gr.on(
triggers=[generate_button.click],
fn=generate_description_fn,
inputs=[
input_image,
],
outputs=[generated_prompt, result],
)
demo.load(load_image, inputs=[], outputs=[input_image])
if __name__ == "__main__":
demo.launch()