Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import time | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from src_inference.pipeline import FluxPipeline | |
| from src_inference.lora_helper import set_single_lora | |
| BASE_PATH = "black-forest-labs/FLUX.1-dev" | |
| LOCAL_LORA_DIR = "./LoRAs" | |
| CUSTOM_LORA_DIR = "./Custom_LoRAs" | |
| os.makedirs(LOCAL_LORA_DIR, exist_ok=True) | |
| os.makedirs(CUSTOM_LORA_DIR, exist_ok=True) | |
| print("downloading OmniConsistency base LoRA β¦") | |
| omni_consistency_path = hf_hub_download( | |
| repo_id="showlab/OmniConsistency", | |
| filename="OmniConsistency.safetensors", | |
| local_dir="./Model" | |
| ) | |
| print("loading base pipeline β¦") | |
| pipe = FluxPipeline.from_pretrained( | |
| BASE_PATH, torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| set_single_lora(pipe.transformer, omni_consistency_path, | |
| lora_weights=[1], cond_size=512) | |
| def download_all_loras(): | |
| lora_names = [ | |
| "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy", | |
| "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line", | |
| "Macaron", "Oil_Painting", "Origami", "Paper_Cutting", | |
| "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty", | |
| "Snoopy", "Van_Gogh", "Vector" | |
| ] | |
| for name in lora_names: | |
| hf_hub_download( | |
| repo_id="showlab/OmniConsistency", | |
| filename=f"LoRAs/{name}_rank128_bf16.safetensors", | |
| local_dir=LOCAL_LORA_DIR, | |
| ) | |
| download_all_loras() | |
| def clear_cache(transformer): | |
| for _, attn_processor in transformer.attn_processors.items(): | |
| attn_processor.bank_kv.clear() | |
| def generate_image( | |
| lora_name, | |
| custom_repo_id, | |
| prompt, | |
| uploaded_image, | |
| width, height, | |
| guidance_scale, | |
| num_inference_steps, | |
| seed | |
| ): | |
| width, height = int(width), int(height) | |
| generator = torch.Generator("cpu").manual_seed(seed) | |
| if custom_repo_id and custom_repo_id.strip(): | |
| repo_id = custom_repo_id.strip() | |
| try: | |
| files = list_repo_files(repo_id) | |
| print("using custom LoRA from:", repo_id) | |
| safetensors_files = [f for f in files if f.endswith(".safetensors")] | |
| print("found safetensors files:", safetensors_files) | |
| if not safetensors_files: | |
| raise ValueError("No .safetensors files were found in this repo") | |
| fname = safetensors_files[0] | |
| lora_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=fname, | |
| local_dir=CUSTOM_LORA_DIR, | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Load custom LoRA failed: {e}") | |
| else: | |
| lora_path = os.path.join( | |
| f"{LOCAL_LORA_DIR}/LoRAs", f"{lora_name}_rank128_bf16.safetensors" | |
| ) | |
| pipe.unload_lora_weights() | |
| try: | |
| pipe.load_lora_weights( | |
| os.path.dirname(lora_path), | |
| weight_name=os.path.basename(lora_path) | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Load LoRA failed: {e}") | |
| spatial_image = [uploaded_image.convert("RGB")] | |
| subject_images = [] | |
| start = time.time() | |
| out_img = pipe( | |
| prompt, | |
| height=(height // 8) * 8, | |
| width=(width // 8) * 8, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| max_sequence_length=512, | |
| generator=generator, | |
| spatial_images=spatial_image, | |
| subject_images=subject_images, | |
| cond_size=512, | |
| ).images[0] | |
| print(f"inference time: {time.time()-start:.2f}s") | |
| clear_cache(pipe.transformer) | |
| return uploaded_image, out_img | |
| # =============== Gradio UI =============== | |
| def create_interface(): | |
| demo_lora_names = [ | |
| "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy", | |
| "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line", | |
| "Macaron", "Oil_Painting", "Origami", "Paper_Cutting", | |
| "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty", | |
| "Snoopy", "Van_Gogh", "Vector" | |
| ] | |
| # Example data | |
| examples = [ | |
| ["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.", | |
| Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42], | |
| ["Clay_Toy", "", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.", | |
| Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42], | |
| ["American_Cartoon", "", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' β capturing both tension and humor.", | |
| Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42], | |
| ["Origami", "", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful momentβcapturing the excitement of meeting a football legend.", | |
| Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42], | |
| ["Vector", "", "Vector style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.", | |
| Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42] | |
| ] | |
| header = """ | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/abs/2505.18445"><img src="https://img.shields.io/badge/ariXv-2505.18445-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://huggingface.co/showlab/OmniConsistency"><img src="https://img.shields.io/badge/π€_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a> | |
| <a href="https://github.com/showlab/OmniConsistency"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
| </div> | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# OmniConsistency LoRA Image Generation") | |
| gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.") | |
| gr.HTML(header) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| lora_dropdown = gr.Dropdown( | |
| demo_lora_names, label="Select built-in LoRA") | |
| custom_repo_box = gr.Textbox( | |
| label="Enter Custom LoRA", | |
| placeholder="LoRA Hugging Face path (e.g., 'username/repo_name')", | |
| info="If you want to use a custom LoRA, enter its Hugging Face repo ID here and built-in LoRA will be Overridden. Leave empty to use built-in LoRAs. [Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)" | |
| ) | |
| prompt_box = gr.Textbox(label="Prompt", | |
| placeholder="Enter your prompt here", | |
| info="Remember to include the necessary trigger words if you're using a custom LoRA." | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| with gr.Column(scale=1): | |
| output_image = gr.ImageSlider(label="Generated Image") | |
| height_box = gr.Textbox(value="1024", label="Height") | |
| width_box = gr.Textbox(value="1024", label="Width") | |
| guidance_slider = gr.Slider( | |
| 0.1, 20, value=3.5, step=0.1, label="Guidance Scale") | |
| steps_slider = gr.Slider( | |
| 1, 50, value=25, step=1, label="Inference Steps") | |
| seed_slider = gr.Slider( | |
| 1, 2_147_483_647, value=42, step=1, label="Seed") | |
| gen_btn = gr.Button("Generate") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input, | |
| height_box, width_box, guidance_slider, steps_slider, seed_slider], | |
| outputs=output_image, | |
| fn=generate_image, | |
| cache_examples=False, | |
| label="Examples" | |
| ) | |
| gen_btn.click( | |
| fn=generate_image, | |
| inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input, | |
| width_box, height_box, guidance_slider, steps_slider, seed_slider], | |
| outputs=output_image | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(ssr_mode=False) | |