FloorAI / app.py
LuyangZ's picture
Update app.py
8211204 verified
raw
history blame
4.73 kB
import gradio
import cv2
from PIL import Image
import numpy as np
+import spaces
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import accelerate
import transformers
from random import randrange
transformers.utils.move_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model_id = "runwayml/stable-diffusion-v1-5"
model_id = "LuyangZ/FloorAI"
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype="auto")
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
controlnet.to(device)
# torch.cuda.empty_cache()
pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float32)
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype="auto")
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float16)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
# pipeline.enable_xformers_memory_efficient_attention()
# pipeline.enable_model_cpu_offload()
# pipeline.enable_attention_slicing()
pipeline = pipeline.to(device)
# torch.cuda.empty_cache()
def expand2square(ol_img, background_color):
width, height = ol_img.size
if width == height:
pad = int(width*0.2)
width_new = width + pad
halfpad = int(pad/2)
ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color)
ol_result.paste(ol_img, (halfpad, halfpad))
return ol_img
elif width > height:
pad = int(width*0.2)
width_new = width + pad
halfpad = int(pad/2)
ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color)
ol_result.paste(ol_img, (halfpad, (width_new - height) // 2))
return ol_result
else:
pad = int(height*0.2)
height_new = height + pad
halfpad = int(pad/2)
ol_result = Image.new(ol_img.mode, (height_new, height_new), background_color)
ol_result.paste(ol_img, ((height_new - width) // 2, halfpad))
return ol_result
def clean_img(image, mask):
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
mask = cv2.threshold(mask, 250, 255, cv2.THRESH_BINARY_INV)[1]
image[mask<250]=(255,255,255)
image = Image.fromarray(image).convert('RGB')
return image
[email protected]
def floorplan_generation(outline, num_of_rooms):
new_width = 512
new_height = 512
outline = cv2.cvtColor(outline, cv2.COLOR_RGB2BGR)
outline_original = outline.copy()
gray = cv2.cvtColor(outline, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)[1]
x,y,w,h = cv2.boundingRect(thresh)
n_outline = outline_original[y:y+h, x:x+w]
n_outline = cv2.cvtColor(n_outline, cv2.COLOR_BGR2RGB)
n_outline = Image.fromarray(n_outline).convert('RGB')
n_outline = expand2square(n_outline, (255, 255, 255))
n_outline = n_outline.resize((new_width, new_height))
num_of_rooms = str(num_of_rooms)
validation_prompt = "floor plan, " + num_of_rooms + " rooms"
validation_image = n_outline
image_lst = []
for i in range(5):
seed = randrange(500)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipeline(validation_prompt,
validation_image,
num_inference_steps=20,
generator=generator).images[0]
image = np.array(image)
mask = np.array(n_outline)
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
image = clean_img(image, mask)
image_lst.append(image)
return image_lst[0], image_lst[1], image_lst[2], image_lst[3], image_lst[4]
gradio_interface = gradio.Interface(
fn=floorplan_generation,
inputs=[gradio.Image(label="Floor Plan Outline, Entrance"),
gradio.Textbox(type="text", label="number of rooms", placeholder="number of rooms")],
outputs=[gradio.Image(label="Generated Floor Plan 1"),
gradio.Image(label="Generated Floor Plan 2"),
gradio.Image(label="Generated Floor Plan 3"),
gradio.Image(label="Generated Floor Plan 4"),
gradio.Image(label="Generated Floor Plan 5")],
title="FloorAI")
gradio_interface.queue(max_size=10, status_update_rate="auto")
gradio_interface.launch(share=True, show_api=True, show_error=True)