VamooseBambel commited on
Commit
2ac128d
·
verified ·
1 Parent(s): c1063d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from diffusers import AutoPipelineForText2Image
4
+ import torch
5
+ import time
6
+ # import logging
7
+ from threading import Timer
8
+ from nsfw_detector import NSFWDetector
9
+
10
+ # logging.basicConfig(level=logging.INFO)
11
+ # logger = logging.getLogger(__name__)
12
+
13
+ # Global variables
14
+ pipe = None
15
+ last_use_time = None
16
+ unload_timer = None
17
+ TIMEOUT_SECONDS = 120 # 2 minutes
18
+ BATCH_SIZE = 4
19
+
20
+ def chunk_generations(num_images):
21
+ """Split number of images into batches of BATCH_SIZE"""
22
+ return [min(BATCH_SIZE, num_images - i) for i in range(0, num_images, BATCH_SIZE)]
23
+
24
+ @spaces.GPU(duration=20)
25
+ def generate_image(
26
+ prompt,
27
+ num_inference_steps=1,
28
+ num_images=1,
29
+ height=512,
30
+ width=512,
31
+ ):
32
+ global pipe
33
+ start_time = time.time()
34
+
35
+ # Load model if needed
36
+ if pipe is None:
37
+ yield None, "Loading model..."
38
+ pipe = AutoPipelineForText2Image.from_pretrained(
39
+ "stabilityai/sdxl-turbo",
40
+ torch_dtype=torch.float16,
41
+ variant="fp16"
42
+ ).to("cuda")
43
+ yield None, "Model loaded, starting generation..."
44
+
45
+ reset_timer()
46
+
47
+ # Process in batches if more than BATCH_SIZE images
48
+ if num_images > BATCH_SIZE:
49
+ yield None, f"Generating {num_images} images in batches..."
50
+ all_images = []
51
+ batches = chunk_generations(num_images)
52
+
53
+ for i, batch_size in enumerate(batches):
54
+ yield None, f"Generating batch {i+1}/{len(batches)} ({batch_size} images)..."
55
+ batch_images = pipe(
56
+ prompt=prompt,
57
+ num_inference_steps=num_inference_steps,
58
+ height=height,
59
+ width=width,
60
+ guidance_scale=0.0,
61
+ num_images_per_prompt=batch_size
62
+ ).images
63
+ all_images.extend(batch_images)
64
+ images = all_images
65
+ else:
66
+ yield None, f"Generating {num_images} image(s) with {num_inference_steps} steps..."
67
+ images = pipe(
68
+ prompt=prompt,
69
+ num_inference_steps=num_inference_steps,
70
+ height=height,
71
+ width=width,
72
+ guidance_scale=0.0,
73
+ num_images_per_prompt=num_images
74
+ ).images
75
+
76
+ total_time = time.time() - start_time
77
+ avg_time = total_time / num_images
78
+ status_msg = f"Generated {num_images} image(s) in {total_time:.2f} seconds (avg {avg_time:.2f}s per image)"
79
+ # logger.info(status_msg)
80
+
81
+ # Check for NSFW content
82
+ detector = NSFWDetector()
83
+ is_nsfw, category, confidence = detector.check_image(images[0])
84
+
85
+ if category == "SAFE":
86
+ yield images, status_msg
87
+ else:
88
+ return
89
+
90
+ def unload_model():
91
+ global pipe, last_use_time
92
+ current_time = time.time()
93
+ if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS:
94
+ # logger.info("Unloading model due to inactivity...")
95
+ pipe = None
96
+ torch.cuda.empty_cache()
97
+ return "Model unloaded due to inactivity"
98
+
99
+ def reset_timer():
100
+ global unload_timer, last_use_time
101
+ if unload_timer:
102
+ unload_timer.cancel()
103
+ last_use_time = time.time()
104
+ unload_timer = Timer(TIMEOUT_SECONDS, unload_model)
105
+ unload_timer.start()
106
+
107
+ # Create the Gradio interface
108
+ with gr.Blocks() as demo:
109
+ with gr.Column():
110
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
111
+ with gr.Row():
112
+ steps = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Number of inference steps")
113
+ num_images = gr.Slider(minimum=1, maximum=64, value=1, step=1, label="Number of images to generate")
114
+ with gr.Row():
115
+ height = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Height")
116
+ width = gr.Slider(minimum=512, maximum=1024, value=512, step=64, label="Width")
117
+
118
+ generate_btn = gr.Button("Generate")
119
+ gallery = gr.Gallery()
120
+ status = gr.Textbox(
121
+ label="Status",
122
+ value="Model not loaded - will load on first generation",
123
+ interactive=False
124
+ )
125
+
126
+ generate_btn.click(
127
+ fn=generate_image,
128
+ inputs=[prompt, steps, num_images, height, width],
129
+ outputs=[gallery, status]
130
+ )
131
+
132
+ gr.Markdown("""
133
+ This model works best with 512x512 resolution and 1-4 inference steps.
134
+ Values above 4 steps may not improve quality significantly.
135
+ The model will automatically unload after 2 minutes of inactivity.
136
+ """)
137
+
138
+ if __name__ == "__main__":
139
+ demo.launch()