Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
from PIL import Image, ImageDraw, ImageFont | |
from src.condition import Condition | |
from diffusers.pipelines import FluxPipeline | |
import numpy as np | |
from src.generate import seed_everything, generate | |
pipe = None | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 | |
) | |
pipe = pipe.to("cuda") | |
pipe.load_lora_weights( | |
"Yuanshi/OminiControl", | |
weight_name=f"omini/subject_512.safetensors", | |
adapter_name="subject_512", | |
) | |
pipe.load_lora_weights( | |
"Yuanshi/OminiControl", | |
weight_name=f"omini/subject_1024_beta.safetensors", | |
adapter_name="subject_1024", | |
) | |
def process_image_and_text(image, resolution, text): | |
w, h, min_size = image.size[0], image.size[1], min(image.size) | |
image = image.crop( | |
( | |
(w - min_size) // 2, | |
(h - min_size) // 2, | |
(w + min_size) // 2, | |
(h + min_size) // 2, | |
) | |
) | |
image = image.resize((512, 512)) | |
condition = Condition("subject", image) | |
result_img = generate( | |
pipe, | |
prompt=text.strip(), | |
conditions=[condition], | |
num_inference_steps=8, | |
height=resolution, | |
width=resolution, | |
).images[0] | |
return result_img | |
def get_samples(): | |
sample_list = [ | |
{ | |
"image": "assets/oranges.jpg", | |
"resolution": 512, | |
"text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", | |
}, | |
{ | |
"image": "assets/penguin.jpg", | |
"resolution": 512, | |
"text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", | |
}, | |
{ | |
"image": "assets/rc_car.jpg", | |
"resolution": 1024, | |
"text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", | |
}, | |
{ | |
"image": "assets/clock.jpg", | |
"resolution": 1024, | |
"text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", | |
}, | |
] | |
return [ | |
[ | |
Image.open(sample["image"]).resize((512, 512)), | |
sample["resolution"], | |
sample["text"], | |
] | |
for sample in sample_list | |
] | |
header = """ | |
# π OminiControl / FLUX | |
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
<a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/π€-Model-ffbd45.svg" alt="HuggingFace"></a> | |
<a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
</div> | |
""" | |
def create_app(): | |
# with gr.Blocks() as app: | |
# gr.Markdown(header, elem_id="header") | |
# # with gr.Tabs(): | |
# # with gr.Tab("Subject-driven"): | |
# gr.Interface( | |
# fn=process_image_and_text, | |
# inputs=[ | |
# gr.Image(type="pil", label="Condition Image", width=300, elem_id="input"), | |
# gr.Radio( | |
# [("512", 512), ("1024(beta)", 1024)], | |
# label="Resolution", | |
# value=512, | |
# elem_id="resolution", | |
# ), | |
# # gr.Slider(4, 16, 4, step=4, label="Inference Steps"), | |
# gr.Textbox(lines=2, label="Text Prompt", elem_id="text"), | |
# ], | |
# outputs=gr.Image(type="pil", elem_id="output"), | |
# examples=get_samples(), | |
# ) | |
# # with gr.Tab("Fill"): | |
# # gr.Markdown("Coming soon") | |
# # with gr.Tab("Canny"): | |
# # gr.Markdown("Coming soon") | |
# # with gr.Tab("Depth"): | |
# # gr.Markdown("Coming soon") | |
with gr.Blocks() as app: | |
gr.Markdown(header, elem_id="header") | |
with gr.Row(equal_height=False): | |
with gr.Column(variant="panel", elem_classes="inputPanel"): | |
original_image = gr.Image( | |
type="pil", label="Condition Image", width=300, elem_id="input" | |
) | |
resolution = gr.Radio( | |
[("512", 512), ("1024(beta)", 1024)], | |
label="Resolution", | |
value=512, | |
elem_id="resolution", | |
) | |
text = gr.Textbox(lines=2, label="Text Prompt", elem_id="text") | |
submit_btn = gr.Button("Run", elem_id="submit_btn") | |
with gr.Column(variant="panel", elem_classes="outputPanel"): | |
output_image = gr.Image(type="pil", elem_id="output") | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=get_samples(), | |
inputs=[original_image, resolution, text], | |
label="Examples", | |
) | |
submit_btn.click( | |
fn=process_image_and_text, | |
inputs=[original_image, resolution, text], | |
outputs=output_image, | |
) | |
return app | |
if __name__ == "__main__": | |
create_app().launch(debug=True, ssr_mode=False) | |