File size: 40,953 Bytes
c2bcd10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
import os
import random
import sys
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import spaces
from comfy import model_management # We need to import this early
import gc
import requests
import re
import hashlib
import shutil

# --- Startup Dummy Function ---
@spaces.GPU(duration=60)
def dummy_gpu_for_startup():
    print("Dummy function for startup check executed. This is normal.")
    return "Startup check passed."

# --- ComfyUI Backend Setup ---
def find_path(name: str, path: str = None) -> str:
    if path is None: path = os.getcwd()
    if name in os.listdir(path): return os.path.join(path, name)
    parent_directory = os.path.dirname(path)
    if parent_directory == path: return None
    return find_path(name, parent_directory)

def add_comfyui_directory_to_sys_path() -> None:
    comfyui_path = find_path("ComfyUI")
    if comfyui_path and os.path.isdir(comfyui_path):
        sys.path.append(comfyui_path)
        print(f"'{comfyui_path}' added to sys.path")

def add_extra_model_paths() -> None:
    try: from main import load_extra_path_config
    except ImportError: from utils.extra_config import load_extra_path_config
    extra_model_paths = find_path("extra_model_paths.yaml")
    if extra_model_paths: load_extra_path_config(extra_model_paths)
    else: print("Could not find extra_model_paths.yaml")

add_comfyui_directory_to_sys_path()
add_extra_model_paths()

# Monkey-patch for Sage Attention
print("Attempting to monkey-patch ComfyUI for Sage Attention...")
try:
    model_management.sage_attention_enabled = lambda: True
    model_management.pytorch_attention_enabled = lambda: False
    print("Successfully monkey-patched model_management for Sage Attention.")
except Exception as e:
    print(f"An error occurred during monkey-patching: {e}")

# --- Constants & Configuration ---
CHECKPOINT_DIR = "models/checkpoints"
LORA_DIR = "models/loras"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LORA_DIR, exist_ok=True)

# --- Model Definitions with Hashes ---
# Format: {Display Name: (Repo ID, Filename, Type, Hash)}
MODEL_MAP_ILLUSTRIOUS = {
    "Laxhar/noobai-XL-Vpred-1.0": ("Laxhar/noobai-XL-Vpred-1.0", "NoobAI-XL-Vpred-v1.0.safetensors", "SDXL", "ea349eeae8"),
    "Laxhar/noobai-XL-1.1": ("Laxhar/noobai-XL-1.1", "NoobAI-XL-v1.1.safetensors", "SDXL", "6681e8e4b1"),
    "WAI0731/wai-nsfw-illustrious-sdxl-v140": ("Ine007/waiNSFWIllustrious_v140", "waiNSFWIllustrious_v140.safetensors", "SDXL", "bdb59bac77"),
    "Ikena/hassaku-xl-illustrious-v30": ("misri/hassakuXLIllustrious_v30", "hassakuXLIllustrious_v30.safetensors", "SDXL", "b4fb5f829a"),
    "bluepen5805/noob_v_pencil-XL": ("bluepen5805/noob_v_pencil-XL", "noob_v_pencil-XL-v3.0.0.safetensors", "SDXL", "90b7911a78"),
    "RedRayz/hikari_noob_v-pred_1.2.2": ("RedRayz/hikari_noob_v-pred_1.2.2", "Hikari_Noob_v-pred_1.2.2.safetensors", "SDXL", "874170688a"),
}
MODEL_MAP_ANIMAGINE = {
    "cagliostrolab/animagine-xl-4.0": ("cagliostrolab/animagine-xl-4.0", "animagine-xl-4.0.safetensors", "SDXL", "6327eca98b"),
    "cagliostrolab/animagine-xl-3.1": ("cagliostrolab/animagine-xl-3.1", "animagine-xl-3.1.safetensors", "SDXL", "e3c47aedb0"),
}
MODEL_MAP_PONY = {
    "PurpleSmartAI/Pony_Diffusion_V6_XL": ("LyliaEngine/Pony_Diffusion_V6_XL", "ponyDiffusionV6XL_v6StartWithThisOne.safetensors", "SDXL", "67ab2fd8ec"),
}
MODEL_MAP_SD15 = {
    "Yuno779/anything-v3": ("ckpt/anything-v3.0", "Anything-V3.0-pruned.safetensors", "SD1.5", "ddd565f806"),
}

# --- Combined Maps for Global Lookup ---
ALL_MODEL_MAP = {**MODEL_MAP_ILLUSTRIOUS, **MODEL_MAP_ANIMAGINE, **MODEL_MAP_PONY, **MODEL_MAP_SD15}
MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()}
DISPLAY_NAME_TO_HASH_MAP = {k: v[3] for k, v in ALL_MODEL_MAP.items()}
HASH_TO_DISPLAY_NAME_MAP = {v[3]: k for k, v in ALL_MODEL_MAP.items()}

# --- UI Defaults ---
DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
MAX_LORAS = 5
LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"]

def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
    try: return obj[index]
    except (KeyError, IndexError):
        try: return obj["result"][index]
        except (KeyError, IndexError): return None

def import_custom_nodes() -> None:
    import asyncio, execution, server
    from nodes import init_extra_nodes
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    server_instance = server.PromptServer(loop)
    execution.PromptQueue(server_instance)
    loop.run_until_complete(init_extra_nodes())

# --- Import ComfyUI Nodes & Get Choices ---
from nodes import CheckpointLoaderSimple, EmptyLatentImage, KSampler, VAEDecode, SaveImage, NODE_CLASS_MAPPINGS
import_custom_nodes()
CLIPTextEncodeSDXL = NODE_CLASS_MAPPINGS['CLIPTextEncodeSDXL']
CLIPTextEncode = NODE_CLASS_MAPPINGS['CLIPTextEncode']
LoraLoader = NODE_CLASS_MAPPINGS['LoraLoader']
CLIPSetLastLayer = NODE_CLASS_MAPPINGS['CLIPSetLastLayer']
try:
    SAMPLER_CHOICES = KSampler.INPUT_TYPES()["required"]["sampler_name"][0]
    SCHEDULER_CHOICES = KSampler.INPUT_TYPES()["required"]["scheduler"][0]
except Exception:
    SAMPLER_CHOICES = ['euler', 'dpmpp_2m_sde_gpu']
    SCHEDULER_CHOICES = ['normal', 'karras']

# --- Instantiate Node Objects ---
checkpointloadersimple = CheckpointLoaderSimple(); cliptextencodesdxl = CLIPTextEncodeSDXL()
cliptextencode_sd15 = CLIPTextEncode(); emptylatentimage = EmptyLatentImage()
ksampler = KSampler(); vaedecode = VAEDecode(); saveimage = SaveImage(); loraloader = LoraLoader()
clipsetlastlayer = CLIPSetLastLayer()

# --- LoRA & File Utils ---
def get_civitai_file_info(version_id):
    api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
    try:
        response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json()
        for file_data in data.get('files', []):
            if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'): return file_data
        if data.get('files'): return data['files'][0]
    except Exception: return None

def get_tensorart_file_info(model_id):
    api_url = f"https://tensor.art/api/v1/models/{model_id}"
    try:
        response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json()
        model_versions = data.get('modelVersions', [])
        if not model_versions: return None
        for file_data in model_versions[0].get('files', []):
            if file_data['name'].endswith('.safetensors'): return file_data
        return model_versions[0]['files'][0] if model_versions[0].get('files') else None
    except Exception: return None

def download_file(url, save_path, api_key=None, progress=None, desc=""):
    if os.path.exists(save_path): return f"File already exists: {os.path.basename(save_path)}"
    headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {}
    try:
        if progress: progress(0, desc=desc)
        response = requests.get(url, stream=True, headers=headers, timeout=15); response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        with open(save_path, "wb") as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                if progress and total_size > 0: downloaded += len(chunk); progress(downloaded / total_size, desc=desc)
        return f"Successfully downloaded: {os.path.basename(save_path)}"
    except Exception as e:
        if os.path.exists(save_path): os.remove(save_path)
        return f"Download failed for {os.path.basename(save_path)}: {e}"

def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress):
    if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided."
    if source == "Civitai":
        version_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"civitai_{version_id}.safetensors"); file_info, api_key_to_use = get_civitai_file_info(version_id), civitai_key; source_name = f"Civitai ID {version_id}"
    elif source == "TensorArt":
        model_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"tensorart_{model_id}.safetensors"); file_info, api_key_to_use = get_tensorart_file_info(model_id), tensorart_key; source_name = f"TensorArt ID {model_id}"
    elif source == "Custom URL":
        url = id_or_url.strip(); url_hash = hashlib.md5(url.encode()).hexdigest(); local_path = os.path.join(LORA_DIR, f"custom_{url_hash}.safetensors"); file_info, api_key_to_use = {'downloadUrl': url}, None; source_name = f"URL {url[:30]}..."
    else: return None, "Invalid source."
    if os.path.exists(local_path): return local_path, "File already exists."
    if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}."
    status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
    return (local_path, status) if "Successfully" in status else (None, status)

def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
    sources, ids, _, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
    active_loras = [(s, i) for s, i, f in zip(sources, ids, files) if s in ["Civitai", "TensorArt", "Custom URL"] and i and i.strip() and f is None]
    if not active_loras: return "No remote LoRAs specified for pre-downloading."
    log = [f"* {s} ID {i}: {get_lora_path(s, i, civitai_api_key, tensorart_api_key, progress)[1]}" for s, i in active_loras]
    return "\n".join(log)

# --- Model Management & Core Logic ---
current_loaded_model_name = None; loaded_checkpoint_tuple = None
def load_model(model_display_name: str, progress=gr.Progress()):
    global current_loaded_model_name, loaded_checkpoint_tuple
    if model_display_name == current_loaded_model_name and loaded_checkpoint_tuple: return loaded_checkpoint_tuple
    if loaded_checkpoint_tuple: model_management.unload_all_models(); loaded_checkpoint_tuple = None; gc.collect(); torch.cuda.empty_cache()
    
    repo_id, filename, _, _ = ALL_MODEL_MAP[model_display_name]
    local_file_path = os.path.join(CHECKPOINT_DIR, filename)
    
    if not os.path.exists(local_file_path):
        progress(0, desc=f"Downloading model: {model_display_name}")
        hf_hub_download(repo_id=repo_id, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False)
    
    progress(0.5, desc=f"Loading '{filename}'")
    MODEL_TUPLE = checkpointloadersimple.load_checkpoint(ckpt_name=filename)
    model_management.load_models_gpu([get_value_at_index(MODEL_TUPLE, 0)])
    current_loaded_model_name = model_display_name; loaded_checkpoint_tuple = MODEL_TUPLE
    progress(1.0, desc="Model loaded"); return loaded_checkpoint_tuple

def _generate_image_logic(model_display_name: str, positive_prompt: str, negative_prompt: str,

                   seed: int, batch_size: int, width: int, height: int, guidance_scale: float, num_inference_steps: int,

                   sampler_name: str, scheduler: str, civitai_api_key: str, tensorart_api_key: str, *lora_data,

                   progress=gr.Progress(track_tqdm=True)):
    output_images = []
    is_sd15 = MODEL_TYPE_MAP.get(model_display_name) == "SD1.5"
    clip_skip = 1
    if is_sd15 and len(lora_data) > MAX_LORAS * 4:
        clip_skip = int(lora_data[-1])
        lora_data = lora_data[:-1]

    with torch.inference_mode():
        model_tuple = load_model(model_display_name, progress)
        model, clip, vae = (get_value_at_index(model_tuple, i) for i in range(3))
        
        if is_sd15:
            clip = get_value_at_index(clipsetlastlayer.set_last_layer(clip=clip, stop_at_clip_layer=-clip_skip), 0)
        
        active_loras_for_meta = []
        sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
        for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)):
            if scale > 0:
                lora_filename = None
                if custom_file:
                    lora_filename = os.path.basename(custom_file.name)
                    shutil.copy(custom_file.name, LORA_DIR)
                elif lora_id and lora_id.strip():
                    local_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
                    if local_path: lora_filename = os.path.basename(local_path)
                
                if lora_filename:
                    lora_tuple = loraloader.load_lora(model=model, clip=clip, lora_name=lora_filename, strength_model=scale, strength_clip=scale)
                    model, clip = get_value_at_index(lora_tuple, 0), get_value_at_index(lora_tuple, 1)
                    active_loras_for_meta.append(f"{source} {lora_id}:{scale}")
        
        loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else ""

        if is_sd15:
            pos_cond = cliptextencode_sd15.encode(text=positive_prompt, clip=clip)
            neg_cond = cliptextencode_sd15.encode(text=negative_prompt, clip=clip)
        else:
            pos_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=positive_prompt, text_l=positive_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0)
            neg_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=negative_prompt, text_l=negative_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0)

        start_seed = seed if seed != -1 else random.randint(0, 2**64 - 1)
        
        latent = emptylatentimage.generate(width=width, height=height, batch_size=batch_size)
        
        sampled = ksampler.sample(
            seed=start_seed, 
            steps=num_inference_steps, 
            cfg=guidance_scale, 
            sampler_name=sampler_name, 
            scheduler=scheduler, 
            denoise=1.0, 
            model=model, 
            positive=get_value_at_index(pos_cond, 0), 
            negative=get_value_at_index(neg_cond, 0), 
            latent_image=get_value_at_index(latent, 0)
        )
        
        decoded_images_tensor = get_value_at_index(vaedecode.decode(samples=get_value_at_index(sampled, 0), vae=vae), 0)

        for i in range(decoded_images_tensor.shape[0]):
            img_tensor = decoded_images_tensor[i]
            pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8"))
            
            current_seed = start_seed + i
            
            model_hash = DISPLAY_NAME_TO_HASH_MAP.get(model_display_name, "N/A")
            params_string = f"{positive_prompt}\nNegative prompt: {negative_prompt}\n"
            params_string += f"Steps: {num_inference_steps}, Sampler: {sampler_name}, Scheduler: {scheduler}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {model_display_name}, Model hash: {model_hash}"
            if is_sd15: params_string += f", Clip skip: {clip_skip}"
            params_string += f", {loras_string}"
            pil_image.info = {'parameters': params_string.strip()}
            
            output_images.append(pil_image)
            
    return output_images

def generate_image_wrapper(*args, **kwargs):
    logic_args_list = list(args[:11])
    zero_gpu_duration = args[11]
    logic_args_list.extend(args[12:])
    duration = 60
    try:
        if zero_gpu_duration and int(zero_gpu_duration) > 0:
            duration = int(zero_gpu_duration)
    except (ValueError, TypeError):
        pass
    return spaces.GPU(duration=duration)(_generate_image_logic)(*logic_args_list, **kwargs)


# --- PNG Info & UI Logic ---
def _parse_parameters(params_text):
    data = {}; lines = params_text.strip().split('\n'); data['prompt'] = lines[0]
    data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
    params_line = '\n'.join(lines[2:])
    def find_param(key, default, cast_type=str):
        match = re.search(fr"\b{key}: ([^,]+?)(,|$|\n)", params_line)
        return cast_type(match.group(1).strip()) if match else default
    data['steps'] = find_param("Steps", 28, int); data['sampler'] = find_param("Sampler", SAMPLER_CHOICES[0], str)
    data['scheduler'] = find_param("Scheduler", SCHEDULER_CHOICES[0], str); data['cfg_scale'] = find_param("CFG scale", 7.5, float)
    data['seed'] = find_param("Seed", -1, int); data['clip_skip'] = find_param("Clip skip", 1, int)
    data['base_model'] = find_param("Base Model", list(ALL_MODEL_MAP.keys())[0], str); data['model_hash'] = find_param("Model hash", None, str)
    size_match = re.search(r"Size: (\d+)x(\d+)", params_line)
    data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
    return data

def get_png_info(image):
    if not image or not (params := image.info.get('parameters')): return "", "", "No metadata found in the image."
    parsed_data = _parse_parameters(params)
    other_params_text = "\n".join([p.strip() for p in '\n'.join(params.strip().split('\n')[2:]).split(',')])
    return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_text

def apply_data_to_ui(data, target_tab):
    final_sampler = data.get('sampler') if data.get('sampler') in SAMPLER_CHOICES else SAMPLER_CHOICES[0]
    default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0]
    final_scheduler = data.get('scheduler') if data.get('scheduler') in SCHEDULER_CHOICES else default_scheduler

    updates = {}
    base_model_name = data.get('base_model')

    if target_tab == "Illustrious":
        if base_model_name in MODEL_MAP_ILLUSTRIOUS:
            updates.update({base_model_name_input_illustrious: base_model_name})
        updates.update({prompt_illustrious: data['prompt'], negative_prompt_illustrious: data['negative_prompt'], seed_illustrious: data['seed'], width_illustrious: data['width'], height_illustrious: data['height'], guidance_scale_illustrious: data['cfg_scale'], num_inference_steps_illustrious: data['steps'], sampler_illustrious: final_sampler, schedule_type_illustrious: final_scheduler, model_tabs: gr.Tabs(selected=0)})
    elif target_tab == "Animagine":
        if base_model_name in MODEL_MAP_ANIMAGINE:
            updates.update({base_model_name_input_animagine: base_model_name})
        updates.update({prompt_animagine: data['prompt'], negative_prompt_animagine: data['negative_prompt'], seed_animagine: data['seed'], width_animagine: data['width'], height_animagine: data['height'], guidance_scale_animagine: data['cfg_scale'], num_inference_steps_animagine: data['steps'], sampler_animagine: final_sampler, schedule_type_animagine: final_scheduler, model_tabs: gr.Tabs(selected=1)})
    elif target_tab == "Pony":
        if base_model_name in MODEL_MAP_PONY:
            updates.update({base_model_name_input_pony: base_model_name})
        updates.update({prompt_pony: data['prompt'], negative_prompt_pony: data['negative_prompt'], seed_pony: data['seed'], width_pony: data['width'], height_pony: data['height'], guidance_scale_pony: data['cfg_scale'], num_inference_steps_pony: data['steps'], sampler_pony: final_sampler, schedule_type_pony: final_scheduler, model_tabs: gr.Tabs(selected=2)})
    elif target_tab == "SD1.5":
        if base_model_name in MODEL_MAP_SD15:
            updates.update({base_model_name_input_sd15: base_model_name})
        updates.update({prompt_sd15: data['prompt'], negative_prompt_sd15: data['negative_prompt'], seed_sd15: data['seed'], width_sd15: data['width'], height_sd15: data['height'], guidance_scale_sd15: data['cfg_scale'], num_inference_steps_sd15: data['steps'], sampler_sd15: final_sampler, schedule_type_sd15: final_scheduler, clip_skip_sd15: data.get('clip_skip', 1), model_tabs: gr.Tabs(selected=3)})
    
    updates[tabs] = gr.Tabs(selected=0)
    return updates

def send_info_to_tab(image, target_tab):
    if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components}
    data = _parse_parameters(image.info['parameters'])
    return apply_data_to_ui(data, target_tab)

def send_info_by_hash(image):
    if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components}
    data = _parse_parameters(image.info['parameters'])
    model_hash = data.get('model_hash')
    display_name = HASH_TO_DISPLAY_NAME_MAP.get(model_hash)

    if not display_name:
        raise gr.Error("Model hash not found in this app's model list. The original model name from the PNG will be used if it exists in the target tab.")

    if display_name in MODEL_MAP_ILLUSTRIOUS: target_tab = "Illustrious"
    elif display_name in MODEL_MAP_ANIMAGINE: target_tab = "Animagine"
    elif display_name in MODEL_MAP_PONY: target_tab = "Pony"
    elif display_name in MODEL_MAP_SD15: target_tab = "SD1.5"
    else:
        raise gr.Error("Cannot determine the correct tab for this model.")

    data['base_model'] = display_name
    return apply_data_to_ui(data, target_tab)

# --- UI Generation Functions ---
def create_lora_settings_ui():
    with gr.Accordion("LoRA Settings", open=False):
        gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
        gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.")
        with gr.Row():
            civitai_api_key = gr.Textbox(label="Civitai API Key", placeholder="Enter your Civitai API Key", type="password", scale=1)
            tensorart_api_key = gr.Textbox(label="TensorArt API Key", placeholder="Enter your TensorArt API Key", type="password", scale=1)
        gr.Markdown("---")
        gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.")
        gr.Markdown("""

        <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-top: 10px; margin-bottom: 15px;'>

            <b>Input Examples:</b>

            <ul>

                <li><b>Civitai:</b> Enter the <b>Model Version ID</b>, not the Model ID. Example: <code>133755</code> (Found in the URL, e.g., <code>civitai.com/models/122136?modelVersionId=<b>133755</b></code>)</li>

                <li><b>TensorArt:</b> Enter the <b>Model ID</b>. Example: <code>706684852832599558</code> (Found in the URL, e.g., <code>tensor.art/models/<b>706684852832599558</b></code>)</li>

                <li><b>Custom URL:</b> Provide a direct download link to a <code>.safetensors</code> file. Example: <code>https://huggingface.co/path/to/your/lora.safetensors</code></li>

                <li><b>File:</b> Use the "Upload" button. The source will be set automatically.</li>

            </ul>

        </div>

        """)
        gr.Markdown("""

        <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>

            <b>Notice:</b>

            <ul style='margin-bottom: 0;'>

                <li>With Gradio, the page may become unresponsive until a file is fully uploaded. Please be patient and wait for the process to complete.</li>

            </ul>

        </div>

        """)
        lora_rows, sources, ids, scales, uploads = [], [], [], [], []
        for i in range(MAX_LORAS):
            with gr.Row(visible=(i == 0)) as row:
                source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai", scale=1)
                lora_id = gr.Textbox(label="ID / URL / File", placeholder="e.g.: 133755", scale=2)
                scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0, scale=2)
                upload = gr.UploadButton("Upload", file_types=[".safetensors"], scale=1)
                lora_rows.append(row); sources.append(source); ids.append(lora_id); scales.append(scale); uploads.append(upload)
                upload.upload(fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()), inputs=[upload], outputs=[lora_id, source])
        with gr.Row(): add_button = gr.Button("✚ Add LoRA"); delete_button = gr.Button("➖ Delete LoRA", visible=False)
        count_state = gr.State(value=1)
        all_components = [item for sublist in zip(sources, ids, scales, uploads) for item in sublist]
    return (civitai_api_key, tensorart_api_key, lora_rows, sources, ids, scales, uploads, add_button, delete_button, count_state, all_components)

def download_all_models_on_startup():
    """Downloads all base models listed in ALL_MODEL_MAP when the app starts."""
    print("--- Starting pre-download of all base models ---")
    for model_display_name, model_info in ALL_MODEL_MAP.items():
        repo_id, filename, _, _ = model_info
        local_file_path = os.path.join(CHECKPOINT_DIR, filename)

        if os.path.exists(local_file_path):
            print(f"✅ Model '{filename}' already exists. Skipping download.")
            continue

        try:
            print(f"Downloading: {model_display_name} ({filename})...")
            hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=CHECKPOINT_DIR,
                local_dir_use_symlinks=False
            )
            print(f"✅ Successfully downloaded {filename}.")
        except Exception as e:
            print(f"❌ Failed to download {filename} from {repo_id}: {e}")
    print("--- Finished pre-downloading all base models ---")

# --- Execute model download on startup ---
download_all_models_on_startup()

# --- Gradio UI ---
with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
    gr.Markdown("# Animated T2I with LoRAs")
    with gr.Tabs(elem_id="tabs_container") as tabs:
        with gr.TabItem("txt2img", id=0):
            with gr.Tabs() as model_tabs:
                for tab_name, model_map, defaults in [
                    ("Illustrious", MODEL_MAP_ILLUSTRIOUS, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
                    ("Animagine", MODEL_MAP_ANIMAGINE, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
                    ("Pony", MODEL_MAP_PONY, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
                    ("SD1.5", MODEL_MAP_SD15, {'w': 512, 'h': 768, 'cs_vis': True, 'cs_val': 1})
                ]:
                    with gr.TabItem(tab_name):
                        gr.Markdown("💡 **Tip:** Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.")
                        with gr.Column():
                            with gr.Row():
                                base_model = gr.Dropdown(label="Base Model", choices=list(model_map.keys()), value=list(model_map.keys())[0], scale=3)
                                with gr.Column(scale=1): predownload_lora = gr.Button("Pre-download LoRAs"); run = gr.Button("Run", variant="primary")
                            predownload_status = gr.Markdown("")
                            prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
                            neg_prompt = gr.Text(label="Negative prompt", lines=3, value=DEFAULT_NEGATIVE_PROMPT)
                            with gr.Row():
                                with gr.Column(scale=2):
                                    with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=defaults['w']); height = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=defaults['h'])
                                    with gr.Row(): 
                                        sampler = gr.Dropdown(label="Sampling method", choices=SAMPLER_CHOICES, value=SAMPLER_CHOICES[0])
                                        default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0]
                                        scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value=default_scheduler)
                                    with gr.Row(): cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5); steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
                                with gr.Column(scale=1): result = gr.Gallery(label="Result", show_label=False, columns=2, object_fit="contain", height="auto")
                            with gr.Row():
                                seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
                                batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
                                clip_skip = gr.Slider(label="Clip Skip", minimum=1, maximum=2, step=1, value=defaults['cs_val'], visible=defaults['cs_vis'])
                                zero_gpu = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120")
                            lora_settings = create_lora_settings_ui()
                            
                            # Assign specific variables for event handlers
                            if tab_name == "Illustrious":
                                base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, result_illustrious = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
                                civitai_api_key_illustrious, tensorart_api_key_illustrious, lora_rows_illustrious, _, lora_id_inputs_illustrious, lora_scale_inputs_illustrious, _, add_lora_button_illustrious, delete_lora_button_illustrious, lora_count_state_illustrious, all_lora_components_flat_illustrious = lora_settings
                                predownload_lora_button_illustrious, run_button_illustrious, predownload_status_illustrious = predownload_lora, run, predownload_status
                            elif tab_name == "Animagine":
                                base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, result_animagine = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
                                civitai_api_key_animagine, tensorart_api_key_animagine, lora_rows_animagine, _, lora_id_inputs_animagine, lora_scale_inputs_animagine, _, add_lora_button_animagine, delete_lora_button_animagine, lora_count_state_animagine, all_lora_components_flat_animagine = lora_settings
                                predownload_lora_button_animagine, run_button_animagine, predownload_status_animagine = predownload_lora, run, predownload_status
                            elif tab_name == "Pony":
                                base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, result_pony = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
                                civitai_api_key_pony, tensorart_api_key_pony, lora_rows_pony, _, lora_id_inputs_pony, lora_scale_inputs_pony, _, add_lora_button_pony, delete_lora_button_pony, lora_count_state_pony, all_lora_components_flat_pony = lora_settings
                                predownload_lora_button_pony, run_button_pony, predownload_status_pony = predownload_lora, run, predownload_status
                            elif tab_name == "SD1.5":
                                base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, zero_gpu_duration_sd15, result_sd15 = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, clip_skip, zero_gpu, result
                                civitai_api_key_sd15, tensorart_api_key_sd15, lora_rows_sd15, _, lora_id_inputs_sd15, lora_scale_inputs_sd15, _, add_lora_button_sd15, delete_lora_button_sd15, lora_count_state_sd15, all_lora_components_flat_sd15 = lora_settings
                                predownload_lora_button_sd15, run_button_sd15, predownload_status_sd15 = predownload_lora, run, predownload_status
        with gr.TabItem("PNG Info", id=1):
            with gr.Column():
                info_image_input = gr.Image(type="pil", label="Upload Image", height=512)
                with gr.Row():
                    info_get_button = gr.Button("Get Info")
                    send_by_hash_button = gr.Button("Send to txt2img by Model Hash", variant="primary")
                with gr.Row():
                    send_to_illustrious_button = gr.Button("Send to Illustrious")
                    send_to_animagine_button = gr.Button("Send to Animagine")
                    send_to_pony_button = gr.Button("Send to Pony")
                    send_to_sd15_button = gr.Button("Send to SD1.5")
                gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
                gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
                gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
    gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>")
    
    # --- Event Handlers ---
    def create_lora_event_handlers(lora_rows, count_state, add_button, del_button, lora_ids, lora_scales):
        def add_lora_row(c): return {count_state: c+1, lora_rows[c]: gr.update(visible=True), del_button: gr.update(visible=True), add_button: gr.update(visible=c+1 < MAX_LORAS)}
        def del_lora_row(c): c-=1; return {count_state: c, lora_rows[c]: gr.update(visible=False), lora_ids[c]: "", lora_scales[c]: 0.0, add_button: gr.update(visible=True), del_button: gr.update(visible=c > 1)}
        add_button.click(add_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows])
        del_button.click(del_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows, *lora_ids, *lora_scales])

    create_lora_event_handlers(lora_rows_illustrious, lora_count_state_illustrious, add_lora_button_illustrious, delete_lora_button_illustrious, lora_id_inputs_illustrious, lora_scale_inputs_illustrious)
    predownload_lora_button_illustrious.click(lambda: "⏳ Downloading...", None, [predownload_status_illustrious]).then(pre_download_loras, [civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [predownload_status_illustrious])
    run_button_illustrious.click(generate_image_wrapper, [base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [result_illustrious])
    
    create_lora_event_handlers(lora_rows_animagine, lora_count_state_animagine, add_lora_button_animagine, delete_lora_button_animagine, lora_id_inputs_animagine, lora_scale_inputs_animagine)
    predownload_lora_button_animagine.click(lambda: "⏳ Downloading...", None, [predownload_status_animagine]).then(pre_download_loras, [civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [predownload_status_animagine])
    run_button_animagine.click(generate_image_wrapper, [base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [result_animagine])

    create_lora_event_handlers(lora_rows_pony, lora_count_state_pony, add_lora_button_pony, delete_lora_button_pony, lora_id_inputs_pony, lora_scale_inputs_pony)
    predownload_lora_button_pony.click(lambda: "⏳ Downloading...", None, [predownload_status_pony]).then(pre_download_loras, [civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [predownload_status_pony])
    run_button_pony.click(generate_image_wrapper, [base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [result_pony])
    
    create_lora_event_handlers(lora_rows_sd15, lora_count_state_sd15, add_lora_button_sd15, delete_lora_button_sd15, lora_id_inputs_sd15, lora_scale_inputs_sd15)
    predownload_lora_button_sd15.click(lambda: "⏳ Downloading...", None, [predownload_status_sd15]).then(pre_download_loras, [civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15], [predownload_status_sd15])
    run_button_sd15.click(generate_image_wrapper, [base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, zero_gpu_duration_sd15, civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15, clip_skip_sd15], [result_sd15])
    
    info_get_button.click(get_png_info, [info_image_input], [info_prompt_output, info_neg_prompt_output, info_params_output])
    all_ui_components = [
        base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious,
        base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine,
        base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony,
        base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15,
        tabs, model_tabs
    ]
    send_to_illustrious_button.click(lambda img: send_info_to_tab(img, "Illustrious"), [info_image_input], all_ui_components)
    send_to_animagine_button.click(lambda img: send_info_to_tab(img, "Animagine"), [info_image_input], all_ui_components)
    send_to_pony_button.click(lambda img: send_info_to_tab(img, "Pony"), [info_image_input], all_ui_components)
    send_to_sd15_button.click(lambda img: send_info_to_tab(img, "SD1.5"), [info_image_input], all_ui_components)
    send_by_hash_button.click(send_info_by_hash, [info_image_input], all_ui_components)

if __name__ == "__main__":
    demo.queue().launch()