Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from diffusers import AutoPipelineForText2Image | |
import torch | |
import time | |
# import logging | |
from threading import Timer | |
from nsfw_detector import NSFWDetector | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# Global variables | |
pipe = None | |
last_use_time = None | |
unload_timer = None | |
TIMEOUT_SECONDS = 120 # 2 minutes | |
BATCH_SIZE = 4 | |
def chunk_generations(num_images): | |
"""Split number of images into batches of BATCH_SIZE""" | |
return [min(BATCH_SIZE, num_images - i) for i in range(0, num_images, BATCH_SIZE)] | |
def generate_image( | |
prompt, | |
num_inference_steps=1, | |
num_images=1, | |
height=512, | |
width=512, | |
): | |
global pipe | |
start_time = time.time() | |
# Load model if needed | |
if pipe is None: | |
yield None, "Loading model..." | |
pipe = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
torch_dtype=torch.float16, | |
variant="fp16" | |
).to("cuda") | |
yield None, "Model loaded, starting generation..." | |
reset_timer() | |
# Process in batches if more than BATCH_SIZE images | |
if num_images > BATCH_SIZE: | |
yield None, f"Generating {num_images} images in batches..." | |
all_images = [] | |
batches = chunk_generations(num_images) | |
for i, batch_size in enumerate(batches): | |
yield None, f"Generating batch {i+1}/{len(batches)} ({batch_size} images)..." | |
batch_images = pipe( | |
prompt=prompt, | |
num_inference_steps=num_inference_steps, | |
height=height, | |
width=width, | |
guidance_scale=0.0, | |
num_images_per_prompt=batch_size | |
).images | |
all_images.extend(batch_images) | |
images = all_images | |
else: | |
yield None, f"Generating {num_images} image(s) with {num_inference_steps} steps..." | |
images = pipe( | |
prompt=prompt, | |
num_inference_steps=num_inference_steps, | |
height=height, | |
width=width, | |
guidance_scale=0.0, | |
num_images_per_prompt=num_images | |
).images | |
total_time = time.time() - start_time | |
avg_time = total_time / num_images | |
status_msg = f"Generated {num_images} image(s) in {total_time:.2f} seconds (avg {avg_time:.2f}s per image)" | |
# logger.info(status_msg) | |
# Check for NSFW content | |
detector = NSFWDetector() | |
is_nsfw, category, confidence = detector.check_image(images[0]) | |
if category == "SAFE": | |
yield images, status_msg | |
else: | |
return | |
def unload_model(): | |
global pipe, last_use_time | |
current_time = time.time() | |
if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS: | |
# logger.info("Unloading model due to inactivity...") | |
pipe = None | |
torch.cuda.empty_cache() | |
return "Model unloaded due to inactivity" | |
def reset_timer(): | |
global unload_timer, last_use_time | |
if unload_timer: | |
unload_timer.cancel() | |
last_use_time = time.time() | |
unload_timer = Timer(TIMEOUT_SECONDS, unload_model) | |
unload_timer.start() | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") | |
with gr.Row(): | |
steps = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of inference steps") | |
num_images = gr.Slider(minimum=1, maximum=64, value=1, step=1, label="Number of images to generate") | |
with gr.Row(): | |
height = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Height") | |
width = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Width") | |
generate_btn = gr.Button("Generate") | |
gallery = gr.Gallery() | |
status = gr.Textbox( | |
label="Status", | |
value="Model not loaded - will load on first generation", | |
interactive=False | |
) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[prompt, steps, num_images, height, width], | |
outputs=[gallery, status] | |
) | |
gr.Markdown(""" | |
This model works best with 512x512 resolution and 1-4 inference steps. | |
Values above 4 steps may not improve quality significantly. | |
The model will automatically unload after 2 minutes of inactivity. | |
""") | |
if __name__ == "__main__": | |
demo.launch() |