File size: 4,561 Bytes
2ac128d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e26c4c
2ac128d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import spaces
from diffusers import AutoPipelineForText2Image
import torch
import time
# import logging
from threading import Timer
from nsfw_detector import NSFWDetector

# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# Global variables
pipe = None
last_use_time = None
unload_timer = None
TIMEOUT_SECONDS = 120  # 2 minutes
BATCH_SIZE = 4

def chunk_generations(num_images):
    """Split number of images into batches of BATCH_SIZE"""
    return [min(BATCH_SIZE, num_images - i) for i in range(0, num_images, BATCH_SIZE)]

@spaces.GPU(duration=25)
def generate_image(
    prompt,
    num_inference_steps=1,
    num_images=1,
    height=512,
    width=512,
):    
    global pipe
    start_time = time.time()
    
    # Load model if needed
    if pipe is None:
        yield None, "Loading model..."
        pipe = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/sdxl-turbo", 
            torch_dtype=torch.float16, 
            variant="fp16"
        ).to("cuda")
        yield None, "Model loaded, starting generation..."
    
    reset_timer()
    
    # Process in batches if more than BATCH_SIZE images
    if num_images > BATCH_SIZE:
        yield None, f"Generating {num_images} images in batches..."
        all_images = []
        batches = chunk_generations(num_images)
        
        for i, batch_size in enumerate(batches):
            yield None, f"Generating batch {i+1}/{len(batches)} ({batch_size} images)..."
            batch_images = pipe(
                prompt=prompt,
                num_inference_steps=num_inference_steps,
                height=height,
                width=width,
                guidance_scale=0.0,
                num_images_per_prompt=batch_size
            ).images
            all_images.extend(batch_images)
        images = all_images
    else:
        yield None, f"Generating {num_images} image(s) with {num_inference_steps} steps..."
        images = pipe(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            height=height,
            width=width,
            guidance_scale=0.0,
            num_images_per_prompt=num_images
        ).images
    
    total_time = time.time() - start_time
    avg_time = total_time / num_images
    status_msg = f"Generated {num_images} image(s) in {total_time:.2f} seconds (avg {avg_time:.2f}s per image)"
    # logger.info(status_msg)
    
    # Check for NSFW content
    detector = NSFWDetector()
    is_nsfw, category, confidence = detector.check_image(images[0])
        
    if category == "SAFE":
        yield images, status_msg
    else:
        return

def unload_model():
    global pipe, last_use_time
    current_time = time.time()
    if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS:
        # logger.info("Unloading model due to inactivity...")
        pipe = None
        torch.cuda.empty_cache()
        return "Model unloaded due to inactivity"

def reset_timer():
    global unload_timer, last_use_time
    if unload_timer:
        unload_timer.cancel()
    last_use_time = time.time()
    unload_timer = Timer(TIMEOUT_SECONDS, unload_model)
    unload_timer.start()

# Create the Gradio interface
with gr.Blocks() as demo:
    with gr.Column():
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
        with gr.Row():
            steps = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of inference steps")
            num_images = gr.Slider(minimum=1, maximum=64, value=1, step=1, label="Number of images to generate")
        with gr.Row():
            height = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Height")
            width = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Width")
        
        generate_btn = gr.Button("Generate")
        gallery = gr.Gallery()
        status = gr.Textbox(
            label="Status", 
            value="Model not loaded - will load on first generation", 
            interactive=False
        )
        
        generate_btn.click(
            fn=generate_image,
            inputs=[prompt, steps, num_images, height, width],
            outputs=[gallery, status]
        )

    gr.Markdown("""
        This model works best with 512x512 resolution and 1-4 inference steps. 
        Values above 4 steps may not improve quality significantly.
        The model will automatically unload after 2 minutes of inactivity.
    """)

if __name__ == "__main__":
    demo.launch()