Update mmig.py
Browse files
mmig.py
CHANGED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
from typing import Dict, List, Optional, Callable
|
4 |
+
|
5 |
+
class MultiModelImageGenerator:
|
6 |
+
"""
|
7 |
+
## Multi-Model Stable Diffusion Image Generation Framework
|
8 |
+
|
9 |
+
### Core Design Principles
|
10 |
+
- Flexible model loading and management
|
11 |
+
- Concurrent image generation support
|
12 |
+
- Robust error handling
|
13 |
+
- Configurable generation strategies
|
14 |
+
|
15 |
+
### Technical Components
|
16 |
+
- Dynamic model function registration
|
17 |
+
- Fallback mechanism for model loading
|
18 |
+
- Task tracking and management
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
models: List[str],
|
24 |
+
default_model_path: str = 'models/'
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Initialize multi-model image generation system.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
models (List[str]): List of model paths for image generation
|
31 |
+
default_model_path (str): Base path for model loading
|
32 |
+
"""
|
33 |
+
self.models = models
|
34 |
+
self.default_model_path = default_model_path
|
35 |
+
self.model_functions: Dict[int, Callable] = {}
|
36 |
+
self._initialize_models()
|
37 |
+
|
38 |
+
def _initialize_models(self):
|
39 |
+
"""
|
40 |
+
Load and initialize image generation models with fallback mechanism.
|
41 |
+
|
42 |
+
Strategy:
|
43 |
+
- Attempt to load each model
|
44 |
+
- Provide default no-op function if loading fails
|
45 |
+
"""
|
46 |
+
for model_idx, model_path in enumerate(self.models, 1):
|
47 |
+
try:
|
48 |
+
# Attempt to load model with Gradio interface
|
49 |
+
model_fn = gr.Interface.load(
|
50 |
+
f"{self.default_model_path}{model_path}",
|
51 |
+
live=False,
|
52 |
+
preprocess=True,
|
53 |
+
postprocess=False
|
54 |
+
)
|
55 |
+
self.model_functions[model_idx] = model_fn
|
56 |
+
except Exception as error:
|
57 |
+
# Fallback: Create a no-op function
|
58 |
+
def fallback_fn(txt):
|
59 |
+
return None
|
60 |
+
|
61 |
+
self.model_functions[model_idx] = gr.Interface(
|
62 |
+
fn=fallback_fn,
|
63 |
+
inputs=["text"],
|
64 |
+
outputs=["image"]
|
65 |
+
)
|
66 |
+
|
67 |
+
def generate_with_model(
|
68 |
+
self,
|
69 |
+
model_idx: int,
|
70 |
+
prompt: str
|
71 |
+
) -> Optional[gr.Image]:
|
72 |
+
"""
|
73 |
+
Generate image using specified model with intelligent fallback.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
model_idx (int): Index of model to use
|
77 |
+
prompt (str): Generation prompt
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Generated image or None if generation fails
|
81 |
+
"""
|
82 |
+
# Use specified model, fallback to first model if not available
|
83 |
+
selected_model = (
|
84 |
+
self.model_functions.get(str(model_idx)) or
|
85 |
+
self.model_functions.get(str(1))
|
86 |
+
)
|
87 |
+
|
88 |
+
return selected_model(prompt)
|
89 |
+
|
90 |
+
def create_gradio_interface(self) -> gr.Blocks:
|
91 |
+
"""
|
92 |
+
Create Gradio interface for multi-model image generation.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Configurable Gradio Blocks interface
|
96 |
+
"""
|
97 |
+
with gr.Blocks(title="Multi-Model Stable Diffusion", theme="Nymbo/Nymbo_Theme") as interface:
|
98 |
+
with gr.Column(scale=12):
|
99 |
+
with gr.Row():
|
100 |
+
primary_prompt = gr.Textbox(label="Generation Prompt", value="")
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
run_btn = gr.Button("Generate", variant="primary")
|
104 |
+
clear_btn = gr.Button("Clear")
|
105 |
+
|
106 |
+
# Dynamic output image grid
|
107 |
+
sd_outputs = {}
|
108 |
+
for model_idx, model_path in enumerate(self.models, 1):
|
109 |
+
with gr.Column(scale=3, min_width=320):
|
110 |
+
with gr.Box():
|
111 |
+
sd_outputs[model_idx] = gr.Image(label=model_path)
|
112 |
+
|
113 |
+
# Task tracking components
|
114 |
+
with gr.Row(visible=False):
|
115 |
+
start_box = gr.Number(interactive=False)
|
116 |
+
end_box = gr.Number(interactive=False)
|
117 |
+
task_status_box = gr.Textbox(value=0, interactive=False)
|
118 |
+
|
119 |
+
# Event bindings
|
120 |
+
def start_task():
|
121 |
+
t_stamp = time.time()
|
122 |
+
return (
|
123 |
+
gr.update(value=t_stamp),
|
124 |
+
gr.update(value=t_stamp),
|
125 |
+
gr.update(value=0)
|
126 |
+
)
|
127 |
+
|
128 |
+
def check_task_status(cnt, t_stamp):
|
129 |
+
current_time = time.time()
|
130 |
+
timeout = t_stamp + 60
|
131 |
+
|
132 |
+
if current_time > timeout and t_stamp != 0:
|
133 |
+
return gr.update(value=0), gr.update(value=1)
|
134 |
+
else:
|
135 |
+
return (
|
136 |
+
gr.update(value=current_time if cnt != 0 else 0),
|
137 |
+
gr.update(value=0)
|
138 |
+
)
|
139 |
+
|
140 |
+
def clear_interface():
|
141 |
+
return tuple([None] + [None] * len(self.models))
|
142 |
+
|
143 |
+
# Task management events
|
144 |
+
start_box.change(
|
145 |
+
check_task_status,
|
146 |
+
[start_box, end_box],
|
147 |
+
[start_box, task_status_box],
|
148 |
+
every=1,
|
149 |
+
show_progress=False
|
150 |
+
)
|
151 |
+
|
152 |
+
primary_prompt.submit(start_task, None, [start_box, end_box, task_status_box])
|
153 |
+
run_btn.click(start_task, None, [start_box, end_box, task_status_box])
|
154 |
+
|
155 |
+
# Dynamic model generation events
|
156 |
+
generation_tasks = {}
|
157 |
+
for model_idx, model_path in enumerate(self.models, 1):
|
158 |
+
generation_tasks[model_idx] = run_btn.click(
|
159 |
+
self.generate_with_model,
|
160 |
+
inputs=[gr.Number(model_idx), primary_prompt],
|
161 |
+
outputs=[sd_outputs[model_idx]]
|
162 |
+
)
|
163 |
+
|
164 |
+
# Clear button handler
|
165 |
+
clear_btn.click(
|
166 |
+
clear_interface,
|
167 |
+
None,
|
168 |
+
[primary_prompt, *list(sd_outputs.values())],
|
169 |
+
cancels=list(generation_tasks.values())
|
170 |
+
)
|
171 |
+
|
172 |
+
return interface
|
173 |
+
|
174 |
+
def launch(self, **kwargs):
|
175 |
+
"""
|
176 |
+
Launch Gradio interface with configurable parameters.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
**kwargs: Gradio launch configuration parameters
|
180 |
+
"""
|
181 |
+
interface = self.create_gradio_interface()
|
182 |
+
interface.queue(concurrency_count=600, status_update_rate=0.1)
|
183 |
+
interface.launch(**kwargs)
|
184 |
+
|
185 |
+
def main():
|
186 |
+
"""
|
187 |
+
Demonstration of Multi-Model Image Generation Framework
|
188 |
+
"""
|
189 |
+
models = [
|
190 |
+
"doohickey/neopian-diffusion",
|
191 |
+
"dxli/duck_toy",
|
192 |
+
"dxli/bear_plushie",
|
193 |
+
"haor/Evt_V4-preview",
|
194 |
+
"Yntec/Dreamscapes_n_Dragonfire_v2"
|
195 |
+
]
|
196 |
+
|
197 |
+
image_generator = MultiModelImageGenerator(models)
|
198 |
+
image_generator.launch(inline=True, show_api=False)
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
main()
|