import gradio as gr import spaces from diffusers import AutoPipelineForText2Image import torch import time # import logging from threading import Timer from nsfw_detector import NSFWDetector # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) # Global variables pipe = None last_use_time = None unload_timer = None TIMEOUT_SECONDS = 120 # 2 minutes BATCH_SIZE = 4 def chunk_generations(num_images): """Split number of images into batches of BATCH_SIZE""" return [min(BATCH_SIZE, num_images - i) for i in range(0, num_images, BATCH_SIZE)] @spaces.GPU(duration=25) def generate_image( prompt, num_inference_steps=1, num_images=1, height=512, width=512, ): global pipe start_time = time.time() # Load model if needed if pipe is None: yield None, "Loading model..." pipe = AutoPipelineForText2Image.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16" ).to("cuda") yield None, "Model loaded, starting generation..." reset_timer() # Process in batches if more than BATCH_SIZE images if num_images > BATCH_SIZE: yield None, f"Generating {num_images} images in batches..." all_images = [] batches = chunk_generations(num_images) for i, batch_size in enumerate(batches): yield None, f"Generating batch {i+1}/{len(batches)} ({batch_size} images)..." batch_images = pipe( prompt=prompt, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=0.0, num_images_per_prompt=batch_size ).images all_images.extend(batch_images) images = all_images else: yield None, f"Generating {num_images} image(s) with {num_inference_steps} steps..." images = pipe( prompt=prompt, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=0.0, num_images_per_prompt=num_images ).images total_time = time.time() - start_time avg_time = total_time / num_images status_msg = f"Generated {num_images} image(s) in {total_time:.2f} seconds (avg {avg_time:.2f}s per image)" # logger.info(status_msg) # Check for NSFW content detector = NSFWDetector() is_nsfw, category, confidence = detector.check_image(images[0]) if category == "SAFE": yield images, status_msg else: return def unload_model(): global pipe, last_use_time current_time = time.time() if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS: # logger.info("Unloading model due to inactivity...") pipe = None torch.cuda.empty_cache() return "Model unloaded due to inactivity" def reset_timer(): global unload_timer, last_use_time if unload_timer: unload_timer.cancel() last_use_time = time.time() unload_timer = Timer(TIMEOUT_SECONDS, unload_model) unload_timer.start() # Create the Gradio interface with gr.Blocks() as demo: with gr.Column(): prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") with gr.Row(): steps = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of inference steps") num_images = gr.Slider(minimum=1, maximum=64, value=1, step=1, label="Number of images to generate") with gr.Row(): height = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Height") width = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Width") generate_btn = gr.Button("Generate") gallery = gr.Gallery() status = gr.Textbox( label="Status", value="Model not loaded - will load on first generation", interactive=False ) generate_btn.click( fn=generate_image, inputs=[prompt, steps, num_images, height, width], outputs=[gallery, status] ) gr.Markdown(""" This model works best with 512x512 resolution and 1-4 inference steps. Values above 4 steps may not improve quality significantly. The model will automatically unload after 2 minutes of inactivity. """) if __name__ == "__main__": demo.launch()