K00B404 commited on
Commit
8125207
·
verified ·
1 Parent(s): 4a79c4b

Update mmig.py

Browse files
Files changed (1) hide show
  1. mmig.py +201 -0
mmig.py CHANGED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from typing import Dict, List, Optional, Callable
4
+
5
+ class MultiModelImageGenerator:
6
+ """
7
+ ## Multi-Model Stable Diffusion Image Generation Framework
8
+
9
+ ### Core Design Principles
10
+ - Flexible model loading and management
11
+ - Concurrent image generation support
12
+ - Robust error handling
13
+ - Configurable generation strategies
14
+
15
+ ### Technical Components
16
+ - Dynamic model function registration
17
+ - Fallback mechanism for model loading
18
+ - Task tracking and management
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ models: List[str],
24
+ default_model_path: str = 'models/'
25
+ ):
26
+ """
27
+ Initialize multi-model image generation system.
28
+
29
+ Args:
30
+ models (List[str]): List of model paths for image generation
31
+ default_model_path (str): Base path for model loading
32
+ """
33
+ self.models = models
34
+ self.default_model_path = default_model_path
35
+ self.model_functions: Dict[int, Callable] = {}
36
+ self._initialize_models()
37
+
38
+ def _initialize_models(self):
39
+ """
40
+ Load and initialize image generation models with fallback mechanism.
41
+
42
+ Strategy:
43
+ - Attempt to load each model
44
+ - Provide default no-op function if loading fails
45
+ """
46
+ for model_idx, model_path in enumerate(self.models, 1):
47
+ try:
48
+ # Attempt to load model with Gradio interface
49
+ model_fn = gr.Interface.load(
50
+ f"{self.default_model_path}{model_path}",
51
+ live=False,
52
+ preprocess=True,
53
+ postprocess=False
54
+ )
55
+ self.model_functions[model_idx] = model_fn
56
+ except Exception as error:
57
+ # Fallback: Create a no-op function
58
+ def fallback_fn(txt):
59
+ return None
60
+
61
+ self.model_functions[model_idx] = gr.Interface(
62
+ fn=fallback_fn,
63
+ inputs=["text"],
64
+ outputs=["image"]
65
+ )
66
+
67
+ def generate_with_model(
68
+ self,
69
+ model_idx: int,
70
+ prompt: str
71
+ ) -> Optional[gr.Image]:
72
+ """
73
+ Generate image using specified model with intelligent fallback.
74
+
75
+ Args:
76
+ model_idx (int): Index of model to use
77
+ prompt (str): Generation prompt
78
+
79
+ Returns:
80
+ Generated image or None if generation fails
81
+ """
82
+ # Use specified model, fallback to first model if not available
83
+ selected_model = (
84
+ self.model_functions.get(str(model_idx)) or
85
+ self.model_functions.get(str(1))
86
+ )
87
+
88
+ return selected_model(prompt)
89
+
90
+ def create_gradio_interface(self) -> gr.Blocks:
91
+ """
92
+ Create Gradio interface for multi-model image generation.
93
+
94
+ Returns:
95
+ Configurable Gradio Blocks interface
96
+ """
97
+ with gr.Blocks(title="Multi-Model Stable Diffusion", theme="Nymbo/Nymbo_Theme") as interface:
98
+ with gr.Column(scale=12):
99
+ with gr.Row():
100
+ primary_prompt = gr.Textbox(label="Generation Prompt", value="")
101
+
102
+ with gr.Row():
103
+ run_btn = gr.Button("Generate", variant="primary")
104
+ clear_btn = gr.Button("Clear")
105
+
106
+ # Dynamic output image grid
107
+ sd_outputs = {}
108
+ for model_idx, model_path in enumerate(self.models, 1):
109
+ with gr.Column(scale=3, min_width=320):
110
+ with gr.Box():
111
+ sd_outputs[model_idx] = gr.Image(label=model_path)
112
+
113
+ # Task tracking components
114
+ with gr.Row(visible=False):
115
+ start_box = gr.Number(interactive=False)
116
+ end_box = gr.Number(interactive=False)
117
+ task_status_box = gr.Textbox(value=0, interactive=False)
118
+
119
+ # Event bindings
120
+ def start_task():
121
+ t_stamp = time.time()
122
+ return (
123
+ gr.update(value=t_stamp),
124
+ gr.update(value=t_stamp),
125
+ gr.update(value=0)
126
+ )
127
+
128
+ def check_task_status(cnt, t_stamp):
129
+ current_time = time.time()
130
+ timeout = t_stamp + 60
131
+
132
+ if current_time > timeout and t_stamp != 0:
133
+ return gr.update(value=0), gr.update(value=1)
134
+ else:
135
+ return (
136
+ gr.update(value=current_time if cnt != 0 else 0),
137
+ gr.update(value=0)
138
+ )
139
+
140
+ def clear_interface():
141
+ return tuple([None] + [None] * len(self.models))
142
+
143
+ # Task management events
144
+ start_box.change(
145
+ check_task_status,
146
+ [start_box, end_box],
147
+ [start_box, task_status_box],
148
+ every=1,
149
+ show_progress=False
150
+ )
151
+
152
+ primary_prompt.submit(start_task, None, [start_box, end_box, task_status_box])
153
+ run_btn.click(start_task, None, [start_box, end_box, task_status_box])
154
+
155
+ # Dynamic model generation events
156
+ generation_tasks = {}
157
+ for model_idx, model_path in enumerate(self.models, 1):
158
+ generation_tasks[model_idx] = run_btn.click(
159
+ self.generate_with_model,
160
+ inputs=[gr.Number(model_idx), primary_prompt],
161
+ outputs=[sd_outputs[model_idx]]
162
+ )
163
+
164
+ # Clear button handler
165
+ clear_btn.click(
166
+ clear_interface,
167
+ None,
168
+ [primary_prompt, *list(sd_outputs.values())],
169
+ cancels=list(generation_tasks.values())
170
+ )
171
+
172
+ return interface
173
+
174
+ def launch(self, **kwargs):
175
+ """
176
+ Launch Gradio interface with configurable parameters.
177
+
178
+ Args:
179
+ **kwargs: Gradio launch configuration parameters
180
+ """
181
+ interface = self.create_gradio_interface()
182
+ interface.queue(concurrency_count=600, status_update_rate=0.1)
183
+ interface.launch(**kwargs)
184
+
185
+ def main():
186
+ """
187
+ Demonstration of Multi-Model Image Generation Framework
188
+ """
189
+ models = [
190
+ "doohickey/neopian-diffusion",
191
+ "dxli/duck_toy",
192
+ "dxli/bear_plushie",
193
+ "haor/Evt_V4-preview",
194
+ "Yntec/Dreamscapes_n_Dragonfire_v2"
195
+ ]
196
+
197
+ image_generator = MultiModelImageGenerator(models)
198
+ image_generator.launch(inline=True, show_api=False)
199
+
200
+ if __name__ == "__main__":
201
+ main()