|
import threading |
|
from collections import deque |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
|
|
from constants import DESCRIPTION, LOGO |
|
from gradio_examples import EXAMPLES |
|
from model import get_pipeline |
|
from utils import replace_background |
|
|
|
MAX_QUEUE_SIZE = 4 |
|
|
|
pipeline = get_pipeline() |
|
|
|
|
|
@dataclass |
|
class GenerationState: |
|
prompts: deque |
|
generations: deque |
|
|
|
|
|
def get_initial_state() -> GenerationState: |
|
return GenerationState( |
|
prompts=deque(maxlen=MAX_QUEUE_SIZE), |
|
generations=deque(maxlen=MAX_QUEUE_SIZE), |
|
) |
|
|
|
|
|
def load_initial_state(request: gr.Request) -> GenerationState: |
|
print("Loading initial state for", request.client.host) |
|
print("Total number of active threads", threading.active_count()) |
|
|
|
return get_initial_state() |
|
|
|
|
|
async def put_to_queue( |
|
image: Optional[Image.Image], |
|
prompt: str, |
|
seed: int, |
|
strength: float, |
|
state: GenerationState, |
|
): |
|
prompts_queue = state.prompts |
|
|
|
if prompt and image is not None: |
|
prompts_queue.append((image, prompt, seed, strength)) |
|
|
|
return state |
|
|
|
|
|
def inference(state: GenerationState) -> Image.Image: |
|
prompts_queue = state.prompts |
|
generations_queue = state.generations |
|
|
|
if len(prompts_queue) == 0: |
|
return state |
|
|
|
image, prompt, seed, strength = prompts_queue.popleft() |
|
|
|
original_image_size = image.size |
|
image = replace_background(image.resize((512, 512))) |
|
|
|
result = pipeline( |
|
prompt=prompt, |
|
image=image, |
|
strength=strength, |
|
seed=seed, |
|
guidance_scale=1, |
|
num_inference_steps=4, |
|
) |
|
|
|
output_image = result.images[0].resize(original_image_size) |
|
|
|
generations_queue.append(output_image) |
|
|
|
return state |
|
|
|
|
|
def update_output_image(state: GenerationState): |
|
image_update = gr.update() |
|
|
|
generations_queue = state.generations |
|
|
|
if len(generations_queue) > 0: |
|
generated_image = generations_queue.popleft() |
|
image_update = gr.update(value=generated_image) |
|
|
|
return image_update, state |
|
|
|
|
|
with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: |
|
generation_state = gr.State(get_initial_state()) |
|
|
|
gr.HTML(f'<div style="width: 70px;">{LOGO}</div>') |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(variant="default"): |
|
input_image = gr.Image( |
|
tool="color-sketch", |
|
source="canvas", |
|
label="Initial Image", |
|
type="pil", |
|
height=512, |
|
width=512, |
|
brush_radius=40.0, |
|
) |
|
|
|
output_image = gr.Image( |
|
label="Generated Image", |
|
type="pil", |
|
interactive=False, |
|
elem_id="output_image", |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0]) |
|
|
|
with gr.Accordion(label="Advanced Options", open=False): |
|
with gr.Row(): |
|
with gr.Column(): |
|
strength = gr.Slider( |
|
label="Strength", |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.8, |
|
info=""" |
|
Strength of the initial image that will be applied during inference. |
|
""", |
|
) |
|
with gr.Column(): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=2**31 - 1, |
|
step=1, |
|
randomize=True, |
|
info=""" |
|
Seed for the random number generator. |
|
""", |
|
) |
|
|
|
demo.load( |
|
load_initial_state, |
|
outputs=[generation_state], |
|
) |
|
demo.load( |
|
inference, |
|
inputs=[generation_state], |
|
outputs=[generation_state], |
|
every=0.1, |
|
) |
|
demo.load( |
|
update_output_image, |
|
inputs=[generation_state], |
|
outputs=[output_image, generation_state], |
|
every=0.1, |
|
) |
|
for event in [input_image.change, prompt_box.change, strength.change, seed.change]: |
|
event( |
|
put_to_queue, |
|
[input_image, prompt_box, seed, strength, generation_state], |
|
[generation_state], |
|
show_progress=False, |
|
queue=True, |
|
) |
|
|
|
gr.Markdown("## Example Prompts") |
|
gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue(concurrency_count=20, api_open=False).launch(max_threads=1024) |
|
|