Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import sys | |
from typing import Sequence, Mapping, Any, Union | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
import spaces | |
from comfy import model_management # We need to import this early | |
import gc | |
import requests | |
import re | |
import hashlib | |
import shutil | |
# --- Startup Dummy Function --- | |
def dummy_gpu_for_startup(): | |
print("Dummy function for startup check executed. This is normal.") | |
return "Startup check passed." | |
# --- ComfyUI Backend Setup --- | |
def find_path(name: str, path: str = None) -> str: | |
if path is None: path = os.getcwd() | |
if name in os.listdir(path): return os.path.join(path, name) | |
parent_directory = os.path.dirname(path) | |
if parent_directory == path: return None | |
return find_path(name, parent_directory) | |
def add_comfyui_directory_to_sys_path() -> None: | |
comfyui_path = find_path("ComfyUI") | |
if comfyui_path and os.path.isdir(comfyui_path): | |
sys.path.append(comfyui_path) | |
print(f"'{comfyui_path}' added to sys.path") | |
def add_extra_model_paths() -> None: | |
try: from main import load_extra_path_config | |
except ImportError: from utils.extra_config import load_extra_path_config | |
extra_model_paths = find_path("extra_model_paths.yaml") | |
if extra_model_paths: load_extra_path_config(extra_model_paths) | |
else: print("Could not find extra_model_paths.yaml") | |
add_comfyui_directory_to_sys_path() | |
add_extra_model_paths() | |
# Monkey-patch for Sage Attention | |
print("Attempting to monkey-patch ComfyUI for Sage Attention...") | |
try: | |
model_management.sage_attention_enabled = lambda: True | |
model_management.pytorch_attention_enabled = lambda: False | |
print("Successfully monkey-patched model_management for Sage Attention.") | |
except Exception as e: | |
print(f"An error occurred during monkey-patching: {e}") | |
# --- Constants & Configuration --- | |
CHECKPOINT_DIR = "models/checkpoints" | |
LORA_DIR = "models/loras" | |
os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
os.makedirs(LORA_DIR, exist_ok=True) | |
# --- Model Definitions with Hashes --- | |
# Format: {Display Name: (Repo ID, Filename, Type, Hash)} | |
MODEL_MAP_ILLUSTRIOUS = { | |
"Laxhar/noobai-XL-Vpred-1.0": ("Laxhar/noobai-XL-Vpred-1.0", "NoobAI-XL-Vpred-v1.0.safetensors", "SDXL", "ea349eeae8"), | |
"Laxhar/noobai-XL-1.1": ("Laxhar/noobai-XL-1.1", "NoobAI-XL-v1.1.safetensors", "SDXL", "6681e8e4b1"), | |
"WAI0731/wai-nsfw-illustrious-sdxl-v140": ("Ine007/waiNSFWIllustrious_v140", "waiNSFWIllustrious_v140.safetensors", "SDXL", "bdb59bac77"), | |
"Ikena/hassaku-xl-illustrious-v30": ("misri/hassakuXLIllustrious_v30", "hassakuXLIllustrious_v30.safetensors", "SDXL", "b4fb5f829a"), | |
"bluepen5805/noob_v_pencil-XL": ("bluepen5805/noob_v_pencil-XL", "noob_v_pencil-XL-v3.0.0.safetensors", "SDXL", "90b7911a78"), | |
"RedRayz/hikari_noob_v-pred_1.2.2": ("RedRayz/hikari_noob_v-pred_1.2.2", "Hikari_Noob_v-pred_1.2.2.safetensors", "SDXL", "874170688a"), | |
} | |
MODEL_MAP_ANIMAGINE = { | |
"cagliostrolab/animagine-xl-4.0": ("cagliostrolab/animagine-xl-4.0", "animagine-xl-4.0.safetensors", "SDXL", "6327eca98b"), | |
"cagliostrolab/animagine-xl-3.1": ("cagliostrolab/animagine-xl-3.1", "animagine-xl-3.1.safetensors", "SDXL", "e3c47aedb0"), | |
} | |
MODEL_MAP_PONY = { | |
"PurpleSmartAI/Pony_Diffusion_V6_XL": ("LyliaEngine/Pony_Diffusion_V6_XL", "ponyDiffusionV6XL_v6StartWithThisOne.safetensors", "SDXL", "67ab2fd8ec"), | |
} | |
MODEL_MAP_SD15 = { | |
"Yuno779/anything-v3": ("ckpt/anything-v3.0", "Anything-V3.0-pruned.safetensors", "SD1.5", "ddd565f806"), | |
} | |
# --- Combined Maps for Global Lookup --- | |
ALL_MODEL_MAP = {**MODEL_MAP_ILLUSTRIOUS, **MODEL_MAP_ANIMAGINE, **MODEL_MAP_PONY, **MODEL_MAP_SD15} | |
MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()} | |
DISPLAY_NAME_TO_HASH_MAP = {k: v[3] for k, v in ALL_MODEL_MAP.items()} | |
HASH_TO_DISPLAY_NAME_MAP = {v[3]: k for k, v in ALL_MODEL_MAP.items()} | |
# --- UI Defaults --- | |
DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn," | |
MAX_LORAS = 5 | |
LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"] | |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
try: return obj[index] | |
except (KeyError, IndexError): | |
try: return obj["result"][index] | |
except (KeyError, IndexError): return None | |
def import_custom_nodes() -> None: | |
import asyncio, execution, server | |
from nodes import init_extra_nodes | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
server_instance = server.PromptServer(loop) | |
execution.PromptQueue(server_instance) | |
loop.run_until_complete(init_extra_nodes()) | |
# --- Import ComfyUI Nodes & Get Choices --- | |
from nodes import CheckpointLoaderSimple, EmptyLatentImage, KSampler, VAEDecode, SaveImage, NODE_CLASS_MAPPINGS | |
import_custom_nodes() | |
CLIPTextEncodeSDXL = NODE_CLASS_MAPPINGS['CLIPTextEncodeSDXL'] | |
CLIPTextEncode = NODE_CLASS_MAPPINGS['CLIPTextEncode'] | |
LoraLoader = NODE_CLASS_MAPPINGS['LoraLoader'] | |
CLIPSetLastLayer = NODE_CLASS_MAPPINGS['CLIPSetLastLayer'] | |
try: | |
SAMPLER_CHOICES = KSampler.INPUT_TYPES()["required"]["sampler_name"][0] | |
SCHEDULER_CHOICES = KSampler.INPUT_TYPES()["required"]["scheduler"][0] | |
except Exception: | |
SAMPLER_CHOICES = ['euler', 'dpmpp_2m_sde_gpu'] | |
SCHEDULER_CHOICES = ['normal', 'karras'] | |
# --- Instantiate Node Objects --- | |
checkpointloadersimple = CheckpointLoaderSimple(); cliptextencodesdxl = CLIPTextEncodeSDXL() | |
cliptextencode_sd15 = CLIPTextEncode(); emptylatentimage = EmptyLatentImage() | |
ksampler = KSampler(); vaedecode = VAEDecode(); saveimage = SaveImage(); loraloader = LoraLoader() | |
clipsetlastlayer = CLIPSetLastLayer() | |
# --- LoRA & File Utils --- | |
def get_civitai_file_info(version_id): | |
api_url = f"https://civitai.com/api/v1/model-versions/{version_id}" | |
try: | |
response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json() | |
for file_data in data.get('files', []): | |
if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'): return file_data | |
if data.get('files'): return data['files'][0] | |
except Exception: return None | |
def get_tensorart_file_info(model_id): | |
api_url = f"https://tensor.art/api/v1/models/{model_id}" | |
try: | |
response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json() | |
model_versions = data.get('modelVersions', []) | |
if not model_versions: return None | |
for file_data in model_versions[0].get('files', []): | |
if file_data['name'].endswith('.safetensors'): return file_data | |
return model_versions[0]['files'][0] if model_versions[0].get('files') else None | |
except Exception: return None | |
def download_file(url, save_path, api_key=None, progress=None, desc=""): | |
if os.path.exists(save_path): return f"File already exists: {os.path.basename(save_path)}" | |
headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {} | |
try: | |
if progress: progress(0, desc=desc) | |
response = requests.get(url, stream=True, headers=headers, timeout=15); response.raise_for_status() | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(save_path, "wb") as f: | |
downloaded = 0 | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
if progress and total_size > 0: downloaded += len(chunk); progress(downloaded / total_size, desc=desc) | |
return f"Successfully downloaded: {os.path.basename(save_path)}" | |
except Exception as e: | |
if os.path.exists(save_path): os.remove(save_path) | |
return f"Download failed for {os.path.basename(save_path)}: {e}" | |
def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress): | |
if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided." | |
if source == "Civitai": | |
version_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"civitai_{version_id}.safetensors"); file_info, api_key_to_use = get_civitai_file_info(version_id), civitai_key; source_name = f"Civitai ID {version_id}" | |
elif source == "TensorArt": | |
model_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"tensorart_{model_id}.safetensors"); file_info, api_key_to_use = get_tensorart_file_info(model_id), tensorart_key; source_name = f"TensorArt ID {model_id}" | |
elif source == "Custom URL": | |
url = id_or_url.strip(); url_hash = hashlib.md5(url.encode()).hexdigest(); local_path = os.path.join(LORA_DIR, f"custom_{url_hash}.safetensors"); file_info, api_key_to_use = {'downloadUrl': url}, None; source_name = f"URL {url[:30]}..." | |
else: return None, "Invalid source." | |
if os.path.exists(local_path): return local_path, "File already exists." | |
if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}." | |
status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
return (local_path, status) if "Successfully" in status else (None, status) | |
def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)): | |
sources, ids, _, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4] | |
active_loras = [(s, i) for s, i, f in zip(sources, ids, files) if s in ["Civitai", "TensorArt", "Custom URL"] and i and i.strip() and f is None] | |
if not active_loras: return "No remote LoRAs specified for pre-downloading." | |
log = [f"* {s} ID {i}: {get_lora_path(s, i, civitai_api_key, tensorart_api_key, progress)[1]}" for s, i in active_loras] | |
return "\n".join(log) | |
# --- Model Management & Core Logic --- | |
current_loaded_model_name = None; loaded_checkpoint_tuple = None | |
def load_model(model_display_name: str, progress=gr.Progress()): | |
global current_loaded_model_name, loaded_checkpoint_tuple | |
if model_display_name == current_loaded_model_name and loaded_checkpoint_tuple: return loaded_checkpoint_tuple | |
if loaded_checkpoint_tuple: model_management.unload_all_models(); loaded_checkpoint_tuple = None; gc.collect(); torch.cuda.empty_cache() | |
repo_id, filename, _, _ = ALL_MODEL_MAP[model_display_name] | |
local_file_path = os.path.join(CHECKPOINT_DIR, filename) | |
if not os.path.exists(local_file_path): | |
progress(0, desc=f"Downloading model: {model_display_name}") | |
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False) | |
progress(0.5, desc=f"Loading '{filename}'") | |
MODEL_TUPLE = checkpointloadersimple.load_checkpoint(ckpt_name=filename) | |
model_management.load_models_gpu([get_value_at_index(MODEL_TUPLE, 0)]) | |
current_loaded_model_name = model_display_name; loaded_checkpoint_tuple = MODEL_TUPLE | |
progress(1.0, desc="Model loaded"); return loaded_checkpoint_tuple | |
def _generate_image_logic(model_display_name: str, positive_prompt: str, negative_prompt: str, | |
seed: int, batch_size: int, width: int, height: int, guidance_scale: float, num_inference_steps: int, | |
sampler_name: str, scheduler: str, civitai_api_key: str, tensorart_api_key: str, *lora_data, | |
progress=gr.Progress(track_tqdm=True)): | |
output_images = [] | |
is_sd15 = MODEL_TYPE_MAP.get(model_display_name) == "SD1.5" | |
clip_skip = 1 | |
if is_sd15 and len(lora_data) > MAX_LORAS * 4: | |
clip_skip = int(lora_data[-1]) | |
lora_data = lora_data[:-1] | |
with torch.inference_mode(): | |
model_tuple = load_model(model_display_name, progress) | |
model, clip, vae = (get_value_at_index(model_tuple, i) for i in range(3)) | |
if is_sd15: | |
clip = get_value_at_index(clipsetlastlayer.set_last_layer(clip=clip, stop_at_clip_layer=-clip_skip), 0) | |
active_loras_for_meta = [] | |
sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4] | |
for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)): | |
if scale > 0: | |
lora_filename = None | |
if custom_file: | |
lora_filename = os.path.basename(custom_file.name) | |
shutil.copy(custom_file.name, LORA_DIR) | |
elif lora_id and lora_id.strip(): | |
local_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress) | |
if local_path: lora_filename = os.path.basename(local_path) | |
if lora_filename: | |
lora_tuple = loraloader.load_lora(model=model, clip=clip, lora_name=lora_filename, strength_model=scale, strength_clip=scale) | |
model, clip = get_value_at_index(lora_tuple, 0), get_value_at_index(lora_tuple, 1) | |
active_loras_for_meta.append(f"{source} {lora_id}:{scale}") | |
loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else "" | |
if is_sd15: | |
pos_cond = cliptextencode_sd15.encode(text=positive_prompt, clip=clip) | |
neg_cond = cliptextencode_sd15.encode(text=negative_prompt, clip=clip) | |
else: | |
pos_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=positive_prompt, text_l=positive_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0) | |
neg_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=negative_prompt, text_l=negative_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0) | |
start_seed = seed if seed != -1 else random.randint(0, 2**64 - 1) | |
latent = emptylatentimage.generate(width=width, height=height, batch_size=batch_size) | |
sampled = ksampler.sample( | |
seed=start_seed, | |
steps=num_inference_steps, | |
cfg=guidance_scale, | |
sampler_name=sampler_name, | |
scheduler=scheduler, | |
denoise=1.0, | |
model=model, | |
positive=get_value_at_index(pos_cond, 0), | |
negative=get_value_at_index(neg_cond, 0), | |
latent_image=get_value_at_index(latent, 0) | |
) | |
decoded_images_tensor = get_value_at_index(vaedecode.decode(samples=get_value_at_index(sampled, 0), vae=vae), 0) | |
for i in range(decoded_images_tensor.shape[0]): | |
img_tensor = decoded_images_tensor[i] | |
pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8")) | |
current_seed = start_seed + i | |
model_hash = DISPLAY_NAME_TO_HASH_MAP.get(model_display_name, "N/A") | |
params_string = f"{positive_prompt}\nNegative prompt: {negative_prompt}\n" | |
params_string += f"Steps: {num_inference_steps}, Sampler: {sampler_name}, Scheduler: {scheduler}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {model_display_name}, Model hash: {model_hash}" | |
if is_sd15: params_string += f", Clip skip: {clip_skip}" | |
params_string += f", {loras_string}" | |
pil_image.info = {'parameters': params_string.strip()} | |
output_images.append(pil_image) | |
return output_images | |
def generate_image_wrapper(*args, **kwargs): | |
logic_args_list = list(args[:11]) | |
zero_gpu_duration = args[11] | |
logic_args_list.extend(args[12:]) | |
duration = 60 | |
try: | |
if zero_gpu_duration and int(zero_gpu_duration) > 0: | |
duration = int(zero_gpu_duration) | |
except (ValueError, TypeError): | |
pass | |
return spaces.GPU(duration=duration)(_generate_image_logic)(*logic_args_list, **kwargs) | |
# --- PNG Info & UI Logic --- | |
def _parse_parameters(params_text): | |
data = {}; lines = params_text.strip().split('\n'); data['prompt'] = lines[0] | |
data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else "" | |
params_line = '\n'.join(lines[2:]) | |
def find_param(key, default, cast_type=str): | |
match = re.search(fr"\b{key}: ([^,]+?)(,|$|\n)", params_line) | |
return cast_type(match.group(1).strip()) if match else default | |
data['steps'] = find_param("Steps", 28, int); data['sampler'] = find_param("Sampler", SAMPLER_CHOICES[0], str) | |
data['scheduler'] = find_param("Scheduler", SCHEDULER_CHOICES[0], str); data['cfg_scale'] = find_param("CFG scale", 7.5, float) | |
data['seed'] = find_param("Seed", -1, int); data['clip_skip'] = find_param("Clip skip", 1, int) | |
data['base_model'] = find_param("Base Model", list(ALL_MODEL_MAP.keys())[0], str); data['model_hash'] = find_param("Model hash", None, str) | |
size_match = re.search(r"Size: (\d+)x(\d+)", params_line) | |
data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024) | |
return data | |
def get_png_info(image): | |
if not image or not (params := image.info.get('parameters')): return "", "", "No metadata found in the image." | |
parsed_data = _parse_parameters(params) | |
other_params_text = "\n".join([p.strip() for p in '\n'.join(params.strip().split('\n')[2:]).split(',')]) | |
return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_text | |
def apply_data_to_ui(data, target_tab): | |
final_sampler = data.get('sampler') if data.get('sampler') in SAMPLER_CHOICES else SAMPLER_CHOICES[0] | |
default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0] | |
final_scheduler = data.get('scheduler') if data.get('scheduler') in SCHEDULER_CHOICES else default_scheduler | |
updates = {} | |
base_model_name = data.get('base_model') | |
if target_tab == "Illustrious": | |
if base_model_name in MODEL_MAP_ILLUSTRIOUS: | |
updates.update({base_model_name_input_illustrious: base_model_name}) | |
updates.update({prompt_illustrious: data['prompt'], negative_prompt_illustrious: data['negative_prompt'], seed_illustrious: data['seed'], width_illustrious: data['width'], height_illustrious: data['height'], guidance_scale_illustrious: data['cfg_scale'], num_inference_steps_illustrious: data['steps'], sampler_illustrious: final_sampler, schedule_type_illustrious: final_scheduler, model_tabs: gr.Tabs(selected=0)}) | |
elif target_tab == "Animagine": | |
if base_model_name in MODEL_MAP_ANIMAGINE: | |
updates.update({base_model_name_input_animagine: base_model_name}) | |
updates.update({prompt_animagine: data['prompt'], negative_prompt_animagine: data['negative_prompt'], seed_animagine: data['seed'], width_animagine: data['width'], height_animagine: data['height'], guidance_scale_animagine: data['cfg_scale'], num_inference_steps_animagine: data['steps'], sampler_animagine: final_sampler, schedule_type_animagine: final_scheduler, model_tabs: gr.Tabs(selected=1)}) | |
elif target_tab == "Pony": | |
if base_model_name in MODEL_MAP_PONY: | |
updates.update({base_model_name_input_pony: base_model_name}) | |
updates.update({prompt_pony: data['prompt'], negative_prompt_pony: data['negative_prompt'], seed_pony: data['seed'], width_pony: data['width'], height_pony: data['height'], guidance_scale_pony: data['cfg_scale'], num_inference_steps_pony: data['steps'], sampler_pony: final_sampler, schedule_type_pony: final_scheduler, model_tabs: gr.Tabs(selected=2)}) | |
elif target_tab == "SD1.5": | |
if base_model_name in MODEL_MAP_SD15: | |
updates.update({base_model_name_input_sd15: base_model_name}) | |
updates.update({prompt_sd15: data['prompt'], negative_prompt_sd15: data['negative_prompt'], seed_sd15: data['seed'], width_sd15: data['width'], height_sd15: data['height'], guidance_scale_sd15: data['cfg_scale'], num_inference_steps_sd15: data['steps'], sampler_sd15: final_sampler, schedule_type_sd15: final_scheduler, clip_skip_sd15: data.get('clip_skip', 1), model_tabs: gr.Tabs(selected=3)}) | |
updates[tabs] = gr.Tabs(selected=0) | |
return updates | |
def send_info_to_tab(image, target_tab): | |
if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components} | |
data = _parse_parameters(image.info['parameters']) | |
return apply_data_to_ui(data, target_tab) | |
def send_info_by_hash(image): | |
if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components} | |
data = _parse_parameters(image.info['parameters']) | |
model_hash = data.get('model_hash') | |
display_name = HASH_TO_DISPLAY_NAME_MAP.get(model_hash) | |
if not display_name: | |
raise gr.Error("Model hash not found in this app's model list. The original model name from the PNG will be used if it exists in the target tab.") | |
if display_name in MODEL_MAP_ILLUSTRIOUS: target_tab = "Illustrious" | |
elif display_name in MODEL_MAP_ANIMAGINE: target_tab = "Animagine" | |
elif display_name in MODEL_MAP_PONY: target_tab = "Pony" | |
elif display_name in MODEL_MAP_SD15: target_tab = "SD1.5" | |
else: | |
raise gr.Error("Cannot determine the correct tab for this model.") | |
data['base_model'] = display_name | |
return apply_data_to_ui(data, target_tab) | |
# --- UI Generation Functions --- | |
def create_lora_settings_ui(): | |
with gr.Accordion("LoRA Settings", open=False): | |
gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.") | |
gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.") | |
with gr.Row(): | |
civitai_api_key = gr.Textbox(label="Civitai API Key", placeholder="Enter your Civitai API Key", type="password", scale=1) | |
tensorart_api_key = gr.Textbox(label="TensorArt API Key", placeholder="Enter your TensorArt API Key", type="password", scale=1) | |
gr.Markdown("---") | |
gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.") | |
gr.Markdown(""" | |
<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-top: 10px; margin-bottom: 15px;'> | |
<b>Input Examples:</b> | |
<ul> | |
<li><b>Civitai:</b> Enter the <b>Model Version ID</b>, not the Model ID. Example: <code>133755</code> (Found in the URL, e.g., <code>civitai.com/models/122136?modelVersionId=<b>133755</b></code>)</li> | |
<li><b>TensorArt:</b> Enter the <b>Model ID</b>. Example: <code>706684852832599558</code> (Found in the URL, e.g., <code>tensor.art/models/<b>706684852832599558</b></code>)</li> | |
<li><b>Custom URL:</b> Provide a direct download link to a <code>.safetensors</code> file. Example: <code>https://huggingface.co/path/to/your/lora.safetensors</code></li> | |
<li><b>File:</b> Use the "Upload" button. The source will be set automatically.</li> | |
</ul> | |
</div> | |
""") | |
gr.Markdown(""" | |
<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'> | |
<b>Notice:</b> | |
<ul style='margin-bottom: 0;'> | |
<li>With Gradio, the page may become unresponsive until a file is fully uploaded. Please be patient and wait for the process to complete.</li> | |
</ul> | |
</div> | |
""") | |
lora_rows, sources, ids, scales, uploads = [], [], [], [], [] | |
for i in range(MAX_LORAS): | |
with gr.Row(visible=(i == 0)) as row: | |
source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai", scale=1) | |
lora_id = gr.Textbox(label="ID / URL / File", placeholder="e.g.: 133755", scale=2) | |
scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0, scale=2) | |
upload = gr.UploadButton("Upload", file_types=[".safetensors"], scale=1) | |
lora_rows.append(row); sources.append(source); ids.append(lora_id); scales.append(scale); uploads.append(upload) | |
upload.upload(fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()), inputs=[upload], outputs=[lora_id, source]) | |
with gr.Row(): add_button = gr.Button("✚ Add LoRA"); delete_button = gr.Button("➖ Delete LoRA", visible=False) | |
count_state = gr.State(value=1) | |
all_components = [item for sublist in zip(sources, ids, scales, uploads) for item in sublist] | |
return (civitai_api_key, tensorart_api_key, lora_rows, sources, ids, scales, uploads, add_button, delete_button, count_state, all_components) | |
def download_all_models_on_startup(): | |
"""Downloads all base models listed in ALL_MODEL_MAP when the app starts.""" | |
print("--- Starting pre-download of all base models ---") | |
for model_display_name, model_info in ALL_MODEL_MAP.items(): | |
repo_id, filename, _, _ = model_info | |
local_file_path = os.path.join(CHECKPOINT_DIR, filename) | |
if os.path.exists(local_file_path): | |
print(f"✅ Model '{filename}' already exists. Skipping download.") | |
continue | |
try: | |
print(f"Downloading: {model_display_name} ({filename})...") | |
hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
local_dir=CHECKPOINT_DIR, | |
local_dir_use_symlinks=False | |
) | |
print(f"✅ Successfully downloaded {filename}.") | |
except Exception as e: | |
print(f"❌ Failed to download {filename} from {repo_id}: {e}") | |
print("--- Finished pre-downloading all base models ---") | |
# --- Execute model download on startup --- | |
download_all_models_on_startup() | |
# --- Gradio UI --- | |
with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo: | |
gr.Markdown("# Animated T2I with LoRAs") | |
with gr.Tabs(elem_id="tabs_container") as tabs: | |
with gr.TabItem("txt2img", id=0): | |
with gr.Tabs() as model_tabs: | |
for tab_name, model_map, defaults in [ | |
("Illustrious", MODEL_MAP_ILLUSTRIOUS, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}), | |
("Animagine", MODEL_MAP_ANIMAGINE, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}), | |
("Pony", MODEL_MAP_PONY, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}), | |
("SD1.5", MODEL_MAP_SD15, {'w': 512, 'h': 768, 'cs_vis': True, 'cs_val': 1}) | |
]: | |
with gr.TabItem(tab_name): | |
gr.Markdown("💡 **Tip:** Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.") | |
with gr.Column(): | |
with gr.Row(): | |
base_model = gr.Dropdown(label="Base Model", choices=list(model_map.keys()), value=list(model_map.keys())[0], scale=3) | |
with gr.Column(scale=1): predownload_lora = gr.Button("Pre-download LoRAs"); run = gr.Button("Run", variant="primary") | |
predownload_status = gr.Markdown("") | |
prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt") | |
neg_prompt = gr.Text(label="Negative prompt", lines=3, value=DEFAULT_NEGATIVE_PROMPT) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=defaults['w']); height = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=defaults['h']) | |
with gr.Row(): | |
sampler = gr.Dropdown(label="Sampling method", choices=SAMPLER_CHOICES, value=SAMPLER_CHOICES[0]) | |
default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0] | |
scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value=default_scheduler) | |
with gr.Row(): cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5); steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28) | |
with gr.Column(scale=1): result = gr.Gallery(label="Result", show_label=False, columns=2, object_fit="contain", height="auto") | |
with gr.Row(): | |
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1) | |
clip_skip = gr.Slider(label="Clip Skip", minimum=1, maximum=2, step=1, value=defaults['cs_val'], visible=defaults['cs_vis']) | |
zero_gpu = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120") | |
lora_settings = create_lora_settings_ui() | |
# Assign specific variables for event handlers | |
if tab_name == "Illustrious": | |
base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, result_illustrious = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result | |
civitai_api_key_illustrious, tensorart_api_key_illustrious, lora_rows_illustrious, _, lora_id_inputs_illustrious, lora_scale_inputs_illustrious, _, add_lora_button_illustrious, delete_lora_button_illustrious, lora_count_state_illustrious, all_lora_components_flat_illustrious = lora_settings | |
predownload_lora_button_illustrious, run_button_illustrious, predownload_status_illustrious = predownload_lora, run, predownload_status | |
elif tab_name == "Animagine": | |
base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, result_animagine = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result | |
civitai_api_key_animagine, tensorart_api_key_animagine, lora_rows_animagine, _, lora_id_inputs_animagine, lora_scale_inputs_animagine, _, add_lora_button_animagine, delete_lora_button_animagine, lora_count_state_animagine, all_lora_components_flat_animagine = lora_settings | |
predownload_lora_button_animagine, run_button_animagine, predownload_status_animagine = predownload_lora, run, predownload_status | |
elif tab_name == "Pony": | |
base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, result_pony = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result | |
civitai_api_key_pony, tensorart_api_key_pony, lora_rows_pony, _, lora_id_inputs_pony, lora_scale_inputs_pony, _, add_lora_button_pony, delete_lora_button_pony, lora_count_state_pony, all_lora_components_flat_pony = lora_settings | |
predownload_lora_button_pony, run_button_pony, predownload_status_pony = predownload_lora, run, predownload_status | |
elif tab_name == "SD1.5": | |
base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, zero_gpu_duration_sd15, result_sd15 = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, clip_skip, zero_gpu, result | |
civitai_api_key_sd15, tensorart_api_key_sd15, lora_rows_sd15, _, lora_id_inputs_sd15, lora_scale_inputs_sd15, _, add_lora_button_sd15, delete_lora_button_sd15, lora_count_state_sd15, all_lora_components_flat_sd15 = lora_settings | |
predownload_lora_button_sd15, run_button_sd15, predownload_status_sd15 = predownload_lora, run, predownload_status | |
with gr.TabItem("PNG Info", id=1): | |
with gr.Column(): | |
info_image_input = gr.Image(type="pil", label="Upload Image", height=512) | |
with gr.Row(): | |
info_get_button = gr.Button("Get Info") | |
send_by_hash_button = gr.Button("Send to txt2img by Model Hash", variant="primary") | |
with gr.Row(): | |
send_to_illustrious_button = gr.Button("Send to Illustrious") | |
send_to_animagine_button = gr.Button("Send to Animagine") | |
send_to_pony_button = gr.Button("Send to Pony") | |
send_to_sd15_button = gr.Button("Send to SD1.5") | |
gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False) | |
gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False) | |
gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False) | |
gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>") | |
# --- Event Handlers --- | |
def create_lora_event_handlers(lora_rows, count_state, add_button, del_button, lora_ids, lora_scales): | |
def add_lora_row(c): return {count_state: c+1, lora_rows[c]: gr.update(visible=True), del_button: gr.update(visible=True), add_button: gr.update(visible=c+1 < MAX_LORAS)} | |
def del_lora_row(c): c-=1; return {count_state: c, lora_rows[c]: gr.update(visible=False), lora_ids[c]: "", lora_scales[c]: 0.0, add_button: gr.update(visible=True), del_button: gr.update(visible=c > 1)} | |
add_button.click(add_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows]) | |
del_button.click(del_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows, *lora_ids, *lora_scales]) | |
create_lora_event_handlers(lora_rows_illustrious, lora_count_state_illustrious, add_lora_button_illustrious, delete_lora_button_illustrious, lora_id_inputs_illustrious, lora_scale_inputs_illustrious) | |
predownload_lora_button_illustrious.click(lambda: "⏳ Downloading...", None, [predownload_status_illustrious]).then(pre_download_loras, [civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [predownload_status_illustrious]) | |
run_button_illustrious.click(generate_image_wrapper, [base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [result_illustrious]) | |
create_lora_event_handlers(lora_rows_animagine, lora_count_state_animagine, add_lora_button_animagine, delete_lora_button_animagine, lora_id_inputs_animagine, lora_scale_inputs_animagine) | |
predownload_lora_button_animagine.click(lambda: "⏳ Downloading...", None, [predownload_status_animagine]).then(pre_download_loras, [civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [predownload_status_animagine]) | |
run_button_animagine.click(generate_image_wrapper, [base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [result_animagine]) | |
create_lora_event_handlers(lora_rows_pony, lora_count_state_pony, add_lora_button_pony, delete_lora_button_pony, lora_id_inputs_pony, lora_scale_inputs_pony) | |
predownload_lora_button_pony.click(lambda: "⏳ Downloading...", None, [predownload_status_pony]).then(pre_download_loras, [civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [predownload_status_pony]) | |
run_button_pony.click(generate_image_wrapper, [base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [result_pony]) | |
create_lora_event_handlers(lora_rows_sd15, lora_count_state_sd15, add_lora_button_sd15, delete_lora_button_sd15, lora_id_inputs_sd15, lora_scale_inputs_sd15) | |
predownload_lora_button_sd15.click(lambda: "⏳ Downloading...", None, [predownload_status_sd15]).then(pre_download_loras, [civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15], [predownload_status_sd15]) | |
run_button_sd15.click(generate_image_wrapper, [base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, zero_gpu_duration_sd15, civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15, clip_skip_sd15], [result_sd15]) | |
info_get_button.click(get_png_info, [info_image_input], [info_prompt_output, info_neg_prompt_output, info_params_output]) | |
all_ui_components = [ | |
base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, | |
base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, | |
base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, | |
base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, | |
tabs, model_tabs | |
] | |
send_to_illustrious_button.click(lambda img: send_info_to_tab(img, "Illustrious"), [info_image_input], all_ui_components) | |
send_to_animagine_button.click(lambda img: send_info_to_tab(img, "Animagine"), [info_image_input], all_ui_components) | |
send_to_pony_button.click(lambda img: send_info_to_tab(img, "Pony"), [info_image_input], all_ui_components) | |
send_to_sd15_button.click(lambda img: send_info_to_tab(img, "SD1.5"), [info_image_input], all_ui_components) | |
send_by_hash_button.click(send_info_by_hash, [info_image_input], all_ui_components) | |
if __name__ == "__main__": | |
demo.queue().launch() |