Kit / app.py
SIGMitch's picture
Update app.py
aca52b5 verified
from typing import Tuple
import requests
import random
import numpy as np
import gradio as gr
import spaces
import os
import torch
from PIL import Image
from diffusers import FluxInpaintPipeline
from diffusers import FluxImg2ImgPipeline
MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
image = image.convert("RGBA")
data = image.getdata()
new_data = []
for item in data:
avg = sum(item[:3]) / 3
if avg < threshold:
new_data.append((0, 0, 0, 0))
else:
new_data.append(item)
image.putdata(new_data)
return image
#pipe = FluxInpaintPipeline.from_pretrained(
# "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
pipe2 = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(DEVICE)
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
# width, height = original_resolution_wh
w, h = original_resolution_wh
w_original, h_original = w, h
if w > h:
aspect_ratiow = h / w
w = IMAGE_SIZE
h = IMAGE_SIZE * aspect_ratiow
if h > w:
aspect_ratioh = w / h
w = IMAGE_SIZE * aspect_ratioh
h = IMAGE_SIZE
if h == w:
w = IMAGE_SIZE
h = IMAGE_SIZE
#was_resized = False
w = w - w % 8
h = h - h % 8
#if width <= maximum_dimension and height <= maximum_dimension:
# width = width - (width % 32)
# height = height - (height % 32)
# return width, height
# if width > height:
# scaling_factor = maximum_dimension / width
# else:
# scaling_factor = maximum_dimension / height
# new_width = int(width * scaling_factor)
# new_height = int(height * scaling_factor)
# new_width = new_width - (new_width % 32)
# new_height = new_height - (new_height % 32)
return int(w), int(h)
@spaces.GPU(duration=80)
def process(
input_image_editor: dict,
input_text: str,
seed_slicer: int,
randomize_seed_checkbox: bool,
strength_slider: float,
num_inference_steps_slider: int,
num_influence: float,
progress=gr.Progress(track_tqdm=True)
):
input_text = "A military COR2 "+input_text
image = input_image_editor['background']
mask = input_image_editor['layers'][0]
if not image:
gr.Info("Please upload an image.")
return None# , None
width, height = resize_image_dimensions(original_resolution_wh=image.size)
resized_image = image.resize((width, height), Image.LANCZOS)
if randomize_seed_checkbox:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
# if not mask:
# gr.Info("Please draw a mask on the image.")
pipe2.load_lora_weights("SIGMitch/KIT")
result = pipe2(
prompt=input_text,
image=resized_image,
width=width,
height=height,
num_images_per_prompt =1,
strength=strength_slider,
generator=generator,
joint_attention_kwargs={"scale": num_influence},
num_inference_steps=num_inference_steps_slider,
guidance_scale=3.5,
)
print('INFERENCE DONE')
#return result.images[0].resize((image.size), Image.LANCZOS), result.images[1].resize((image.size), Image.LANCZOS)
return result.images[0]# , None #result.images[1]
#resized_mask = mask.resize((width, height), Image.LANCZOS)
#pipe.load_lora_weights("SIGMitch/KIT")
#result = pipe(
# prompt=input_text,
# image=resized_image,
# mask_image=resized_mask,
# width=width,
# height=height,
# strength=strength_slider,
# generator=generator,
# joint_attention_kwargs={"scale": 1.2},
# num_inference_steps=num_inference_steps_slider
#).images[0]
#print('INFERENCE DONE')
# return result, resized_mask
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image_editor_component = gr.ImageEditor(
label='Image',
type='pil',
sources=["upload"],
image_mode='RGB',
layers=False
)
with gr.Row():
input_text_component = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
lorasteps = gr.Slider(label="Influence", minimum=0, maximum=2, step=0.1, value=1)
submit_button_component = gr.Button(
value='Submit', variant='primary', scale=0)
with gr.Accordion("Advanced Settings", open=False):
seed_slicer_component = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed_checkbox_component = gr.Checkbox(
label="Randomize seed", value=True)
with gr.Row():
strength_slider_component = gr.Slider(
label="Strength",
info="Indicates extent to transform the reference `image`. "
"Must be between 0 and 1. `image` is used as a starting "
"point and more noise is added the higher the `strength`.",
minimum=0,
maximum=1,
step=0.01,
value=0.8,
)
num_inference_steps_slider_component = gr.Slider(
label="Number of inference steps",
info="The number of denoising steps. More denoising steps "
"usually lead to a higher quality image at the",
minimum=1,
maximum=50,
step=1,
value=28,
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Generated image', format="png")
# output_image_component2 = gr.Image(
# type='pil', image_mode='RGB', label='Generated image', format="png")
submit_button_component.click(
fn=process,
inputs=[
input_image_editor_component,
input_text_component,
seed_slicer_component,
randomize_seed_checkbox_component,
strength_slider_component,
num_inference_steps_slider_component,
lorasteps
],
outputs=[
output_image_component# ,
# output_image_component2
]
)
# demo.launch(auth=("user", os.getenv('Login')),share=True, debug=False, show_error=True)
demo.launch(share=True, debug=False, show_error=True)