Spaces:
Runtime error
Runtime error
| import gradio | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import spaces | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
| from diffusers.utils import load_image | |
| import torch | |
| import accelerate | |
| import transformers | |
| from random import randrange | |
| from transformers.utils.hub import move_cache | |
| move_cache() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| base_model_id = "runwayml/stable-diffusion-v1-5" | |
| model_id = "LuyangZ/FloorAI" | |
| # model_id = "LuyangZ/controlnet_Neufert4_64_100" | |
| # controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) | |
| # controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype="auto") | |
| # controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32, force_download=True) | |
| controlnet = ControlNetModel.from_pretrained(model_id, force_download=True) | |
| controlnet.to(device) | |
| torch.cuda.empty_cache() | |
| # pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float32, force_download=True) | |
| # pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype="auto") | |
| # pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float16) | |
| pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id, controlnet=controlnet, force_download=True) | |
| pipeline.safety_checker = None | |
| pipeline.requires_safety_checker = False | |
| pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | |
| # pipeline.enable_xformers_memory_efficient_attention() | |
| # pipeline.enable_model_cpu_offload() | |
| # pipeline.enable_attention_slicing() | |
| pipeline = pipeline.to(device) | |
| torch.cuda.empty_cache() | |
| def expand2square(ol_img, background_color): | |
| width, height = ol_img.size | |
| if width == height: | |
| pad = int(width*0.2) | |
| width_new = width + pad | |
| halfpad = int(pad/2) | |
| ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color) | |
| ol_result.paste(ol_img, (halfpad, halfpad)) | |
| return ol_img | |
| elif width > height: | |
| pad = int(width*0.2) | |
| width_new = width + pad | |
| halfpad = int(pad/2) | |
| ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color) | |
| ol_result.paste(ol_img, (halfpad, (width_new - height) // 2)) | |
| return ol_result | |
| else: | |
| pad = int(height*0.2) | |
| height_new = height + pad | |
| halfpad = int(pad/2) | |
| ol_result = Image.new(ol_img.mode, (height_new, height_new), background_color) | |
| ol_result.paste(ol_img, ((height_new - width) // 2, halfpad)) | |
| return ol_result | |
| def clean_img(image, mask): | |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| mask = cv2.threshold(mask, 250, 255, cv2.THRESH_BINARY_INV)[1] | |
| image[mask<250]=(255,255,255) | |
| image = Image.fromarray(image).convert('RGB') | |
| return image | |
| def floorplan_generation(outline, num_of_rooms): | |
| new_width = 512 | |
| new_height = 512 | |
| outline = cv2.cvtColor(outline, cv2.COLOR_RGB2BGR) | |
| outline_original = outline.copy() | |
| gray = cv2.cvtColor(outline, cv2.COLOR_BGR2GRAY) | |
| thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)[1] | |
| x,y,w,h = cv2.boundingRect(thresh) | |
| n_outline = outline_original[y:y+h, x:x+w] | |
| n_outline = cv2.cvtColor(n_outline, cv2.COLOR_BGR2RGB) | |
| n_outline = Image.fromarray(n_outline).convert('RGB') | |
| n_outline = expand2square(n_outline, (255, 255, 255)) | |
| n_outline = n_outline.resize((new_width, new_height)) | |
| num_of_rooms = str(num_of_rooms) | |
| validation_prompt = "floor plan, " + num_of_rooms + " rooms" | |
| validation_image = n_outline | |
| image_lst = [] | |
| for i in range(5): | |
| seed = randrange(5000) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| image = pipeline(validation_prompt, | |
| validation_image, | |
| num_inference_steps=20, | |
| generator=generator).images[0] | |
| image = np.array(image) | |
| mask = np.array(n_outline) | |
| mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | |
| image = clean_img(image, mask) | |
| image_lst.append(image) | |
| return image_lst[0], image_lst[1], image_lst[2], image_lst[3], image_lst[4] | |
| gradio_interface = gradio.Interface( | |
| fn=floorplan_generation, | |
| inputs=[gradio.Image(label="Floor Plan Outline, Entrance"), | |
| gradio.Textbox(type="text", label="Number of Rooms", placeholder="Number of Rooms")], | |
| outputs=[gradio.Image(label="Generated Floor Plan 1"), | |
| gradio.Image(label="Generated Floor Plan 2"), | |
| gradio.Image(label="Generated Floor Plan 3"), | |
| gradio.Image(label="Generated Floor Plan 4"), | |
| gradio.Image(label="Generated Floor Plan 5")], | |
| title="FloorAI", | |
| examples=[["example_1.png", "4"], ["example_2.png", "3"], ["example_3.png", "2"], ["example_4.png", "4"], ["example_5.png", "4"]]) | |
| gradio_interface.queue(max_size=10, status_update_rate="auto", api_open=True) | |
| gradio_interface.launch(share=True, show_api=True, show_error=True) | |