import spaces import argparse import os import shutil import cv2 import gradio as gr import numpy as np import torch from facexlib.utils.face_restoration_helper import FaceRestoreHelper import huggingface_hub from huggingface_hub import hf_hub_download from PIL import Image from torchvision.transforms.functional import normalize from dreamo.dreamo_pipeline import DreamOPipeline from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img, resize_numpy_image_long from tools import BEN2 parser = argparse.ArgumentParser() parser.add_argument('--port', type=int, default=8080) parser.add_argument('--no_turbo', action='store_true') args = parser.parse_args() huggingface_hub.login(os.getenv('HF_TOKEN')) try: shutil.rmtree('gradio_cached_examples') except FileNotFoundError: print("cache folder not exist") class Generator: def __init__(self): device = torch.device('cuda') # preprocessing models # background remove model: BEN2 self.bg_rm_model = BEN2.BEN_Base().to(device).eval() hf_hub_download(repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', local_dir='models') self.bg_rm_model.loadcheckpoints('models/BEN2_Base.pth') # face crop and align tool: facexlib self.face_helper = FaceRestoreHelper( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device=device, ) # load dreamo model_root = 'black-forest-labs/FLUX.1-dev' dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16) dreamo_pipeline.load_dreamo_model(device, use_turbo=not args.no_turbo) self.dreamo_pipeline = dreamo_pipeline.to(device) @torch.no_grad() def get_align_face(self, img): # the face preprocessing code is same as PuLID self.face_helper.clean_all() image_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) self.face_helper.read_image(image_bgr) self.face_helper.get_face_landmarks_5(only_center_face=True) self.face_helper.align_warp_face() if len(self.face_helper.cropped_faces) == 0: return None align_face = self.face_helper.cropped_faces[0] input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 input = input.to(torch.device("cuda")) parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] parsing_out = parsing_out.argmax(dim=1, keepdim=True) bg_label = [0, 16, 18, 7, 8, 9, 14, 15] bg = sum(parsing_out == i for i in bg_label).bool() white_image = torch.ones_like(input) # only keep the face features face_features_image = torch.where(bg, white_image, input) face_features_image = tensor2img(face_features_image, rgb2bgr=False) return face_features_image generator = Generator() @spaces.GPU @torch.inference_mode() def generate_image( ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed, width=1024, height=1024, ref_res=512, num_steps=12, guidance=3.5, true_cfg=1, cfg_start_step=0, cfg_end_step=0, neg_prompt='', neg_guidance=3.5, first_step_guidance=0, ): print(prompt) ref_conds = [] debug_images = [] ref_images = [ref_image1, ref_image2] ref_tasks = [ref_task1, ref_task2] for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)): if ref_image is not None: if ref_task == "id": ref_image = resize_numpy_image_long(ref_image, 1024) ref_image = generator.get_align_face(ref_image) elif ref_task != "style": ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image)) if ref_task != "id": ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res) debug_images.append(ref_image) ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0 ref_image = 2 * ref_image - 1.0 ref_conds.append( { 'img': ref_image, 'task': ref_task, 'idx': idx + 1, } ) seed = int(seed) if seed == -1: seed = torch.Generator(device="cpu").seed() image = generator.dreamo_pipeline( prompt=prompt, width=width, height=height, num_inference_steps=num_steps, guidance_scale=guidance, ref_conds=ref_conds, generator=torch.Generator(device="cpu").manual_seed(seed), true_cfg_scale=true_cfg, true_cfg_start_step=cfg_start_step, true_cfg_end_step=cfg_end_step, negative_prompt=neg_prompt, neg_guidance_scale=neg_guidance, first_step_guidance_scale=first_step_guidance if first_step_guidance > 0 else guidance, ).images[0] return image, debug_images, seed # Custom CSS for pastel theme _CUSTOM_CSS_ = """ :root { --primary-color: #f8c3cd; /* Sakura pink - primary accent */ --secondary-color: #b3e5fc; /* Pastel blue - secondary accent */ --background-color: #f5f5f7; /* Very light gray background */ --card-background: #ffffff; /* White for cards */ --text-color: #424242; /* Dark gray for text */ --accent-color: #ffb6c1; /* Light pink for accents */ --success-color: #c8e6c9; /* Pastel green for success */ --warning-color: #fff9c4; /* Pastel yellow for warnings */ --shadow-color: rgba(0, 0, 0, 0.1); /* Shadow color */ --border-radius: 12px; /* Rounded corners */ } body { background-color: var(--background-color) !important; font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; } .gradio-container { max-width: 1200px !important; margin: 0 auto !important; } /* Header styling */ h1 { color: #9c27b0 !important; font-weight: 800 !important; text-shadow: 2px 2px 4px rgba(156, 39, 176, 0.2) !important; letter-spacing: -0.5px !important; } /* Card styling for panels */ .panel-box { border-radius: var(--border-radius) !important; box-shadow: 0 8px 16px var(--shadow-color) !important; background-color: var(--card-background) !important; border: none !important; overflow: hidden !important; padding: 20px !important; margin-bottom: 20px !important; } /* Button styling */ button.gr-button { background: linear-gradient(135deg, var(--primary-color), #e1bee7) !important; border-radius: var(--border-radius) !important; color: #4a148c !important; font-weight: 600 !important; border: none !important; padding: 10px 20px !important; transition: all 0.3s ease !important; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; } button.gr-button:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 10px rgba(0, 0, 0, 0.15) !important; background: linear-gradient(135deg, #e1bee7, var(--primary-color)) !important; } /* Input fields styling */ input, select, textarea, .gr-input { border-radius: 8px !important; border: 2px solid #e0e0e0 !important; padding: 10px 15px !important; transition: all 0.3s ease !important; background-color: #fafafa !important; } input:focus, select:focus, textarea:focus, .gr-input:focus { border-color: var(--primary-color) !important; box-shadow: 0 0 0 3px rgba(248, 195, 205, 0.3) !important; } /* Slider styling */ .gr-form input[type=range] { appearance: none !important; width: 100% !important; height: 6px !important; background: #e0e0e0 !important; border-radius: 5px !important; outline: none !important; } .gr-form input[type=range]::-webkit-slider-thumb { appearance: none !important; width: 16px !important; height: 16px !important; background: var(--primary-color) !important; border-radius: 50% !important; cursor: pointer !important; border: 2px solid white !important; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important; } /* Dropdown styling */ .gr-form select { background-color: white !important; border: 2px solid #e0e0e0 !important; border-radius: 8px !important; padding: 10px 15px !important; } .gr-form select option { padding: 10px !important; } /* Image upload area */ .gr-image-input { border: 2px dashed #b39ddb !important; border-radius: var(--border-radius) !important; background-color: #f3e5f5 !important; padding: 20px !important; display: flex !important; flex-direction: column !important; align-items: center !important; justify-content: center !important; transition: all 0.3s ease !important; } .gr-image-input:hover { background-color: #ede7f6 !important; border-color: #9575cd !important; } /* Add a nice pattern to the background */ body::before { content: "" !important; position: fixed !important; top: 0 !important; left: 0 !important; width: 100% !important; height: 100% !important; background: radial-gradient(circle at 10% 20%, rgba(248, 195, 205, 0.1) 0%, rgba(245, 245, 247, 0) 20%), radial-gradient(circle at 80% 70%, rgba(179, 229, 252, 0.1) 0%, rgba(245, 245, 247, 0) 20%) !important; pointer-events: none !important; z-index: -1 !important; } /* Gallery styling */ .gr-gallery { grid-gap: 15px !important; } .gr-gallery-item { border-radius: var(--border-radius) !important; overflow: hidden !important; box-shadow: 0 4px 8px var(--shadow-color) !important; transition: transform 0.3s ease !important; } .gr-gallery-item:hover { transform: scale(1.02) !important; } /* Label styling */ .gr-form label { font-weight: 600 !important; color: #673ab7 !important; margin-bottom: 5px !important; } /* Improve spacing */ .gr-padded { padding: 20px !important; } .gr-compact { gap: 15px !important; } .gr-form > div { margin-bottom: 16px !important; } /* Headings */ .gr-form h3 { color: #7b1fa2 !important; margin-top: 5px !important; margin-bottom: 15px !important; border-bottom: 2px solid #e1bee7 !important; padding-bottom: 8px !important; } /* Examples section */ #examples-panel { background-color: #f3e5f5 !important; border-radius: var(--border-radius) !important; padding: 15px !important; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05) !important; } #examples-panel h2 { color: #7b1fa2 !important; font-size: 1.5rem !important; margin-bottom: 15px !important; } /* Accordion styling */ .gr-accordion { border: 1px solid #e0e0e0 !important; border-radius: var(--border-radius) !important; overflow: hidden !important; } .gr-accordion summary { padding: 12px 16px !important; background-color: #f9f9f9 !important; cursor: pointer !important; font-weight: 600 !important; color: #673ab7 !important; } /* Generate button special styling */ #generate-btn { background: linear-gradient(135deg, #ff9a9e, #fad0c4) !important; font-size: 1.1rem !important; padding: 12px 24px !important; margin-top: 10px !important; margin-bottom: 15px !important; width: 100% !important; } #generate-btn:hover { background: linear-gradient(135deg, #fad0c4, #ff9a9e) !important; } """ _HEADER_ = '''

✨ DreamO Video ✨

Create customized images with advanced AI

Paper: DreamO: A Unified Framework for Image Customization | Codes: GitHub

🚩 Update Notes:

''' _CITE_ = r"""

If DreamO is helpful, please help to ⭐ the community. Thanks!


📧 Contact

If you have any questions or feedback, feel free to open a discussion or contact arxivgpt@gmail.com

""" def create_demo(): with gr.Blocks(css=_CUSTOM_CSS_) as demo: gr.HTML(_HEADER_) with gr.Row(): with gr.Column(scale=6): # Input panel - using a Group div with custom class instead of Box with gr.Group(elem_id="input-panel", elem_classes="panel-box"): gr.Markdown("### 📸 Reference Images") with gr.Row(): with gr.Column(): ref_image1 = gr.Image(label="Reference Image 1", type="numpy", height=256, elem_id="ref-image-1") ref_task1 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Task for Reference Image 1", elem_id="ref-task-1") with gr.Column(): ref_image2 = gr.Image(label="Reference Image 2", type="numpy", height=256, elem_id="ref-image-2") ref_task2 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Task for Reference Image 2", elem_id="ref-task-2") gr.Markdown("### ✏️ Generation Parameters") prompt = gr.Textbox(label="Prompt", value="a person playing guitar in the street", elem_id="prompt-input") with gr.Row(): width = gr.Slider(768, 1024, 1024, step=16, label="Width", elem_id="width-slider") height = gr.Slider(768, 1024, 1024, step=16, label="Height", elem_id="height-slider") with gr.Row(): num_steps = gr.Slider(8, 30, 12, step=1, label="Number of Steps", elem_id="steps-slider") guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance Scale", elem_id="guidance-slider") seed = gr.Textbox(label="Seed (-1 for random)", value="-1", elem_id="seed-input") with gr.Accordion("Advanced Options", open=False): ref_res = gr.Slider(512, 1024, 512, step=16, label="Resolution for Reference Image") neg_prompt = gr.Textbox(label="Negative Prompt", value="") neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Negative Guidance") with gr.Row(): true_cfg = gr.Slider(1, 5, 1, step=0.1, label="True CFG") first_step_guidance = gr.Slider(0, 10, 0, step=0.1, label="First Step Guidance") with gr.Row(): cfg_start_step = gr.Slider(0, 30, 0, step=1, label="CFG Start Step") cfg_end_step = gr.Slider(0, 30, 0, step=1, label="CFG End Step") generate_btn = gr.Button("✨ Generate Image", elem_id="generate-btn") gr.HTML(_CITE_) with gr.Column(scale=6): # Output panel - using a Group div with custom class instead of Box with gr.Group(elem_id="output-panel", elem_classes="panel-box"): gr.Markdown("### 🖼️ Generated Result") output_image = gr.Image(label="Generated Image", elem_id="output-image", format='png') seed_output = gr.Textbox(label="Used Seed", elem_id="seed-output") gr.Markdown("### 🔍 Preprocessing") debug_image = gr.Gallery( label="Preprocessing Results (including face crop and background removal)", elem_id="debug-gallery", ) # Examples panel - using a Group div with custom class instead of Box with gr.Group(elem_id="examples-panel", elem_classes="panel-box"): gr.Markdown("## 📚 Examples") example_inps = [ [ 'example_inputs/choi.jpg', None, 'ip', 'ip', 'a woman sitting on the cloud, playing guitar', 1206523688721442817, ], [ 'example_inputs/choi.jpg', None, 'id', 'ip', 'a woman holding a sign saying "TOP", on the mountain', 10441727852953907380, ], [ 'example_inputs/perfume.png', None, 'ip', 'ip', 'a perfume under spotlight', 116150031980664704, ], [ 'example_inputs/choi.jpg', None, 'id', 'ip', 'portrait, in alps', 5443415087540486371, ], [ 'example_inputs/mickey.png', None, 'style', 'ip', 'generate a same style image. A rooster wearing overalls.', 6245580464677124951, ], [ 'example_inputs/mountain.png', None, 'style', 'ip', 'generate a same style image. A pavilion by the river, and the distant mountains are endless', 5248066378927500767, ], [ 'example_inputs/shirt.png', 'example_inputs/skirt.jpeg', 'ip', 'ip', 'A girl is wearing a short-sleeved shirt and a short skirt on the beach.', 9514069256241143615, ], [ 'example_inputs/woman2.png', 'example_inputs/dress.png', 'id', 'ip', 'the woman wearing a dress, In the banquet hall', 7698454872441022867, ], [ 'example_inputs/dog1.png', 'example_inputs/dog2.png', 'ip', 'ip', 'two dogs in the jungle', 6187006025405083344, ], ] gr.Examples( examples=example_inps, inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed], label='Examples by category: IP task (rows 1-4), ID task (row 5), Style task (rows 6-7), Try-On task (rows 8-9)', cache_examples='lazy', outputs=[output_image, debug_image, seed_output], fn=generate_image, ) generate_btn.click( fn=generate_image, inputs=[ ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed, width, height, ref_res, num_steps, guidance, true_cfg, cfg_start_step, cfg_end_step, neg_prompt, neg_guidance, first_step_guidance, ], outputs=[output_image, debug_image, seed_output], ) return demo if __name__ == '__main__': demo = create_demo() demo.launch()