RioShiina's picture
Upload folder using huggingface_hub
c2bcd10 verified
import os
import random
import sys
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import spaces
from comfy import model_management # We need to import this early
import gc
import requests
import re
import hashlib
import shutil
# --- Startup Dummy Function ---
@spaces.GPU(duration=60)
def dummy_gpu_for_startup():
print("Dummy function for startup check executed. This is normal.")
return "Startup check passed."
# --- ComfyUI Backend Setup ---
def find_path(name: str, path: str = None) -> str:
if path is None: path = os.getcwd()
if name in os.listdir(path): return os.path.join(path, name)
parent_directory = os.path.dirname(path)
if parent_directory == path: return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
comfyui_path = find_path("ComfyUI")
if comfyui_path and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
print(f"'{comfyui_path}' added to sys.path")
def add_extra_model_paths() -> None:
try: from main import load_extra_path_config
except ImportError: from utils.extra_config import load_extra_path_config
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths: load_extra_path_config(extra_model_paths)
else: print("Could not find extra_model_paths.yaml")
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
# Monkey-patch for Sage Attention
print("Attempting to monkey-patch ComfyUI for Sage Attention...")
try:
model_management.sage_attention_enabled = lambda: True
model_management.pytorch_attention_enabled = lambda: False
print("Successfully monkey-patched model_management for Sage Attention.")
except Exception as e:
print(f"An error occurred during monkey-patching: {e}")
# --- Constants & Configuration ---
CHECKPOINT_DIR = "models/checkpoints"
LORA_DIR = "models/loras"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LORA_DIR, exist_ok=True)
# --- Model Definitions with Hashes ---
# Format: {Display Name: (Repo ID, Filename, Type, Hash)}
MODEL_MAP_ILLUSTRIOUS = {
"Laxhar/noobai-XL-Vpred-1.0": ("Laxhar/noobai-XL-Vpred-1.0", "NoobAI-XL-Vpred-v1.0.safetensors", "SDXL", "ea349eeae8"),
"Laxhar/noobai-XL-1.1": ("Laxhar/noobai-XL-1.1", "NoobAI-XL-v1.1.safetensors", "SDXL", "6681e8e4b1"),
"WAI0731/wai-nsfw-illustrious-sdxl-v140": ("Ine007/waiNSFWIllustrious_v140", "waiNSFWIllustrious_v140.safetensors", "SDXL", "bdb59bac77"),
"Ikena/hassaku-xl-illustrious-v30": ("misri/hassakuXLIllustrious_v30", "hassakuXLIllustrious_v30.safetensors", "SDXL", "b4fb5f829a"),
"bluepen5805/noob_v_pencil-XL": ("bluepen5805/noob_v_pencil-XL", "noob_v_pencil-XL-v3.0.0.safetensors", "SDXL", "90b7911a78"),
"RedRayz/hikari_noob_v-pred_1.2.2": ("RedRayz/hikari_noob_v-pred_1.2.2", "Hikari_Noob_v-pred_1.2.2.safetensors", "SDXL", "874170688a"),
}
MODEL_MAP_ANIMAGINE = {
"cagliostrolab/animagine-xl-4.0": ("cagliostrolab/animagine-xl-4.0", "animagine-xl-4.0.safetensors", "SDXL", "6327eca98b"),
"cagliostrolab/animagine-xl-3.1": ("cagliostrolab/animagine-xl-3.1", "animagine-xl-3.1.safetensors", "SDXL", "e3c47aedb0"),
}
MODEL_MAP_PONY = {
"PurpleSmartAI/Pony_Diffusion_V6_XL": ("LyliaEngine/Pony_Diffusion_V6_XL", "ponyDiffusionV6XL_v6StartWithThisOne.safetensors", "SDXL", "67ab2fd8ec"),
}
MODEL_MAP_SD15 = {
"Yuno779/anything-v3": ("ckpt/anything-v3.0", "Anything-V3.0-pruned.safetensors", "SD1.5", "ddd565f806"),
}
# --- Combined Maps for Global Lookup ---
ALL_MODEL_MAP = {**MODEL_MAP_ILLUSTRIOUS, **MODEL_MAP_ANIMAGINE, **MODEL_MAP_PONY, **MODEL_MAP_SD15}
MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()}
DISPLAY_NAME_TO_HASH_MAP = {k: v[3] for k, v in ALL_MODEL_MAP.items()}
HASH_TO_DISPLAY_NAME_MAP = {v[3]: k for k, v in ALL_MODEL_MAP.items()}
# --- UI Defaults ---
DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
MAX_LORAS = 5
LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"]
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try: return obj[index]
except (KeyError, IndexError):
try: return obj["result"][index]
except (KeyError, IndexError): return None
def import_custom_nodes() -> None:
import asyncio, execution, server
from nodes import init_extra_nodes
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
loop.run_until_complete(init_extra_nodes())
# --- Import ComfyUI Nodes & Get Choices ---
from nodes import CheckpointLoaderSimple, EmptyLatentImage, KSampler, VAEDecode, SaveImage, NODE_CLASS_MAPPINGS
import_custom_nodes()
CLIPTextEncodeSDXL = NODE_CLASS_MAPPINGS['CLIPTextEncodeSDXL']
CLIPTextEncode = NODE_CLASS_MAPPINGS['CLIPTextEncode']
LoraLoader = NODE_CLASS_MAPPINGS['LoraLoader']
CLIPSetLastLayer = NODE_CLASS_MAPPINGS['CLIPSetLastLayer']
try:
SAMPLER_CHOICES = KSampler.INPUT_TYPES()["required"]["sampler_name"][0]
SCHEDULER_CHOICES = KSampler.INPUT_TYPES()["required"]["scheduler"][0]
except Exception:
SAMPLER_CHOICES = ['euler', 'dpmpp_2m_sde_gpu']
SCHEDULER_CHOICES = ['normal', 'karras']
# --- Instantiate Node Objects ---
checkpointloadersimple = CheckpointLoaderSimple(); cliptextencodesdxl = CLIPTextEncodeSDXL()
cliptextencode_sd15 = CLIPTextEncode(); emptylatentimage = EmptyLatentImage()
ksampler = KSampler(); vaedecode = VAEDecode(); saveimage = SaveImage(); loraloader = LoraLoader()
clipsetlastlayer = CLIPSetLastLayer()
# --- LoRA & File Utils ---
def get_civitai_file_info(version_id):
api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
try:
response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json()
for file_data in data.get('files', []):
if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'): return file_data
if data.get('files'): return data['files'][0]
except Exception: return None
def get_tensorart_file_info(model_id):
api_url = f"https://tensor.art/api/v1/models/{model_id}"
try:
response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json()
model_versions = data.get('modelVersions', [])
if not model_versions: return None
for file_data in model_versions[0].get('files', []):
if file_data['name'].endswith('.safetensors'): return file_data
return model_versions[0]['files'][0] if model_versions[0].get('files') else None
except Exception: return None
def download_file(url, save_path, api_key=None, progress=None, desc=""):
if os.path.exists(save_path): return f"File already exists: {os.path.basename(save_path)}"
headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {}
try:
if progress: progress(0, desc=desc)
response = requests.get(url, stream=True, headers=headers, timeout=15); response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(save_path, "wb") as f:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
if progress and total_size > 0: downloaded += len(chunk); progress(downloaded / total_size, desc=desc)
return f"Successfully downloaded: {os.path.basename(save_path)}"
except Exception as e:
if os.path.exists(save_path): os.remove(save_path)
return f"Download failed for {os.path.basename(save_path)}: {e}"
def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress):
if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided."
if source == "Civitai":
version_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"civitai_{version_id}.safetensors"); file_info, api_key_to_use = get_civitai_file_info(version_id), civitai_key; source_name = f"Civitai ID {version_id}"
elif source == "TensorArt":
model_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"tensorart_{model_id}.safetensors"); file_info, api_key_to_use = get_tensorart_file_info(model_id), tensorart_key; source_name = f"TensorArt ID {model_id}"
elif source == "Custom URL":
url = id_or_url.strip(); url_hash = hashlib.md5(url.encode()).hexdigest(); local_path = os.path.join(LORA_DIR, f"custom_{url_hash}.safetensors"); file_info, api_key_to_use = {'downloadUrl': url}, None; source_name = f"URL {url[:30]}..."
else: return None, "Invalid source."
if os.path.exists(local_path): return local_path, "File already exists."
if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}."
status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
return (local_path, status) if "Successfully" in status else (None, status)
def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
sources, ids, _, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
active_loras = [(s, i) for s, i, f in zip(sources, ids, files) if s in ["Civitai", "TensorArt", "Custom URL"] and i and i.strip() and f is None]
if not active_loras: return "No remote LoRAs specified for pre-downloading."
log = [f"* {s} ID {i}: {get_lora_path(s, i, civitai_api_key, tensorart_api_key, progress)[1]}" for s, i in active_loras]
return "\n".join(log)
# --- Model Management & Core Logic ---
current_loaded_model_name = None; loaded_checkpoint_tuple = None
def load_model(model_display_name: str, progress=gr.Progress()):
global current_loaded_model_name, loaded_checkpoint_tuple
if model_display_name == current_loaded_model_name and loaded_checkpoint_tuple: return loaded_checkpoint_tuple
if loaded_checkpoint_tuple: model_management.unload_all_models(); loaded_checkpoint_tuple = None; gc.collect(); torch.cuda.empty_cache()
repo_id, filename, _, _ = ALL_MODEL_MAP[model_display_name]
local_file_path = os.path.join(CHECKPOINT_DIR, filename)
if not os.path.exists(local_file_path):
progress(0, desc=f"Downloading model: {model_display_name}")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False)
progress(0.5, desc=f"Loading '{filename}'")
MODEL_TUPLE = checkpointloadersimple.load_checkpoint(ckpt_name=filename)
model_management.load_models_gpu([get_value_at_index(MODEL_TUPLE, 0)])
current_loaded_model_name = model_display_name; loaded_checkpoint_tuple = MODEL_TUPLE
progress(1.0, desc="Model loaded"); return loaded_checkpoint_tuple
def _generate_image_logic(model_display_name: str, positive_prompt: str, negative_prompt: str,
seed: int, batch_size: int, width: int, height: int, guidance_scale: float, num_inference_steps: int,
sampler_name: str, scheduler: str, civitai_api_key: str, tensorart_api_key: str, *lora_data,
progress=gr.Progress(track_tqdm=True)):
output_images = []
is_sd15 = MODEL_TYPE_MAP.get(model_display_name) == "SD1.5"
clip_skip = 1
if is_sd15 and len(lora_data) > MAX_LORAS * 4:
clip_skip = int(lora_data[-1])
lora_data = lora_data[:-1]
with torch.inference_mode():
model_tuple = load_model(model_display_name, progress)
model, clip, vae = (get_value_at_index(model_tuple, i) for i in range(3))
if is_sd15:
clip = get_value_at_index(clipsetlastlayer.set_last_layer(clip=clip, stop_at_clip_layer=-clip_skip), 0)
active_loras_for_meta = []
sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)):
if scale > 0:
lora_filename = None
if custom_file:
lora_filename = os.path.basename(custom_file.name)
shutil.copy(custom_file.name, LORA_DIR)
elif lora_id and lora_id.strip():
local_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
if local_path: lora_filename = os.path.basename(local_path)
if lora_filename:
lora_tuple = loraloader.load_lora(model=model, clip=clip, lora_name=lora_filename, strength_model=scale, strength_clip=scale)
model, clip = get_value_at_index(lora_tuple, 0), get_value_at_index(lora_tuple, 1)
active_loras_for_meta.append(f"{source} {lora_id}:{scale}")
loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else ""
if is_sd15:
pos_cond = cliptextencode_sd15.encode(text=positive_prompt, clip=clip)
neg_cond = cliptextencode_sd15.encode(text=negative_prompt, clip=clip)
else:
pos_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=positive_prompt, text_l=positive_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0)
neg_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=negative_prompt, text_l=negative_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0)
start_seed = seed if seed != -1 else random.randint(0, 2**64 - 1)
latent = emptylatentimage.generate(width=width, height=height, batch_size=batch_size)
sampled = ksampler.sample(
seed=start_seed,
steps=num_inference_steps,
cfg=guidance_scale,
sampler_name=sampler_name,
scheduler=scheduler,
denoise=1.0,
model=model,
positive=get_value_at_index(pos_cond, 0),
negative=get_value_at_index(neg_cond, 0),
latent_image=get_value_at_index(latent, 0)
)
decoded_images_tensor = get_value_at_index(vaedecode.decode(samples=get_value_at_index(sampled, 0), vae=vae), 0)
for i in range(decoded_images_tensor.shape[0]):
img_tensor = decoded_images_tensor[i]
pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8"))
current_seed = start_seed + i
model_hash = DISPLAY_NAME_TO_HASH_MAP.get(model_display_name, "N/A")
params_string = f"{positive_prompt}\nNegative prompt: {negative_prompt}\n"
params_string += f"Steps: {num_inference_steps}, Sampler: {sampler_name}, Scheduler: {scheduler}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {model_display_name}, Model hash: {model_hash}"
if is_sd15: params_string += f", Clip skip: {clip_skip}"
params_string += f", {loras_string}"
pil_image.info = {'parameters': params_string.strip()}
output_images.append(pil_image)
return output_images
def generate_image_wrapper(*args, **kwargs):
logic_args_list = list(args[:11])
zero_gpu_duration = args[11]
logic_args_list.extend(args[12:])
duration = 60
try:
if zero_gpu_duration and int(zero_gpu_duration) > 0:
duration = int(zero_gpu_duration)
except (ValueError, TypeError):
pass
return spaces.GPU(duration=duration)(_generate_image_logic)(*logic_args_list, **kwargs)
# --- PNG Info & UI Logic ---
def _parse_parameters(params_text):
data = {}; lines = params_text.strip().split('\n'); data['prompt'] = lines[0]
data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
params_line = '\n'.join(lines[2:])
def find_param(key, default, cast_type=str):
match = re.search(fr"\b{key}: ([^,]+?)(,|$|\n)", params_line)
return cast_type(match.group(1).strip()) if match else default
data['steps'] = find_param("Steps", 28, int); data['sampler'] = find_param("Sampler", SAMPLER_CHOICES[0], str)
data['scheduler'] = find_param("Scheduler", SCHEDULER_CHOICES[0], str); data['cfg_scale'] = find_param("CFG scale", 7.5, float)
data['seed'] = find_param("Seed", -1, int); data['clip_skip'] = find_param("Clip skip", 1, int)
data['base_model'] = find_param("Base Model", list(ALL_MODEL_MAP.keys())[0], str); data['model_hash'] = find_param("Model hash", None, str)
size_match = re.search(r"Size: (\d+)x(\d+)", params_line)
data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
return data
def get_png_info(image):
if not image or not (params := image.info.get('parameters')): return "", "", "No metadata found in the image."
parsed_data = _parse_parameters(params)
other_params_text = "\n".join([p.strip() for p in '\n'.join(params.strip().split('\n')[2:]).split(',')])
return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_text
def apply_data_to_ui(data, target_tab):
final_sampler = data.get('sampler') if data.get('sampler') in SAMPLER_CHOICES else SAMPLER_CHOICES[0]
default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0]
final_scheduler = data.get('scheduler') if data.get('scheduler') in SCHEDULER_CHOICES else default_scheduler
updates = {}
base_model_name = data.get('base_model')
if target_tab == "Illustrious":
if base_model_name in MODEL_MAP_ILLUSTRIOUS:
updates.update({base_model_name_input_illustrious: base_model_name})
updates.update({prompt_illustrious: data['prompt'], negative_prompt_illustrious: data['negative_prompt'], seed_illustrious: data['seed'], width_illustrious: data['width'], height_illustrious: data['height'], guidance_scale_illustrious: data['cfg_scale'], num_inference_steps_illustrious: data['steps'], sampler_illustrious: final_sampler, schedule_type_illustrious: final_scheduler, model_tabs: gr.Tabs(selected=0)})
elif target_tab == "Animagine":
if base_model_name in MODEL_MAP_ANIMAGINE:
updates.update({base_model_name_input_animagine: base_model_name})
updates.update({prompt_animagine: data['prompt'], negative_prompt_animagine: data['negative_prompt'], seed_animagine: data['seed'], width_animagine: data['width'], height_animagine: data['height'], guidance_scale_animagine: data['cfg_scale'], num_inference_steps_animagine: data['steps'], sampler_animagine: final_sampler, schedule_type_animagine: final_scheduler, model_tabs: gr.Tabs(selected=1)})
elif target_tab == "Pony":
if base_model_name in MODEL_MAP_PONY:
updates.update({base_model_name_input_pony: base_model_name})
updates.update({prompt_pony: data['prompt'], negative_prompt_pony: data['negative_prompt'], seed_pony: data['seed'], width_pony: data['width'], height_pony: data['height'], guidance_scale_pony: data['cfg_scale'], num_inference_steps_pony: data['steps'], sampler_pony: final_sampler, schedule_type_pony: final_scheduler, model_tabs: gr.Tabs(selected=2)})
elif target_tab == "SD1.5":
if base_model_name in MODEL_MAP_SD15:
updates.update({base_model_name_input_sd15: base_model_name})
updates.update({prompt_sd15: data['prompt'], negative_prompt_sd15: data['negative_prompt'], seed_sd15: data['seed'], width_sd15: data['width'], height_sd15: data['height'], guidance_scale_sd15: data['cfg_scale'], num_inference_steps_sd15: data['steps'], sampler_sd15: final_sampler, schedule_type_sd15: final_scheduler, clip_skip_sd15: data.get('clip_skip', 1), model_tabs: gr.Tabs(selected=3)})
updates[tabs] = gr.Tabs(selected=0)
return updates
def send_info_to_tab(image, target_tab):
if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components}
data = _parse_parameters(image.info['parameters'])
return apply_data_to_ui(data, target_tab)
def send_info_by_hash(image):
if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components}
data = _parse_parameters(image.info['parameters'])
model_hash = data.get('model_hash')
display_name = HASH_TO_DISPLAY_NAME_MAP.get(model_hash)
if not display_name:
raise gr.Error("Model hash not found in this app's model list. The original model name from the PNG will be used if it exists in the target tab.")
if display_name in MODEL_MAP_ILLUSTRIOUS: target_tab = "Illustrious"
elif display_name in MODEL_MAP_ANIMAGINE: target_tab = "Animagine"
elif display_name in MODEL_MAP_PONY: target_tab = "Pony"
elif display_name in MODEL_MAP_SD15: target_tab = "SD1.5"
else:
raise gr.Error("Cannot determine the correct tab for this model.")
data['base_model'] = display_name
return apply_data_to_ui(data, target_tab)
# --- UI Generation Functions ---
def create_lora_settings_ui():
with gr.Accordion("LoRA Settings", open=False):
gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.")
with gr.Row():
civitai_api_key = gr.Textbox(label="Civitai API Key", placeholder="Enter your Civitai API Key", type="password", scale=1)
tensorart_api_key = gr.Textbox(label="TensorArt API Key", placeholder="Enter your TensorArt API Key", type="password", scale=1)
gr.Markdown("---")
gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.")
gr.Markdown("""
<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-top: 10px; margin-bottom: 15px;'>
<b>Input Examples:</b>
<ul>
<li><b>Civitai:</b> Enter the <b>Model Version ID</b>, not the Model ID. Example: <code>133755</code> (Found in the URL, e.g., <code>civitai.com/models/122136?modelVersionId=<b>133755</b></code>)</li>
<li><b>TensorArt:</b> Enter the <b>Model ID</b>. Example: <code>706684852832599558</code> (Found in the URL, e.g., <code>tensor.art/models/<b>706684852832599558</b></code>)</li>
<li><b>Custom URL:</b> Provide a direct download link to a <code>.safetensors</code> file. Example: <code>https://huggingface.co/path/to/your/lora.safetensors</code></li>
<li><b>File:</b> Use the "Upload" button. The source will be set automatically.</li>
</ul>
</div>
""")
gr.Markdown("""
<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>
<b>Notice:</b>
<ul style='margin-bottom: 0;'>
<li>With Gradio, the page may become unresponsive until a file is fully uploaded. Please be patient and wait for the process to complete.</li>
</ul>
</div>
""")
lora_rows, sources, ids, scales, uploads = [], [], [], [], []
for i in range(MAX_LORAS):
with gr.Row(visible=(i == 0)) as row:
source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai", scale=1)
lora_id = gr.Textbox(label="ID / URL / File", placeholder="e.g.: 133755", scale=2)
scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0, scale=2)
upload = gr.UploadButton("Upload", file_types=[".safetensors"], scale=1)
lora_rows.append(row); sources.append(source); ids.append(lora_id); scales.append(scale); uploads.append(upload)
upload.upload(fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()), inputs=[upload], outputs=[lora_id, source])
with gr.Row(): add_button = gr.Button("✚ Add LoRA"); delete_button = gr.Button("➖ Delete LoRA", visible=False)
count_state = gr.State(value=1)
all_components = [item for sublist in zip(sources, ids, scales, uploads) for item in sublist]
return (civitai_api_key, tensorart_api_key, lora_rows, sources, ids, scales, uploads, add_button, delete_button, count_state, all_components)
def download_all_models_on_startup():
"""Downloads all base models listed in ALL_MODEL_MAP when the app starts."""
print("--- Starting pre-download of all base models ---")
for model_display_name, model_info in ALL_MODEL_MAP.items():
repo_id, filename, _, _ = model_info
local_file_path = os.path.join(CHECKPOINT_DIR, filename)
if os.path.exists(local_file_path):
print(f"✅ Model '{filename}' already exists. Skipping download.")
continue
try:
print(f"Downloading: {model_display_name} ({filename})...")
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False
)
print(f"✅ Successfully downloaded {filename}.")
except Exception as e:
print(f"❌ Failed to download {filename} from {repo_id}: {e}")
print("--- Finished pre-downloading all base models ---")
# --- Execute model download on startup ---
download_all_models_on_startup()
# --- Gradio UI ---
with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
gr.Markdown("# Animated T2I with LoRAs")
with gr.Tabs(elem_id="tabs_container") as tabs:
with gr.TabItem("txt2img", id=0):
with gr.Tabs() as model_tabs:
for tab_name, model_map, defaults in [
("Illustrious", MODEL_MAP_ILLUSTRIOUS, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
("Animagine", MODEL_MAP_ANIMAGINE, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
("Pony", MODEL_MAP_PONY, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
("SD1.5", MODEL_MAP_SD15, {'w': 512, 'h': 768, 'cs_vis': True, 'cs_val': 1})
]:
with gr.TabItem(tab_name):
gr.Markdown("💡 **Tip:** Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.")
with gr.Column():
with gr.Row():
base_model = gr.Dropdown(label="Base Model", choices=list(model_map.keys()), value=list(model_map.keys())[0], scale=3)
with gr.Column(scale=1): predownload_lora = gr.Button("Pre-download LoRAs"); run = gr.Button("Run", variant="primary")
predownload_status = gr.Markdown("")
prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
neg_prompt = gr.Text(label="Negative prompt", lines=3, value=DEFAULT_NEGATIVE_PROMPT)
with gr.Row():
with gr.Column(scale=2):
with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=defaults['w']); height = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=defaults['h'])
with gr.Row():
sampler = gr.Dropdown(label="Sampling method", choices=SAMPLER_CHOICES, value=SAMPLER_CHOICES[0])
default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0]
scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value=default_scheduler)
with gr.Row(): cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5); steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
with gr.Column(scale=1): result = gr.Gallery(label="Result", show_label=False, columns=2, object_fit="contain", height="auto")
with gr.Row():
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
clip_skip = gr.Slider(label="Clip Skip", minimum=1, maximum=2, step=1, value=defaults['cs_val'], visible=defaults['cs_vis'])
zero_gpu = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120")
lora_settings = create_lora_settings_ui()
# Assign specific variables for event handlers
if tab_name == "Illustrious":
base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, result_illustrious = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
civitai_api_key_illustrious, tensorart_api_key_illustrious, lora_rows_illustrious, _, lora_id_inputs_illustrious, lora_scale_inputs_illustrious, _, add_lora_button_illustrious, delete_lora_button_illustrious, lora_count_state_illustrious, all_lora_components_flat_illustrious = lora_settings
predownload_lora_button_illustrious, run_button_illustrious, predownload_status_illustrious = predownload_lora, run, predownload_status
elif tab_name == "Animagine":
base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, result_animagine = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
civitai_api_key_animagine, tensorart_api_key_animagine, lora_rows_animagine, _, lora_id_inputs_animagine, lora_scale_inputs_animagine, _, add_lora_button_animagine, delete_lora_button_animagine, lora_count_state_animagine, all_lora_components_flat_animagine = lora_settings
predownload_lora_button_animagine, run_button_animagine, predownload_status_animagine = predownload_lora, run, predownload_status
elif tab_name == "Pony":
base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, result_pony = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
civitai_api_key_pony, tensorart_api_key_pony, lora_rows_pony, _, lora_id_inputs_pony, lora_scale_inputs_pony, _, add_lora_button_pony, delete_lora_button_pony, lora_count_state_pony, all_lora_components_flat_pony = lora_settings
predownload_lora_button_pony, run_button_pony, predownload_status_pony = predownload_lora, run, predownload_status
elif tab_name == "SD1.5":
base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, zero_gpu_duration_sd15, result_sd15 = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, clip_skip, zero_gpu, result
civitai_api_key_sd15, tensorart_api_key_sd15, lora_rows_sd15, _, lora_id_inputs_sd15, lora_scale_inputs_sd15, _, add_lora_button_sd15, delete_lora_button_sd15, lora_count_state_sd15, all_lora_components_flat_sd15 = lora_settings
predownload_lora_button_sd15, run_button_sd15, predownload_status_sd15 = predownload_lora, run, predownload_status
with gr.TabItem("PNG Info", id=1):
with gr.Column():
info_image_input = gr.Image(type="pil", label="Upload Image", height=512)
with gr.Row():
info_get_button = gr.Button("Get Info")
send_by_hash_button = gr.Button("Send to txt2img by Model Hash", variant="primary")
with gr.Row():
send_to_illustrious_button = gr.Button("Send to Illustrious")
send_to_animagine_button = gr.Button("Send to Animagine")
send_to_pony_button = gr.Button("Send to Pony")
send_to_sd15_button = gr.Button("Send to SD1.5")
gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>")
# --- Event Handlers ---
def create_lora_event_handlers(lora_rows, count_state, add_button, del_button, lora_ids, lora_scales):
def add_lora_row(c): return {count_state: c+1, lora_rows[c]: gr.update(visible=True), del_button: gr.update(visible=True), add_button: gr.update(visible=c+1 < MAX_LORAS)}
def del_lora_row(c): c-=1; return {count_state: c, lora_rows[c]: gr.update(visible=False), lora_ids[c]: "", lora_scales[c]: 0.0, add_button: gr.update(visible=True), del_button: gr.update(visible=c > 1)}
add_button.click(add_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows])
del_button.click(del_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows, *lora_ids, *lora_scales])
create_lora_event_handlers(lora_rows_illustrious, lora_count_state_illustrious, add_lora_button_illustrious, delete_lora_button_illustrious, lora_id_inputs_illustrious, lora_scale_inputs_illustrious)
predownload_lora_button_illustrious.click(lambda: "⏳ Downloading...", None, [predownload_status_illustrious]).then(pre_download_loras, [civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [predownload_status_illustrious])
run_button_illustrious.click(generate_image_wrapper, [base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [result_illustrious])
create_lora_event_handlers(lora_rows_animagine, lora_count_state_animagine, add_lora_button_animagine, delete_lora_button_animagine, lora_id_inputs_animagine, lora_scale_inputs_animagine)
predownload_lora_button_animagine.click(lambda: "⏳ Downloading...", None, [predownload_status_animagine]).then(pre_download_loras, [civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [predownload_status_animagine])
run_button_animagine.click(generate_image_wrapper, [base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [result_animagine])
create_lora_event_handlers(lora_rows_pony, lora_count_state_pony, add_lora_button_pony, delete_lora_button_pony, lora_id_inputs_pony, lora_scale_inputs_pony)
predownload_lora_button_pony.click(lambda: "⏳ Downloading...", None, [predownload_status_pony]).then(pre_download_loras, [civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [predownload_status_pony])
run_button_pony.click(generate_image_wrapper, [base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [result_pony])
create_lora_event_handlers(lora_rows_sd15, lora_count_state_sd15, add_lora_button_sd15, delete_lora_button_sd15, lora_id_inputs_sd15, lora_scale_inputs_sd15)
predownload_lora_button_sd15.click(lambda: "⏳ Downloading...", None, [predownload_status_sd15]).then(pre_download_loras, [civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15], [predownload_status_sd15])
run_button_sd15.click(generate_image_wrapper, [base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, zero_gpu_duration_sd15, civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15, clip_skip_sd15], [result_sd15])
info_get_button.click(get_png_info, [info_image_input], [info_prompt_output, info_neg_prompt_output, info_params_output])
all_ui_components = [
base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious,
base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine,
base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony,
base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15,
tabs, model_tabs
]
send_to_illustrious_button.click(lambda img: send_info_to_tab(img, "Illustrious"), [info_image_input], all_ui_components)
send_to_animagine_button.click(lambda img: send_info_to_tab(img, "Animagine"), [info_image_input], all_ui_components)
send_to_pony_button.click(lambda img: send_info_to_tab(img, "Pony"), [info_image_input], all_ui_components)
send_to_sd15_button.click(lambda img: send_info_to_tab(img, "SD1.5"), [info_image_input], all_ui_components)
send_by_hash_button.click(send_info_by_hash, [info_image_input], all_ui_components)
if __name__ == "__main__":
demo.queue().launch()