import os import random import sys from typing import Sequence, Mapping, Any, Union import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download import spaces from comfy import model_management # We need to import this early import gc import requests import re import hashlib import shutil # --- Startup Dummy Function --- @spaces.GPU(duration=60) def dummy_gpu_for_startup(): print("Dummy function for startup check executed. This is normal.") return "Startup check passed." # --- ComfyUI Backend Setup --- def find_path(name: str, path: str = None) -> str: if path is None: path = os.getcwd() if name in os.listdir(path): return os.path.join(path, name) parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: comfyui_path = find_path("ComfyUI") if comfyui_path and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: try: from main import load_extra_path_config except ImportError: from utils.extra_config import load_extra_path_config extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths: load_extra_path_config(extra_model_paths) else: print("Could not find extra_model_paths.yaml") add_comfyui_directory_to_sys_path() add_extra_model_paths() # Monkey-patch for Sage Attention print("Attempting to monkey-patch ComfyUI for Sage Attention...") try: model_management.sage_attention_enabled = lambda: True model_management.pytorch_attention_enabled = lambda: False print("Successfully monkey-patched model_management for Sage Attention.") except Exception as e: print(f"An error occurred during monkey-patching: {e}") # --- Constants & Configuration --- CHECKPOINT_DIR = "models/checkpoints" LORA_DIR = "models/loras" os.makedirs(CHECKPOINT_DIR, exist_ok=True) os.makedirs(LORA_DIR, exist_ok=True) # --- Model Definitions with Hashes --- # Format: {Display Name: (Repo ID, Filename, Type, Hash)} MODEL_MAP_ILLUSTRIOUS = { "Laxhar/noobai-XL-Vpred-1.0": ("Laxhar/noobai-XL-Vpred-1.0", "NoobAI-XL-Vpred-v1.0.safetensors", "SDXL", "ea349eeae8"), "Laxhar/noobai-XL-1.1": ("Laxhar/noobai-XL-1.1", "NoobAI-XL-v1.1.safetensors", "SDXL", "6681e8e4b1"), "WAI0731/wai-nsfw-illustrious-sdxl-v140": ("Ine007/waiNSFWIllustrious_v140", "waiNSFWIllustrious_v140.safetensors", "SDXL", "bdb59bac77"), "Ikena/hassaku-xl-illustrious-v30": ("misri/hassakuXLIllustrious_v30", "hassakuXLIllustrious_v30.safetensors", "SDXL", "b4fb5f829a"), "bluepen5805/noob_v_pencil-XL": ("bluepen5805/noob_v_pencil-XL", "noob_v_pencil-XL-v3.0.0.safetensors", "SDXL", "90b7911a78"), "RedRayz/hikari_noob_v-pred_1.2.2": ("RedRayz/hikari_noob_v-pred_1.2.2", "Hikari_Noob_v-pred_1.2.2.safetensors", "SDXL", "874170688a"), } MODEL_MAP_ANIMAGINE = { "cagliostrolab/animagine-xl-4.0": ("cagliostrolab/animagine-xl-4.0", "animagine-xl-4.0.safetensors", "SDXL", "6327eca98b"), "cagliostrolab/animagine-xl-3.1": ("cagliostrolab/animagine-xl-3.1", "animagine-xl-3.1.safetensors", "SDXL", "e3c47aedb0"), } MODEL_MAP_PONY = { "PurpleSmartAI/Pony_Diffusion_V6_XL": ("LyliaEngine/Pony_Diffusion_V6_XL", "ponyDiffusionV6XL_v6StartWithThisOne.safetensors", "SDXL", "67ab2fd8ec"), } MODEL_MAP_SD15 = { "Yuno779/anything-v3": ("ckpt/anything-v3.0", "Anything-V3.0-pruned.safetensors", "SD1.5", "ddd565f806"), } # --- Combined Maps for Global Lookup --- ALL_MODEL_MAP = {**MODEL_MAP_ILLUSTRIOUS, **MODEL_MAP_ANIMAGINE, **MODEL_MAP_PONY, **MODEL_MAP_SD15} MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()} DISPLAY_NAME_TO_HASH_MAP = {k: v[3] for k, v in ALL_MODEL_MAP.items()} HASH_TO_DISPLAY_NAME_MAP = {v[3]: k for k, v in ALL_MODEL_MAP.items()} # --- UI Defaults --- DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn," MAX_LORAS = 5 LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"] def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: try: return obj[index] except (KeyError, IndexError): try: return obj["result"][index] except (KeyError, IndexError): return None def import_custom_nodes() -> None: import asyncio, execution, server from nodes import init_extra_nodes loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) loop.run_until_complete(init_extra_nodes()) # --- Import ComfyUI Nodes & Get Choices --- from nodes import CheckpointLoaderSimple, EmptyLatentImage, KSampler, VAEDecode, SaveImage, NODE_CLASS_MAPPINGS import_custom_nodes() CLIPTextEncodeSDXL = NODE_CLASS_MAPPINGS['CLIPTextEncodeSDXL'] CLIPTextEncode = NODE_CLASS_MAPPINGS['CLIPTextEncode'] LoraLoader = NODE_CLASS_MAPPINGS['LoraLoader'] CLIPSetLastLayer = NODE_CLASS_MAPPINGS['CLIPSetLastLayer'] try: SAMPLER_CHOICES = KSampler.INPUT_TYPES()["required"]["sampler_name"][0] SCHEDULER_CHOICES = KSampler.INPUT_TYPES()["required"]["scheduler"][0] except Exception: SAMPLER_CHOICES = ['euler', 'dpmpp_2m_sde_gpu'] SCHEDULER_CHOICES = ['normal', 'karras'] # --- Instantiate Node Objects --- checkpointloadersimple = CheckpointLoaderSimple(); cliptextencodesdxl = CLIPTextEncodeSDXL() cliptextencode_sd15 = CLIPTextEncode(); emptylatentimage = EmptyLatentImage() ksampler = KSampler(); vaedecode = VAEDecode(); saveimage = SaveImage(); loraloader = LoraLoader() clipsetlastlayer = CLIPSetLastLayer() # --- LoRA & File Utils --- def get_civitai_file_info(version_id): api_url = f"https://civitai.com/api/v1/model-versions/{version_id}" try: response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json() for file_data in data.get('files', []): if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'): return file_data if data.get('files'): return data['files'][0] except Exception: return None def get_tensorart_file_info(model_id): api_url = f"https://tensor.art/api/v1/models/{model_id}" try: response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json() model_versions = data.get('modelVersions', []) if not model_versions: return None for file_data in model_versions[0].get('files', []): if file_data['name'].endswith('.safetensors'): return file_data return model_versions[0]['files'][0] if model_versions[0].get('files') else None except Exception: return None def download_file(url, save_path, api_key=None, progress=None, desc=""): if os.path.exists(save_path): return f"File already exists: {os.path.basename(save_path)}" headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {} try: if progress: progress(0, desc=desc) response = requests.get(url, stream=True, headers=headers, timeout=15); response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) with open(save_path, "wb") as f: downloaded = 0 for chunk in response.iter_content(chunk_size=8192): f.write(chunk) if progress and total_size > 0: downloaded += len(chunk); progress(downloaded / total_size, desc=desc) return f"Successfully downloaded: {os.path.basename(save_path)}" except Exception as e: if os.path.exists(save_path): os.remove(save_path) return f"Download failed for {os.path.basename(save_path)}: {e}" def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress): if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided." if source == "Civitai": version_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"civitai_{version_id}.safetensors"); file_info, api_key_to_use = get_civitai_file_info(version_id), civitai_key; source_name = f"Civitai ID {version_id}" elif source == "TensorArt": model_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"tensorart_{model_id}.safetensors"); file_info, api_key_to_use = get_tensorart_file_info(model_id), tensorart_key; source_name = f"TensorArt ID {model_id}" elif source == "Custom URL": url = id_or_url.strip(); url_hash = hashlib.md5(url.encode()).hexdigest(); local_path = os.path.join(LORA_DIR, f"custom_{url_hash}.safetensors"); file_info, api_key_to_use = {'downloadUrl': url}, None; source_name = f"URL {url[:30]}..." else: return None, "Invalid source." if os.path.exists(local_path): return local_path, "File already exists." if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}." status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") return (local_path, status) if "Successfully" in status else (None, status) def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)): sources, ids, _, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4] active_loras = [(s, i) for s, i, f in zip(sources, ids, files) if s in ["Civitai", "TensorArt", "Custom URL"] and i and i.strip() and f is None] if not active_loras: return "No remote LoRAs specified for pre-downloading." log = [f"* {s} ID {i}: {get_lora_path(s, i, civitai_api_key, tensorart_api_key, progress)[1]}" for s, i in active_loras] return "\n".join(log) # --- Model Management & Core Logic --- current_loaded_model_name = None; loaded_checkpoint_tuple = None def load_model(model_display_name: str, progress=gr.Progress()): global current_loaded_model_name, loaded_checkpoint_tuple if model_display_name == current_loaded_model_name and loaded_checkpoint_tuple: return loaded_checkpoint_tuple if loaded_checkpoint_tuple: model_management.unload_all_models(); loaded_checkpoint_tuple = None; gc.collect(); torch.cuda.empty_cache() repo_id, filename, _, _ = ALL_MODEL_MAP[model_display_name] local_file_path = os.path.join(CHECKPOINT_DIR, filename) if not os.path.exists(local_file_path): progress(0, desc=f"Downloading model: {model_display_name}") hf_hub_download(repo_id=repo_id, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False) progress(0.5, desc=f"Loading '{filename}'") MODEL_TUPLE = checkpointloadersimple.load_checkpoint(ckpt_name=filename) model_management.load_models_gpu([get_value_at_index(MODEL_TUPLE, 0)]) current_loaded_model_name = model_display_name; loaded_checkpoint_tuple = MODEL_TUPLE progress(1.0, desc="Model loaded"); return loaded_checkpoint_tuple def _generate_image_logic(model_display_name: str, positive_prompt: str, negative_prompt: str, seed: int, batch_size: int, width: int, height: int, guidance_scale: float, num_inference_steps: int, sampler_name: str, scheduler: str, civitai_api_key: str, tensorart_api_key: str, *lora_data, progress=gr.Progress(track_tqdm=True)): output_images = [] is_sd15 = MODEL_TYPE_MAP.get(model_display_name) == "SD1.5" clip_skip = 1 if is_sd15 and len(lora_data) > MAX_LORAS * 4: clip_skip = int(lora_data[-1]) lora_data = lora_data[:-1] with torch.inference_mode(): model_tuple = load_model(model_display_name, progress) model, clip, vae = (get_value_at_index(model_tuple, i) for i in range(3)) if is_sd15: clip = get_value_at_index(clipsetlastlayer.set_last_layer(clip=clip, stop_at_clip_layer=-clip_skip), 0) active_loras_for_meta = [] sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4] for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)): if scale > 0: lora_filename = None if custom_file: lora_filename = os.path.basename(custom_file.name) shutil.copy(custom_file.name, LORA_DIR) elif lora_id and lora_id.strip(): local_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress) if local_path: lora_filename = os.path.basename(local_path) if lora_filename: lora_tuple = loraloader.load_lora(model=model, clip=clip, lora_name=lora_filename, strength_model=scale, strength_clip=scale) model, clip = get_value_at_index(lora_tuple, 0), get_value_at_index(lora_tuple, 1) active_loras_for_meta.append(f"{source} {lora_id}:{scale}") loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else "" if is_sd15: pos_cond = cliptextencode_sd15.encode(text=positive_prompt, clip=clip) neg_cond = cliptextencode_sd15.encode(text=negative_prompt, clip=clip) else: pos_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=positive_prompt, text_l=positive_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0) neg_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=negative_prompt, text_l=negative_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0) start_seed = seed if seed != -1 else random.randint(0, 2**64 - 1) latent = emptylatentimage.generate(width=width, height=height, batch_size=batch_size) sampled = ksampler.sample( seed=start_seed, steps=num_inference_steps, cfg=guidance_scale, sampler_name=sampler_name, scheduler=scheduler, denoise=1.0, model=model, positive=get_value_at_index(pos_cond, 0), negative=get_value_at_index(neg_cond, 0), latent_image=get_value_at_index(latent, 0) ) decoded_images_tensor = get_value_at_index(vaedecode.decode(samples=get_value_at_index(sampled, 0), vae=vae), 0) for i in range(decoded_images_tensor.shape[0]): img_tensor = decoded_images_tensor[i] pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8")) current_seed = start_seed + i model_hash = DISPLAY_NAME_TO_HASH_MAP.get(model_display_name, "N/A") params_string = f"{positive_prompt}\nNegative prompt: {negative_prompt}\n" params_string += f"Steps: {num_inference_steps}, Sampler: {sampler_name}, Scheduler: {scheduler}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {model_display_name}, Model hash: {model_hash}" if is_sd15: params_string += f", Clip skip: {clip_skip}" params_string += f", {loras_string}" pil_image.info = {'parameters': params_string.strip()} output_images.append(pil_image) return output_images def generate_image_wrapper(*args, **kwargs): logic_args_list = list(args[:11]) zero_gpu_duration = args[11] logic_args_list.extend(args[12:]) duration = 60 try: if zero_gpu_duration and int(zero_gpu_duration) > 0: duration = int(zero_gpu_duration) except (ValueError, TypeError): pass return spaces.GPU(duration=duration)(_generate_image_logic)(*logic_args_list, **kwargs) # --- PNG Info & UI Logic --- def _parse_parameters(params_text): data = {}; lines = params_text.strip().split('\n'); data['prompt'] = lines[0] data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else "" params_line = '\n'.join(lines[2:]) def find_param(key, default, cast_type=str): match = re.search(fr"\b{key}: ([^,]+?)(,|$|\n)", params_line) return cast_type(match.group(1).strip()) if match else default data['steps'] = find_param("Steps", 28, int); data['sampler'] = find_param("Sampler", SAMPLER_CHOICES[0], str) data['scheduler'] = find_param("Scheduler", SCHEDULER_CHOICES[0], str); data['cfg_scale'] = find_param("CFG scale", 7.5, float) data['seed'] = find_param("Seed", -1, int); data['clip_skip'] = find_param("Clip skip", 1, int) data['base_model'] = find_param("Base Model", list(ALL_MODEL_MAP.keys())[0], str); data['model_hash'] = find_param("Model hash", None, str) size_match = re.search(r"Size: (\d+)x(\d+)", params_line) data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024) return data def get_png_info(image): if not image or not (params := image.info.get('parameters')): return "", "", "No metadata found in the image." parsed_data = _parse_parameters(params) other_params_text = "\n".join([p.strip() for p in '\n'.join(params.strip().split('\n')[2:]).split(',')]) return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_text def apply_data_to_ui(data, target_tab): final_sampler = data.get('sampler') if data.get('sampler') in SAMPLER_CHOICES else SAMPLER_CHOICES[0] default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0] final_scheduler = data.get('scheduler') if data.get('scheduler') in SCHEDULER_CHOICES else default_scheduler updates = {} base_model_name = data.get('base_model') if target_tab == "Illustrious": if base_model_name in MODEL_MAP_ILLUSTRIOUS: updates.update({base_model_name_input_illustrious: base_model_name}) updates.update({prompt_illustrious: data['prompt'], negative_prompt_illustrious: data['negative_prompt'], seed_illustrious: data['seed'], width_illustrious: data['width'], height_illustrious: data['height'], guidance_scale_illustrious: data['cfg_scale'], num_inference_steps_illustrious: data['steps'], sampler_illustrious: final_sampler, schedule_type_illustrious: final_scheduler, model_tabs: gr.Tabs(selected=0)}) elif target_tab == "Animagine": if base_model_name in MODEL_MAP_ANIMAGINE: updates.update({base_model_name_input_animagine: base_model_name}) updates.update({prompt_animagine: data['prompt'], negative_prompt_animagine: data['negative_prompt'], seed_animagine: data['seed'], width_animagine: data['width'], height_animagine: data['height'], guidance_scale_animagine: data['cfg_scale'], num_inference_steps_animagine: data['steps'], sampler_animagine: final_sampler, schedule_type_animagine: final_scheduler, model_tabs: gr.Tabs(selected=1)}) elif target_tab == "Pony": if base_model_name in MODEL_MAP_PONY: updates.update({base_model_name_input_pony: base_model_name}) updates.update({prompt_pony: data['prompt'], negative_prompt_pony: data['negative_prompt'], seed_pony: data['seed'], width_pony: data['width'], height_pony: data['height'], guidance_scale_pony: data['cfg_scale'], num_inference_steps_pony: data['steps'], sampler_pony: final_sampler, schedule_type_pony: final_scheduler, model_tabs: gr.Tabs(selected=2)}) elif target_tab == "SD1.5": if base_model_name in MODEL_MAP_SD15: updates.update({base_model_name_input_sd15: base_model_name}) updates.update({prompt_sd15: data['prompt'], negative_prompt_sd15: data['negative_prompt'], seed_sd15: data['seed'], width_sd15: data['width'], height_sd15: data['height'], guidance_scale_sd15: data['cfg_scale'], num_inference_steps_sd15: data['steps'], sampler_sd15: final_sampler, schedule_type_sd15: final_scheduler, clip_skip_sd15: data.get('clip_skip', 1), model_tabs: gr.Tabs(selected=3)}) updates[tabs] = gr.Tabs(selected=0) return updates def send_info_to_tab(image, target_tab): if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components} data = _parse_parameters(image.info['parameters']) return apply_data_to_ui(data, target_tab) def send_info_by_hash(image): if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components} data = _parse_parameters(image.info['parameters']) model_hash = data.get('model_hash') display_name = HASH_TO_DISPLAY_NAME_MAP.get(model_hash) if not display_name: raise gr.Error("Model hash not found in this app's model list. The original model name from the PNG will be used if it exists in the target tab.") if display_name in MODEL_MAP_ILLUSTRIOUS: target_tab = "Illustrious" elif display_name in MODEL_MAP_ANIMAGINE: target_tab = "Animagine" elif display_name in MODEL_MAP_PONY: target_tab = "Pony" elif display_name in MODEL_MAP_SD15: target_tab = "SD1.5" else: raise gr.Error("Cannot determine the correct tab for this model.") data['base_model'] = display_name return apply_data_to_ui(data, target_tab) # --- UI Generation Functions --- def create_lora_settings_ui(): with gr.Accordion("LoRA Settings", open=False): gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.") gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.") with gr.Row(): civitai_api_key = gr.Textbox(label="Civitai API Key", placeholder="Enter your Civitai API Key", type="password", scale=1) tensorart_api_key = gr.Textbox(label="TensorArt API Key", placeholder="Enter your TensorArt API Key", type="password", scale=1) gr.Markdown("---") gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.") gr.Markdown("""
Input Examples:
""") gr.Markdown("""
Notice:
""") lora_rows, sources, ids, scales, uploads = [], [], [], [], [] for i in range(MAX_LORAS): with gr.Row(visible=(i == 0)) as row: source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai", scale=1) lora_id = gr.Textbox(label="ID / URL / File", placeholder="e.g.: 133755", scale=2) scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0, scale=2) upload = gr.UploadButton("Upload", file_types=[".safetensors"], scale=1) lora_rows.append(row); sources.append(source); ids.append(lora_id); scales.append(scale); uploads.append(upload) upload.upload(fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()), inputs=[upload], outputs=[lora_id, source]) with gr.Row(): add_button = gr.Button("✚ Add LoRA"); delete_button = gr.Button("➖ Delete LoRA", visible=False) count_state = gr.State(value=1) all_components = [item for sublist in zip(sources, ids, scales, uploads) for item in sublist] return (civitai_api_key, tensorart_api_key, lora_rows, sources, ids, scales, uploads, add_button, delete_button, count_state, all_components) def download_all_models_on_startup(): """Downloads all base models listed in ALL_MODEL_MAP when the app starts.""" print("--- Starting pre-download of all base models ---") for model_display_name, model_info in ALL_MODEL_MAP.items(): repo_id, filename, _, _ = model_info local_file_path = os.path.join(CHECKPOINT_DIR, filename) if os.path.exists(local_file_path): print(f"✅ Model '{filename}' already exists. Skipping download.") continue try: print(f"Downloading: {model_display_name} ({filename})...") hf_hub_download( repo_id=repo_id, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False ) print(f"✅ Successfully downloaded {filename}.") except Exception as e: print(f"❌ Failed to download {filename} from {repo_id}: {e}") print("--- Finished pre-downloading all base models ---") # --- Execute model download on startup --- download_all_models_on_startup() # --- Gradio UI --- with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo: gr.Markdown("# Animated T2I with LoRAs") with gr.Tabs(elem_id="tabs_container") as tabs: with gr.TabItem("txt2img", id=0): with gr.Tabs() as model_tabs: for tab_name, model_map, defaults in [ ("Illustrious", MODEL_MAP_ILLUSTRIOUS, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}), ("Animagine", MODEL_MAP_ANIMAGINE, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}), ("Pony", MODEL_MAP_PONY, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}), ("SD1.5", MODEL_MAP_SD15, {'w': 512, 'h': 768, 'cs_vis': True, 'cs_val': 1}) ]: with gr.TabItem(tab_name): gr.Markdown("💡 **Tip:** Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.") with gr.Column(): with gr.Row(): base_model = gr.Dropdown(label="Base Model", choices=list(model_map.keys()), value=list(model_map.keys())[0], scale=3) with gr.Column(scale=1): predownload_lora = gr.Button("Pre-download LoRAs"); run = gr.Button("Run", variant="primary") predownload_status = gr.Markdown("") prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt") neg_prompt = gr.Text(label="Negative prompt", lines=3, value=DEFAULT_NEGATIVE_PROMPT) with gr.Row(): with gr.Column(scale=2): with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=defaults['w']); height = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=defaults['h']) with gr.Row(): sampler = gr.Dropdown(label="Sampling method", choices=SAMPLER_CHOICES, value=SAMPLER_CHOICES[0]) default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0] scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value=default_scheduler) with gr.Row(): cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5); steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28) with gr.Column(scale=1): result = gr.Gallery(label="Result", show_label=False, columns=2, object_fit="contain", height="auto") with gr.Row(): seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1) clip_skip = gr.Slider(label="Clip Skip", minimum=1, maximum=2, step=1, value=defaults['cs_val'], visible=defaults['cs_vis']) zero_gpu = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120") lora_settings = create_lora_settings_ui() # Assign specific variables for event handlers if tab_name == "Illustrious": base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, result_illustrious = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result civitai_api_key_illustrious, tensorart_api_key_illustrious, lora_rows_illustrious, _, lora_id_inputs_illustrious, lora_scale_inputs_illustrious, _, add_lora_button_illustrious, delete_lora_button_illustrious, lora_count_state_illustrious, all_lora_components_flat_illustrious = lora_settings predownload_lora_button_illustrious, run_button_illustrious, predownload_status_illustrious = predownload_lora, run, predownload_status elif tab_name == "Animagine": base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, result_animagine = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result civitai_api_key_animagine, tensorart_api_key_animagine, lora_rows_animagine, _, lora_id_inputs_animagine, lora_scale_inputs_animagine, _, add_lora_button_animagine, delete_lora_button_animagine, lora_count_state_animagine, all_lora_components_flat_animagine = lora_settings predownload_lora_button_animagine, run_button_animagine, predownload_status_animagine = predownload_lora, run, predownload_status elif tab_name == "Pony": base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, result_pony = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result civitai_api_key_pony, tensorart_api_key_pony, lora_rows_pony, _, lora_id_inputs_pony, lora_scale_inputs_pony, _, add_lora_button_pony, delete_lora_button_pony, lora_count_state_pony, all_lora_components_flat_pony = lora_settings predownload_lora_button_pony, run_button_pony, predownload_status_pony = predownload_lora, run, predownload_status elif tab_name == "SD1.5": base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, zero_gpu_duration_sd15, result_sd15 = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, clip_skip, zero_gpu, result civitai_api_key_sd15, tensorart_api_key_sd15, lora_rows_sd15, _, lora_id_inputs_sd15, lora_scale_inputs_sd15, _, add_lora_button_sd15, delete_lora_button_sd15, lora_count_state_sd15, all_lora_components_flat_sd15 = lora_settings predownload_lora_button_sd15, run_button_sd15, predownload_status_sd15 = predownload_lora, run, predownload_status with gr.TabItem("PNG Info", id=1): with gr.Column(): info_image_input = gr.Image(type="pil", label="Upload Image", height=512) with gr.Row(): info_get_button = gr.Button("Get Info") send_by_hash_button = gr.Button("Send to txt2img by Model Hash", variant="primary") with gr.Row(): send_to_illustrious_button = gr.Button("Send to Illustrious") send_to_animagine_button = gr.Button("Send to Animagine") send_to_pony_button = gr.Button("Send to Pony") send_to_sd15_button = gr.Button("Send to SD1.5") gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False) gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False) gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False) gr.Markdown("
Made by RioShiina with ❤️
") # --- Event Handlers --- def create_lora_event_handlers(lora_rows, count_state, add_button, del_button, lora_ids, lora_scales): def add_lora_row(c): return {count_state: c+1, lora_rows[c]: gr.update(visible=True), del_button: gr.update(visible=True), add_button: gr.update(visible=c+1 < MAX_LORAS)} def del_lora_row(c): c-=1; return {count_state: c, lora_rows[c]: gr.update(visible=False), lora_ids[c]: "", lora_scales[c]: 0.0, add_button: gr.update(visible=True), del_button: gr.update(visible=c > 1)} add_button.click(add_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows]) del_button.click(del_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows, *lora_ids, *lora_scales]) create_lora_event_handlers(lora_rows_illustrious, lora_count_state_illustrious, add_lora_button_illustrious, delete_lora_button_illustrious, lora_id_inputs_illustrious, lora_scale_inputs_illustrious) predownload_lora_button_illustrious.click(lambda: "⏳ Downloading...", None, [predownload_status_illustrious]).then(pre_download_loras, [civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [predownload_status_illustrious]) run_button_illustrious.click(generate_image_wrapper, [base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [result_illustrious]) create_lora_event_handlers(lora_rows_animagine, lora_count_state_animagine, add_lora_button_animagine, delete_lora_button_animagine, lora_id_inputs_animagine, lora_scale_inputs_animagine) predownload_lora_button_animagine.click(lambda: "⏳ Downloading...", None, [predownload_status_animagine]).then(pre_download_loras, [civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [predownload_status_animagine]) run_button_animagine.click(generate_image_wrapper, [base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [result_animagine]) create_lora_event_handlers(lora_rows_pony, lora_count_state_pony, add_lora_button_pony, delete_lora_button_pony, lora_id_inputs_pony, lora_scale_inputs_pony) predownload_lora_button_pony.click(lambda: "⏳ Downloading...", None, [predownload_status_pony]).then(pre_download_loras, [civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [predownload_status_pony]) run_button_pony.click(generate_image_wrapper, [base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [result_pony]) create_lora_event_handlers(lora_rows_sd15, lora_count_state_sd15, add_lora_button_sd15, delete_lora_button_sd15, lora_id_inputs_sd15, lora_scale_inputs_sd15) predownload_lora_button_sd15.click(lambda: "⏳ Downloading...", None, [predownload_status_sd15]).then(pre_download_loras, [civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15], [predownload_status_sd15]) run_button_sd15.click(generate_image_wrapper, [base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, zero_gpu_duration_sd15, civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15, clip_skip_sd15], [result_sd15]) info_get_button.click(get_png_info, [info_image_input], [info_prompt_output, info_neg_prompt_output, info_params_output]) all_ui_components = [ base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, tabs, model_tabs ] send_to_illustrious_button.click(lambda img: send_info_to_tab(img, "Illustrious"), [info_image_input], all_ui_components) send_to_animagine_button.click(lambda img: send_info_to_tab(img, "Animagine"), [info_image_input], all_ui_components) send_to_pony_button.click(lambda img: send_info_to_tab(img, "Pony"), [info_image_input], all_ui_components) send_to_sd15_button.click(lambda img: send_info_to_tab(img, "SD1.5"), [info_image_input], all_ui_components) send_by_hash_button.click(send_info_by_hash, [info_image_input], all_ui_components) if __name__ == "__main__": demo.queue().launch()