Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import random | |
import spaces #[uncomment to use ZeroGPU] | |
# from diffusers import DiffusionPipeline | |
from diffusers import FluxControlPipeline | |
from controlnet_aux import CannyDetector | |
from huggingface_hub import login | |
import torch | |
import subprocess | |
from groq import Groq | |
import base64 | |
from io import BytesIO | |
import os | |
from PIL import Image | |
from google import genai | |
from google.genai import types | |
login(token=os.environ.get("HF_API_KEY")) | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
# Load FLUX image generator | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use | |
flat_lora_path = "matteomarjanovic/flatsketcher" | |
canny_lora_path = "black-forest-labs/FLUX.1-Canny-dev-lora" | |
flat_weigths_file = "lora.safetensors" | |
canny_weigths_file = "flux1-canny-dev-lora.safetensors" | |
processor = CannyDetector() | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
pipe = FluxControlPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) | |
pipe = pipe.to(device) | |
pipe.load_lora_weights(flat_lora_path, weight_name=flat_weigths_file, adapter_name="flat") | |
pipe.load_lora_weights(canny_lora_path, weight_name=canny_weigths_file, adapter_name="canny") | |
pipe.set_adapters(["flat", "canny"], adapter_weights=[0.8, 0.4]) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
# def encode_image(image_path): | |
# with open(image_path, "rb") as image_file: | |
# return base64.b64encode(image_file.read()).decode('utf-8') | |
def encode_image(pil_image): | |
# Convert PIL image to bytes | |
buffered = BytesIO() | |
pil_image.save(buffered, format=pil_image.format or "PNG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
# @spaces.GPU #[uncomment to use ZeroGPU] | |
# def infer( | |
# prompt, | |
# progress=gr.Progress(track_tqdm=True), | |
# ): | |
# # seed = random.randint(0, MAX_SEED) | |
# # generator = torch.Generator().manual_seed(seed) | |
# image = pipe( | |
# prompt=prompt, | |
# guidance_scale=0., | |
# num_inference_steps=4, | |
# width=1420, | |
# height=1080, | |
# max_sequence_length=256, | |
# ).images[0] | |
# return image | |
description_prompt = """ | |
I want you to imagine how the technical flat sketch of the garment you see in the picture would look like, and describe it in rich details, in one paragraph. | |
Don't add any additional comment. | |
Specify that the flat sketch is black and white (even if the original garment has a color) and that it doesn't include the person that wear the garment. | |
Clarify that it's not made on a paper sheet, but it's digitally made, so it has plain white background, not paper. | |
Describe only the part that is visible in the picture (front or back of the garment, not both). | |
It should start with "The technical flat sketch of..." | |
The style of the result should look somewhat like the following example: | |
The technical flat sketch of the dress depicts a midi-length, off-the-shoulder design with a smocked bodice and short puff sleeves that have elasticized cuffs. | |
The elastic neckline sits straight across the chest, ensuring a secure fit. | |
The bodice transitions into a flowy, tiered skirt with three evenly spaced gathered panels, creating soft volume. | |
Elasticized areas are marked with textured lines, while the gathers and drape is indicated through subtle curved strokes, ensuring clarity in construction details. | |
The flat sketch does NOT include any person and it's only the in black and white, being a technical drawing. | |
""" | |
#[uncomment to use ZeroGPU] | |
def generate_description_fn( | |
image, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
base64_image = encode_image(image) | |
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) | |
response = client.models.generate_content( | |
model="gemini-2.0-flash", | |
contents=[description_prompt, image] | |
) | |
prompt = response.text + " In the style of FLTSKC" | |
control_image = processor( | |
image, | |
low_threshold=50, | |
high_threshold=200, | |
detect_resolution=1024, | |
image_resolution=1024 | |
) | |
width, height = control_image.size | |
image = pipe( | |
prompt=prompt, | |
control_image=control_image, | |
guidance_scale=3., | |
num_inference_steps=4, | |
width=width, | |
height=height, | |
max_sequence_length=256, | |
).images[0] | |
return prompt, image | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
.gradio-container { | |
background-color: oklch(98% 0 0); | |
} | |
.btn-primary { | |
background-color: #422ad5; | |
outline-color: #422ad5; | |
} | |
""" | |
def load_image(): | |
image_path = "hoodie.png" | |
default_img = Image.open(image_path) | |
return default_img | |
# generated_prompt = "" | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
# gr.Markdown("# Draptic: from garment image to technical flat sketch") | |
with gr.Row(): | |
with gr.Column(elem_id="col-input-image"): | |
# gr.Markdown(" ## Drop your image here") | |
input_image = gr.Image(type="pil", sources=["upload", "clipboard"]) | |
with gr.Column(elem_id="col-container"): | |
generate_button = gr.Button("Generate flat sketch", scale=0, variant="primary", elem_classes="btn btn-primary") | |
result = gr.Image(label="Result", show_label=False) | |
if result: | |
gr.Markdown("## Description of the garment:") | |
generated_prompt = gr.Markdown("") | |
gr.on( | |
triggers=[generate_button.click], | |
fn=generate_description_fn, | |
inputs=[ | |
input_image, | |
], | |
outputs=[generated_prompt, result], | |
) | |
demo.load(load_image, inputs=[], outputs=[input_image]) | |
if __name__ == "__main__": | |
demo.launch() | |