|
import gradio as gr |
|
import time |
|
from typing import Dict, List, Optional, Callable |
|
|
|
class MultiModelImageGenerator: |
|
""" |
|
## Multi-Model Stable Diffusion Image Generation Framework |
|
|
|
### Core Design Principles |
|
- Flexible model loading and management |
|
- Concurrent image generation support |
|
- Robust error handling |
|
- Configurable generation strategies |
|
|
|
### Technical Components |
|
- Dynamic model function registration |
|
- Fallback mechanism for model loading |
|
- Task tracking and management |
|
""" |
|
|
|
def __init__( |
|
self, |
|
models: List[str], |
|
default_model_path: str = 'models/' |
|
): |
|
""" |
|
Initialize multi-model image generation system. |
|
|
|
Args: |
|
models (List[str]): List of model paths for image generation |
|
default_model_path (str): Base path for model loading |
|
""" |
|
self.models = models |
|
self.default_model_path = default_model_path |
|
self.model_functions: Dict[int, Callable] = {} |
|
self._initialize_models() |
|
|
|
def _initialize_models(self): |
|
""" |
|
Load and initialize image generation models with fallback mechanism. |
|
|
|
Strategy: |
|
- Attempt to load each model |
|
- Provide default no-op function if loading fails |
|
""" |
|
for model_idx, model_path in enumerate(self.models, 1): |
|
try: |
|
|
|
model_fn = gr.Interface.load( |
|
f"{self.default_model_path}{model_path}", |
|
live=False, |
|
preprocess=True, |
|
postprocess=False |
|
) |
|
self.model_functions[model_idx] = model_fn |
|
except Exception as error: |
|
|
|
def fallback_fn(txt): |
|
return None |
|
|
|
self.model_functions[model_idx] = gr.Interface( |
|
fn=fallback_fn, |
|
inputs=["text"], |
|
outputs=["image"] |
|
) |
|
|
|
def generate_with_model( |
|
self, |
|
model_idx: int, |
|
prompt: str |
|
) -> Optional[gr.Image]: |
|
""" |
|
Generate image using specified model with intelligent fallback. |
|
|
|
Args: |
|
model_idx (int): Index of model to use |
|
prompt (str): Generation prompt |
|
|
|
Returns: |
|
Generated image or None if generation fails |
|
""" |
|
|
|
selected_model = ( |
|
self.model_functions.get(str(model_idx)) or |
|
self.model_functions.get(str(1)) |
|
) |
|
|
|
return selected_model(prompt) |
|
|
|
def create_gradio_interface(self) -> gr.Blocks: |
|
""" |
|
Create Gradio interface for multi-model image generation. |
|
|
|
Returns: |
|
Configurable Gradio Blocks interface |
|
""" |
|
with gr.Blocks(title="Multi-Model Stable Diffusion", theme="Nymbo/Nymbo_Theme") as interface: |
|
with gr.Column(scale=12): |
|
with gr.Row(): |
|
primary_prompt = gr.Textbox(label="Generation Prompt", value="") |
|
|
|
with gr.Row(): |
|
run_btn = gr.Button("Generate", variant="primary") |
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
sd_outputs = {} |
|
for model_idx, model_path in enumerate(self.models, 1): |
|
with gr.Column(scale=3, min_width=320): |
|
with gr.Box(): |
|
sd_outputs[model_idx] = gr.Image(label=model_path) |
|
|
|
|
|
with gr.Row(visible=False): |
|
start_box = gr.Number(interactive=False) |
|
end_box = gr.Number(interactive=False) |
|
task_status_box = gr.Textbox(value=0, interactive=False) |
|
|
|
|
|
def start_task(): |
|
t_stamp = time.time() |
|
return ( |
|
gr.update(value=t_stamp), |
|
gr.update(value=t_stamp), |
|
gr.update(value=0) |
|
) |
|
|
|
def check_task_status(cnt, t_stamp): |
|
current_time = time.time() |
|
timeout = t_stamp + 60 |
|
|
|
if current_time > timeout and t_stamp != 0: |
|
return gr.update(value=0), gr.update(value=1) |
|
else: |
|
return ( |
|
gr.update(value=current_time if cnt != 0 else 0), |
|
gr.update(value=0) |
|
) |
|
|
|
def clear_interface(): |
|
return tuple([None] + [None] * len(self.models)) |
|
|
|
|
|
start_box.change( |
|
check_task_status, |
|
[start_box, end_box], |
|
[start_box, task_status_box], |
|
every=1, |
|
show_progress=False |
|
) |
|
|
|
primary_prompt.submit(start_task, None, [start_box, end_box, task_status_box]) |
|
run_btn.click(start_task, None, [start_box, end_box, task_status_box]) |
|
|
|
|
|
generation_tasks = {} |
|
for model_idx, model_path in enumerate(self.models, 1): |
|
generation_tasks[model_idx] = run_btn.click( |
|
self.generate_with_model, |
|
inputs=[gr.Number(model_idx), primary_prompt], |
|
outputs=[sd_outputs[model_idx]] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
clear_interface, |
|
None, |
|
[primary_prompt, *list(sd_outputs.values())], |
|
cancels=list(generation_tasks.values()) |
|
) |
|
|
|
return interface |
|
|
|
def launch(self, **kwargs): |
|
""" |
|
Launch Gradio interface with configurable parameters. |
|
|
|
Args: |
|
**kwargs: Gradio launch configuration parameters |
|
""" |
|
interface = self.create_gradio_interface() |
|
interface.queue(concurrency_count=600, status_update_rate=0.1) |
|
interface.launch(**kwargs) |
|
|
|
def main(): |
|
""" |
|
Demonstration of Multi-Model Image Generation Framework |
|
""" |
|
models = [ |
|
"doohickey/neopian-diffusion", |
|
"dxli/duck_toy", |
|
"dxli/bear_plushie", |
|
"haor/Evt_V4-preview", |
|
"Yntec/Dreamscapes_n_Dragonfire_v2" |
|
] |
|
|
|
image_generator = MultiModelImageGenerator(models) |
|
image_generator.launch(inline=True, show_api=False) |
|
|
|
if __name__ == "__main__": |
|
main() |