Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,561 Bytes
2ac128d 1e26c4c 2ac128d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import spaces
from diffusers import AutoPipelineForText2Image
import torch
import time
# import logging
from threading import Timer
from nsfw_detector import NSFWDetector
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# Global variables
pipe = None
last_use_time = None
unload_timer = None
TIMEOUT_SECONDS = 120 # 2 minutes
BATCH_SIZE = 4
def chunk_generations(num_images):
"""Split number of images into batches of BATCH_SIZE"""
return [min(BATCH_SIZE, num_images - i) for i in range(0, num_images, BATCH_SIZE)]
@spaces.GPU(duration=25)
def generate_image(
prompt,
num_inference_steps=1,
num_images=1,
height=512,
width=512,
):
global pipe
start_time = time.time()
# Load model if needed
if pipe is None:
yield None, "Loading model..."
pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
yield None, "Model loaded, starting generation..."
reset_timer()
# Process in batches if more than BATCH_SIZE images
if num_images > BATCH_SIZE:
yield None, f"Generating {num_images} images in batches..."
all_images = []
batches = chunk_generations(num_images)
for i, batch_size in enumerate(batches):
yield None, f"Generating batch {i+1}/{len(batches)} ({batch_size} images)..."
batch_images = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=0.0,
num_images_per_prompt=batch_size
).images
all_images.extend(batch_images)
images = all_images
else:
yield None, f"Generating {num_images} image(s) with {num_inference_steps} steps..."
images = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=0.0,
num_images_per_prompt=num_images
).images
total_time = time.time() - start_time
avg_time = total_time / num_images
status_msg = f"Generated {num_images} image(s) in {total_time:.2f} seconds (avg {avg_time:.2f}s per image)"
# logger.info(status_msg)
# Check for NSFW content
detector = NSFWDetector()
is_nsfw, category, confidence = detector.check_image(images[0])
if category == "SAFE":
yield images, status_msg
else:
return
def unload_model():
global pipe, last_use_time
current_time = time.time()
if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS:
# logger.info("Unloading model due to inactivity...")
pipe = None
torch.cuda.empty_cache()
return "Model unloaded due to inactivity"
def reset_timer():
global unload_timer, last_use_time
if unload_timer:
unload_timer.cancel()
last_use_time = time.time()
unload_timer = Timer(TIMEOUT_SECONDS, unload_model)
unload_timer.start()
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
with gr.Row():
steps = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of inference steps")
num_images = gr.Slider(minimum=1, maximum=64, value=1, step=1, label="Number of images to generate")
with gr.Row():
height = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Height")
width = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Width")
generate_btn = gr.Button("Generate")
gallery = gr.Gallery()
status = gr.Textbox(
label="Status",
value="Model not loaded - will load on first generation",
interactive=False
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, steps, num_images, height, width],
outputs=[gallery, status]
)
gr.Markdown("""
This model works best with 512x512 resolution and 1-4 inference steps.
Values above 4 steps may not improve quality significantly.
The model will automatically unload after 2 minutes of inactivity.
""")
if __name__ == "__main__":
demo.launch() |