Spaces:
Configuration error
Configuration error
migrating to zero gpu
Browse files- .gitattributes +35 -0
- README.md +1 -3
- app.py +181 -625
- config.py +105 -0
- lora.toml +0 -28
- lora_diffusers.py +0 -478
- requirements.txt +6 -7
- style.css +4 -30
- utils.py +173 -1
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -4,13 +4,11 @@ emoji: 🌍
|
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
app_file: app.py
|
| 9 |
license: mit
|
| 10 |
pinned: false
|
| 11 |
suggested_hardware: a10g-small
|
| 12 |
-
duplicated_from: hysts/SD-XL
|
| 13 |
-
hf_oauth: true
|
| 14 |
---
|
| 15 |
|
| 16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 4 |
colorFrom: gray
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.20.0
|
| 8 |
app_file: app.py
|
| 9 |
license: mit
|
| 10 |
pinned: false
|
| 11 |
suggested_hardware: a10g-small
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,244 +1,71 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
import os
|
| 6 |
-
import random
|
| 7 |
import gc
|
| 8 |
-
import toml
|
| 9 |
import gradio as gr
|
| 10 |
import numpy as np
|
| 11 |
-
import utils
|
| 12 |
import torch
|
| 13 |
import json
|
| 14 |
-
import
|
| 15 |
-
import
|
| 16 |
-
import
|
| 17 |
-
|
| 18 |
-
from
|
| 19 |
from datetime import datetime
|
| 20 |
-
from PIL import PngImagePlugin
|
| 21 |
-
import gradio_user_history as gr_user_history
|
| 22 |
-
from huggingface_hub import hf_hub_download
|
| 23 |
-
from safetensors.torch import load_file
|
| 24 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 25 |
-
from lora_diffusers import LoRANetwork, create_network_from_weights
|
| 26 |
from diffusers.models import AutoencoderKL
|
| 27 |
-
from diffusers import
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
DPMSolverSinglestepScheduler,
|
| 32 |
-
KDPM2DiscreteScheduler,
|
| 33 |
-
EulerDiscreteScheduler,
|
| 34 |
-
EulerAncestralDiscreteScheduler,
|
| 35 |
-
HeunDiscreteScheduler,
|
| 36 |
-
LMSDiscreteScheduler,
|
| 37 |
-
DDIMScheduler,
|
| 38 |
-
DEISMultistepScheduler,
|
| 39 |
-
UniPCMultistepScheduler,
|
| 40 |
-
)
|
| 41 |
|
| 42 |
DESCRIPTION = "Animagine XL 3.0"
|
| 43 |
if not torch.cuda.is_available():
|
| 44 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
|
| 45 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
| 46 |
-
MAX_SEED = np.iinfo(np.int32).max
|
| 47 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 48 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
| 49 |
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
|
| 50 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
|
| 51 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
|
| 52 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
|
|
|
| 53 |
|
| 54 |
-
MODEL = os.getenv(
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
torch.backends.cudnn.deterministic = True
|
| 57 |
torch.backends.cudnn.benchmark = False
|
| 58 |
|
| 59 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 60 |
|
| 61 |
-
|
|
|
|
| 62 |
vae = AutoencoderKL.from_pretrained(
|
| 63 |
"madebyollin/sdxl-vae-fp16-fix",
|
| 64 |
torch_dtype=torch.float16,
|
| 65 |
)
|
| 66 |
-
pipeline =
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
pipe = pipeline(
|
| 69 |
-
|
| 70 |
vae=vae,
|
| 71 |
torch_dtype=torch.float16,
|
| 72 |
custom_pipeline="lpw_stable_diffusion_xl",
|
| 73 |
use_safetensors=True,
|
|
|
|
| 74 |
use_auth_token=HF_TOKEN,
|
| 75 |
variant="fp16",
|
| 76 |
)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
else:
|
| 81 |
-
pipe.to(device)
|
| 82 |
-
if USE_TORCH_COMPILE:
|
| 83 |
-
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 84 |
-
else:
|
| 85 |
-
pipe = None
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 89 |
-
if randomize_seed:
|
| 90 |
-
seed = random.randint(0, MAX_SEED)
|
| 91 |
-
return seed
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def seed_everything(seed):
|
| 95 |
-
torch.manual_seed(seed)
|
| 96 |
-
torch.cuda.manual_seed_all(seed)
|
| 97 |
-
np.random.seed(seed)
|
| 98 |
-
generator = torch.Generator()
|
| 99 |
-
generator.manual_seed(seed)
|
| 100 |
-
return generator
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def get_image_path(base_path: str):
|
| 104 |
-
extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
|
| 105 |
-
for ext in extensions:
|
| 106 |
-
image_path = base_path + ext
|
| 107 |
-
if os.path.exists(image_path):
|
| 108 |
-
return image_path
|
| 109 |
-
return None
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def update_selection(selected_state: gr.SelectData):
|
| 113 |
-
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
| 114 |
-
lora_weight = sdxl_loras[selected_state.index]["multiplier"]
|
| 115 |
-
updated_selected_info = f"{lora_repo}"
|
| 116 |
-
|
| 117 |
-
return (
|
| 118 |
-
updated_selected_info,
|
| 119 |
-
selected_state,
|
| 120 |
-
lora_weight,
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def parse_aspect_ratio(aspect_ratio):
|
| 125 |
-
if aspect_ratio == "Custom":
|
| 126 |
-
return None, None
|
| 127 |
-
width, height = aspect_ratio.split(" x ")
|
| 128 |
-
return int(width), int(height)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def aspect_ratio_handler(aspect_ratio, custom_width, custom_height):
|
| 132 |
-
if aspect_ratio == "Custom":
|
| 133 |
-
return custom_width, custom_height
|
| 134 |
-
else:
|
| 135 |
-
width, height = parse_aspect_ratio(aspect_ratio)
|
| 136 |
-
return width, height
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def create_network(text_encoders, unet, state_dict, multiplier, device):
|
| 140 |
-
network = create_network_from_weights(
|
| 141 |
-
text_encoders,
|
| 142 |
-
unet,
|
| 143 |
-
state_dict,
|
| 144 |
-
multiplier,
|
| 145 |
-
)
|
| 146 |
-
network.load_state_dict(state_dict)
|
| 147 |
-
network.to(device, dtype=unet.dtype)
|
| 148 |
-
network.apply_to(multiplier=multiplier)
|
| 149 |
-
|
| 150 |
-
return network
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def get_scheduler(scheduler_config, name):
|
| 154 |
-
scheduler_map = {
|
| 155 |
-
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
| 156 |
-
scheduler_config, use_karras_sigmas=True
|
| 157 |
-
),
|
| 158 |
-
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
|
| 159 |
-
scheduler_config, use_karras_sigmas=True
|
| 160 |
-
),
|
| 161 |
-
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
| 162 |
-
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
|
| 163 |
-
),
|
| 164 |
-
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
|
| 165 |
-
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
|
| 166 |
-
scheduler_config
|
| 167 |
-
),
|
| 168 |
-
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
| 169 |
-
}
|
| 170 |
-
return scheduler_map.get(name, lambda: None)()
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def free_memory():
|
| 174 |
-
torch.cuda.empty_cache()
|
| 175 |
-
gc.collect()
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def preprocess_prompt(
|
| 179 |
-
style_dict,
|
| 180 |
-
style_name: str,
|
| 181 |
-
positive: str,
|
| 182 |
-
negative: str = "",
|
| 183 |
-
add_style: bool = True,
|
| 184 |
-
) -> Tuple[str, str]:
|
| 185 |
-
p, n = style_dict.get(style_name, style_dict["(None)"])
|
| 186 |
-
|
| 187 |
-
if add_style and positive.strip():
|
| 188 |
-
formatted_positive = p.format(prompt=positive)
|
| 189 |
-
else:
|
| 190 |
-
formatted_positive = positive
|
| 191 |
-
|
| 192 |
-
combined_negative = n + negative
|
| 193 |
-
return formatted_positive, combined_negative
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def common_upscale(samples, width, height, upscale_method):
|
| 197 |
-
return torch.nn.functional.interpolate(
|
| 198 |
-
samples, size=(height, width), mode=upscale_method
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def upscale(samples, upscale_method, scale_by):
|
| 203 |
-
width = round(samples.shape[3] * scale_by)
|
| 204 |
-
height = round(samples.shape[2] * scale_by)
|
| 205 |
-
s = common_upscale(samples, width, height, upscale_method)
|
| 206 |
-
return s
|
| 207 |
-
|
| 208 |
|
| 209 |
-
def load_and_convert_thumbnail(model_path: str):
|
| 210 |
-
with safetensors.safe_open(model_path, framework="pt") as f:
|
| 211 |
-
metadata = f.metadata()
|
| 212 |
-
if "modelspec.thumbnail" in metadata:
|
| 213 |
-
base64_data = metadata["modelspec.thumbnail"]
|
| 214 |
-
prefix, encoded = base64_data.split(",", 1)
|
| 215 |
-
image_data = base64.b64decode(encoded)
|
| 216 |
-
image = PIL.Image.open(BytesIO(image_data))
|
| 217 |
-
return image
|
| 218 |
-
return None
|
| 219 |
-
|
| 220 |
-
def load_wildcard_files(wildcard_dir):
|
| 221 |
-
wildcard_files = {}
|
| 222 |
-
for file in os.listdir(wildcard_dir):
|
| 223 |
-
if file.endswith(".txt"):
|
| 224 |
-
key = f"__{file.split('.')[0]}__" # Create a key like __character__
|
| 225 |
-
wildcard_files[key] = os.path.join(wildcard_dir, file)
|
| 226 |
-
return wildcard_files
|
| 227 |
-
|
| 228 |
-
def get_random_line_from_file(file_path):
|
| 229 |
-
with open(file_path, 'r') as file:
|
| 230 |
-
lines = file.readlines()
|
| 231 |
-
if not lines:
|
| 232 |
-
return ""
|
| 233 |
-
return random.choice(lines).strip()
|
| 234 |
-
|
| 235 |
-
def add_wildcard(prompt, wildcard_files):
|
| 236 |
-
for key, file_path in wildcard_files.items():
|
| 237 |
-
if key in prompt:
|
| 238 |
-
wildcard_line = get_random_line_from_file(file_path)
|
| 239 |
-
prompt = prompt.replace(key, wildcard_line)
|
| 240 |
-
return prompt
|
| 241 |
|
|
|
|
| 242 |
def generate(
|
| 243 |
prompt: str,
|
| 244 |
negative_prompt: str = "",
|
|
@@ -247,90 +74,40 @@ def generate(
|
|
| 247 |
custom_height: int = 1024,
|
| 248 |
guidance_scale: float = 7.0,
|
| 249 |
num_inference_steps: int = 28,
|
| 250 |
-
use_lora: bool = False,
|
| 251 |
-
lora_weight: float = 1.0,
|
| 252 |
-
selected_state: str = "",
|
| 253 |
sampler: str = "Euler a",
|
| 254 |
aspect_ratio_selector: str = "896 x 1152",
|
| 255 |
style_selector: str = "(None)",
|
| 256 |
quality_selector: str = "Standard",
|
| 257 |
use_upscaler: bool = False,
|
| 258 |
-
upscaler_strength: float = 0.
|
| 259 |
upscale_by: float = 1.5,
|
| 260 |
add_quality_tags: bool = True,
|
| 261 |
-
profile: gr.OAuthProfile | None = None,
|
| 262 |
progress=gr.Progress(track_tqdm=True),
|
| 263 |
-
) ->
|
| 264 |
-
generator = seed_everything(seed)
|
| 265 |
|
| 266 |
-
|
| 267 |
-
network_state = {"current_lora": None, "multiplier": None}
|
| 268 |
-
|
| 269 |
-
width, height = aspect_ratio_handler(
|
| 270 |
aspect_ratio_selector,
|
| 271 |
custom_width,
|
| 272 |
custom_height,
|
| 273 |
)
|
| 274 |
|
| 275 |
-
prompt = add_wildcard(prompt, wildcard_files)
|
| 276 |
|
| 277 |
-
|
| 278 |
-
prompt, negative_prompt = preprocess_prompt(
|
| 279 |
quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
|
| 280 |
)
|
| 281 |
-
prompt, negative_prompt = preprocess_prompt(
|
| 282 |
styles, style_selector, prompt, negative_prompt
|
| 283 |
)
|
| 284 |
|
| 285 |
-
|
| 286 |
-
width = width - (width % 8)
|
| 287 |
-
if height % 8 != 0:
|
| 288 |
-
height = height - (height % 8)
|
| 289 |
-
|
| 290 |
-
if use_lora:
|
| 291 |
-
if not selected_state:
|
| 292 |
-
raise Exception("You must Select a LoRA")
|
| 293 |
-
repo_name = sdxl_loras[selected_state.index]["repo"]
|
| 294 |
-
full_path_lora = saved_names[selected_state.index]
|
| 295 |
-
weight_name = sdxl_loras[selected_state.index]["weights"]
|
| 296 |
-
|
| 297 |
-
lora_sd = load_file(full_path_lora)
|
| 298 |
-
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
| 299 |
-
|
| 300 |
-
if network_state["current_lora"] != repo_name:
|
| 301 |
-
network = create_network(
|
| 302 |
-
text_encoders,
|
| 303 |
-
pipe.unet,
|
| 304 |
-
lora_sd,
|
| 305 |
-
lora_weight,
|
| 306 |
-
device,
|
| 307 |
-
)
|
| 308 |
-
network_state["current_lora"] = repo_name
|
| 309 |
-
network_state["multiplier"] = lora_weight
|
| 310 |
-
elif network_state["multiplier"] != lora_weight:
|
| 311 |
-
network = create_network(
|
| 312 |
-
text_encoders,
|
| 313 |
-
pipe.unet,
|
| 314 |
-
lora_sd,
|
| 315 |
-
lora_weight,
|
| 316 |
-
device,
|
| 317 |
-
)
|
| 318 |
-
network_state["multiplier"] = lora_weight
|
| 319 |
-
else:
|
| 320 |
-
if network:
|
| 321 |
-
network.unapply_to()
|
| 322 |
-
network = None
|
| 323 |
-
network_state = {
|
| 324 |
-
"current_lora": None,
|
| 325 |
-
"multiplier": None,
|
| 326 |
-
}
|
| 327 |
|
| 328 |
backup_scheduler = pipe.scheduler
|
| 329 |
-
pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
|
| 330 |
|
| 331 |
if use_upscaler:
|
| 332 |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
| 333 |
-
|
| 334 |
metadata = {
|
| 335 |
"prompt": prompt,
|
| 336 |
"negative_prompt": negative_prompt,
|
|
@@ -344,11 +121,6 @@ def generate(
|
|
| 344 |
"quality_tags": quality_selector,
|
| 345 |
}
|
| 346 |
|
| 347 |
-
if use_lora:
|
| 348 |
-
metadata["use_lora"] = {"selected_lora": repo_name, "multiplier": lora_weight}
|
| 349 |
-
else:
|
| 350 |
-
metadata["use_lora"] = None
|
| 351 |
-
|
| 352 |
if use_upscaler:
|
| 353 |
new_width = int(width * upscale_by)
|
| 354 |
new_height = int(height * upscale_by)
|
|
@@ -360,8 +132,7 @@ def generate(
|
|
| 360 |
}
|
| 361 |
else:
|
| 362 |
metadata["use_upscaler"] = None
|
| 363 |
-
|
| 364 |
-
print(json.dumps(metadata, indent=4))
|
| 365 |
|
| 366 |
try:
|
| 367 |
if use_upscaler:
|
|
@@ -375,8 +146,8 @@ def generate(
|
|
| 375 |
generator=generator,
|
| 376 |
output_type="latent",
|
| 377 |
).images
|
| 378 |
-
upscaled_latents = upscale(latents, "nearest-exact", upscale_by)
|
| 379 |
-
|
| 380 |
prompt=prompt,
|
| 381 |
negative_prompt=negative_prompt,
|
| 382 |
image=upscaled_latents,
|
|
@@ -385,9 +156,9 @@ def generate(
|
|
| 385 |
strength=upscaler_strength,
|
| 386 |
generator=generator,
|
| 387 |
output_type="pil",
|
| 388 |
-
).images
|
| 389 |
else:
|
| 390 |
-
|
| 391 |
prompt=prompt,
|
| 392 |
negative_prompt=negative_prompt,
|
| 393 |
width=width,
|
|
@@ -396,194 +167,38 @@ def generate(
|
|
| 396 |
num_inference_steps=num_inference_steps,
|
| 397 |
generator=generator,
|
| 398 |
output_type="pil",
|
| 399 |
-
).images
|
| 400 |
-
if network:
|
| 401 |
-
network.unapply_to()
|
| 402 |
-
network = None
|
| 403 |
-
if profile is not None:
|
| 404 |
-
gr_user_history.save_image(
|
| 405 |
-
label=prompt,
|
| 406 |
-
image=image,
|
| 407 |
-
profile=profile,
|
| 408 |
-
metadata=metadata,
|
| 409 |
-
)
|
| 410 |
-
if image and IS_COLAB:
|
| 411 |
-
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 412 |
-
output_directory = "./outputs"
|
| 413 |
-
os.makedirs(output_directory, exist_ok=True)
|
| 414 |
-
filename = f"image_{current_time}.png"
|
| 415 |
-
filepath = os.path.join(output_directory, filename)
|
| 416 |
-
|
| 417 |
-
# Convert metadata to a string and save as a text chunk in the PNG
|
| 418 |
-
metadata_str = json.dumps(metadata)
|
| 419 |
-
info = PngImagePlugin.PngInfo()
|
| 420 |
-
info.add_text("metadata", metadata_str)
|
| 421 |
-
image.save(filepath, "PNG", pnginfo=info)
|
| 422 |
-
print(f"Image saved as {filepath} with metadata")
|
| 423 |
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
| 425 |
|
|
|
|
| 426 |
except Exception as e:
|
| 427 |
-
|
| 428 |
raise
|
| 429 |
finally:
|
| 430 |
-
if network:
|
| 431 |
-
network.unapply_to()
|
| 432 |
-
network = None
|
| 433 |
-
if use_lora:
|
| 434 |
-
del lora_sd, text_encoders
|
| 435 |
if use_upscaler:
|
| 436 |
del upscaler_pipe
|
| 437 |
pipe.scheduler = backup_scheduler
|
| 438 |
-
free_memory()
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
examples = [
|
| 442 |
-
"1girl, arima kana, oshi no ko, solo, idol, idol clothes, one eye closed, red shirt, black skirt, black headwear, gloves, stage light, singing, open mouth, crowd, smile, pointing at viewer",
|
| 443 |
-
"1girl, c.c., code geass, white shirt, long sleeves, turtleneck, sitting, looking at viewer, eating, pizza, plate, fork, knife, table, chair, table, restaurant, cinematic angle, cinematic lighting",
|
| 444 |
-
"1girl, sakurauchi riko, \(love live\), queen hat, noble coat, red coat, noble shirt, sitting, crossed legs, gentle smile, parted lips, throne, cinematic angle",
|
| 445 |
-
"1girl, amiya \(arknights\), arknights, dirty face, outstretched hand, close-up, cinematic angle, foreshortening, dark, dark background",
|
| 446 |
-
"A boy and a girl, Emiya Shirou and Artoria Pendragon from fate series, having their breakfast in the dining room. Emiya Shirou wears white t-shirt and jacket. Artoria Pendragon wears white dress with blue neck ribbon. Rice, soup, and minced meats are served on the table. They look at each other while smiling happily",
|
| 447 |
-
]
|
| 448 |
-
|
| 449 |
-
quality_prompt_list = [
|
| 450 |
-
{
|
| 451 |
-
"name": "(None)",
|
| 452 |
-
"prompt": "{prompt}",
|
| 453 |
-
"negative_prompt": "nsfw, lowres, ",
|
| 454 |
-
},
|
| 455 |
-
{
|
| 456 |
-
"name": "Standard",
|
| 457 |
-
"prompt": "{prompt}, masterpiece, best quality",
|
| 458 |
-
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, ",
|
| 459 |
-
},
|
| 460 |
-
{
|
| 461 |
-
"name": "Light",
|
| 462 |
-
"prompt": "{prompt}, (masterpiece), best quality, perfect face",
|
| 463 |
-
"negative_prompt": "nsfw, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, ",
|
| 464 |
-
},
|
| 465 |
-
{
|
| 466 |
-
"name": "Heavy",
|
| 467 |
-
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), illustration, disheveled hair, perfect composition, moist skin, intricate details, earrings",
|
| 468 |
-
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, ",
|
| 469 |
-
},
|
| 470 |
-
]
|
| 471 |
-
|
| 472 |
-
sampler_list = [
|
| 473 |
-
"DPM++ 2M Karras",
|
| 474 |
-
"DPM++ SDE Karras",
|
| 475 |
-
"DPM++ 2M SDE Karras",
|
| 476 |
-
"Euler",
|
| 477 |
-
"Euler a",
|
| 478 |
-
"DDIM",
|
| 479 |
-
]
|
| 480 |
-
|
| 481 |
-
aspect_ratios = [
|
| 482 |
-
"1024 x 1024",
|
| 483 |
-
"1152 x 896",
|
| 484 |
-
"896 x 1152",
|
| 485 |
-
"1216 x 832",
|
| 486 |
-
"832 x 1216",
|
| 487 |
-
"1344 x 768",
|
| 488 |
-
"768 x 1344",
|
| 489 |
-
"1536 x 640",
|
| 490 |
-
"640 x 1536",
|
| 491 |
-
"Custom",
|
| 492 |
-
]
|
| 493 |
-
|
| 494 |
-
style_list = [
|
| 495 |
-
{
|
| 496 |
-
"name": "(None)",
|
| 497 |
-
"prompt": "{prompt}",
|
| 498 |
-
"negative_prompt": "",
|
| 499 |
-
},
|
| 500 |
-
{
|
| 501 |
-
"name": "Cinematic",
|
| 502 |
-
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 503 |
-
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 504 |
-
},
|
| 505 |
-
{
|
| 506 |
-
"name": "Photographic",
|
| 507 |
-
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 508 |
-
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 509 |
-
},
|
| 510 |
-
{
|
| 511 |
-
"name": "Anime",
|
| 512 |
-
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
| 513 |
-
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
| 514 |
-
},
|
| 515 |
-
{
|
| 516 |
-
"name": "Manga",
|
| 517 |
-
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 518 |
-
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 519 |
-
},
|
| 520 |
-
{
|
| 521 |
-
"name": "Digital Art",
|
| 522 |
-
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 523 |
-
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
| 524 |
-
},
|
| 525 |
-
{
|
| 526 |
-
"name": "Pixel art",
|
| 527 |
-
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
| 528 |
-
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 529 |
-
},
|
| 530 |
-
{
|
| 531 |
-
"name": "Fantasy art",
|
| 532 |
-
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
| 533 |
-
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
| 534 |
-
},
|
| 535 |
-
{
|
| 536 |
-
"name": "Neonpunk",
|
| 537 |
-
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
| 538 |
-
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 539 |
-
},
|
| 540 |
-
{
|
| 541 |
-
"name": "3D Model",
|
| 542 |
-
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
| 543 |
-
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
| 544 |
-
},
|
| 545 |
-
]
|
| 546 |
-
|
| 547 |
-
thumbnail_cache = {}
|
| 548 |
|
| 549 |
-
with open("lora.toml", "r") as file:
|
| 550 |
-
data = toml.load(file)
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
if model_path not in thumbnail_cache:
|
| 559 |
-
thumbnail_image = load_and_convert_thumbnail(model_path)
|
| 560 |
-
thumbnail_cache[model_path] = thumbnail_image
|
| 561 |
-
else:
|
| 562 |
-
thumbnail_image = thumbnail_cache[model_path]
|
| 563 |
-
|
| 564 |
-
sdxl_loras.append(
|
| 565 |
-
{
|
| 566 |
-
"image": thumbnail_image, # Storing the PIL image object
|
| 567 |
-
"title": item["title"],
|
| 568 |
-
"repo": item["repo"],
|
| 569 |
-
"weights": item["weights"],
|
| 570 |
-
"multiplier": item.get("multiplier", "1.0"),
|
| 571 |
-
}
|
| 572 |
-
)
|
| 573 |
|
| 574 |
-
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
| 575 |
quality_prompt = {
|
| 576 |
-
k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list
|
| 577 |
}
|
| 578 |
|
| 579 |
-
|
| 580 |
-
# hf_hub_download(item["repo"], item["weights"], token=HF_TOKEN)
|
| 581 |
-
# for item in sdxl_loras
|
| 582 |
-
# ]
|
| 583 |
-
|
| 584 |
-
wildcard_files = load_wildcard_files("wildcard")
|
| 585 |
|
| 586 |
-
with gr.Blocks(css="style.css"
|
| 587 |
title = gr.HTML(
|
| 588 |
f"""<h1><span>{DESCRIPTION}</span></h1>""",
|
| 589 |
elem_id="title",
|
|
@@ -592,187 +207,131 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
| 592 |
f"""Gradio demo for [cagliostrolab/animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0)""",
|
| 593 |
elem_id="subtitle",
|
| 594 |
)
|
| 595 |
-
gr.Markdown(
|
| 596 |
-
f"""Prompting is a bit different in this iteration, we train the model like this:
|
| 597 |
-
```
|
| 598 |
-
1girl/1boy, character name, from what series, everything else in any order.
|
| 599 |
-
```
|
| 600 |
-
Prompting Tips
|
| 601 |
-
```
|
| 602 |
-
1. Quality Tags: `masterpiece, best quality, high quality, normal quality, worst quality, low quality`
|
| 603 |
-
2. Year Tags: `oldest, early, mid, late, newest`
|
| 604 |
-
3. Rating tags: `rating: general, rating: sensitive, rating: questionable, rating: explicit, nsfw`
|
| 605 |
-
4. Escape character: `character name \(series\)`
|
| 606 |
-
5. Recommended settings: `Euler a, cfg 5-7, 25-28 steps`
|
| 607 |
-
6. It's recommended to use the exact danbooru tags for more accurate result
|
| 608 |
-
7. To use character wildcard, add this syntax to the prompt `__character__`.
|
| 609 |
-
```
|
| 610 |
-
""",
|
| 611 |
-
elem_id="subtitle",
|
| 612 |
-
)
|
| 613 |
gr.DuplicateButton(
|
| 614 |
value="Duplicate Space for private use",
|
| 615 |
elem_id="duplicate-button",
|
| 616 |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 617 |
)
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
label="Width",
|
| 702 |
-
minimum=MIN_IMAGE_SIZE,
|
| 703 |
-
maximum=MAX_IMAGE_SIZE,
|
| 704 |
-
step=8,
|
| 705 |
-
value=1024,
|
| 706 |
-
)
|
| 707 |
-
custom_height = gr.Slider(
|
| 708 |
-
label="Height",
|
| 709 |
-
minimum=MIN_IMAGE_SIZE,
|
| 710 |
-
maximum=MAX_IMAGE_SIZE,
|
| 711 |
-
step=8,
|
| 712 |
-
value=1024,
|
| 713 |
-
)
|
| 714 |
-
with gr.Group():
|
| 715 |
-
sampler = gr.Dropdown(
|
| 716 |
-
label="Sampler",
|
| 717 |
-
choices=sampler_list,
|
| 718 |
-
interactive=True,
|
| 719 |
-
value="Euler a",
|
| 720 |
-
)
|
| 721 |
-
with gr.Group():
|
| 722 |
-
seed = gr.Slider(
|
| 723 |
-
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
|
| 724 |
-
)
|
| 725 |
-
|
| 726 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 727 |
-
with gr.Group():
|
| 728 |
-
with gr.Row():
|
| 729 |
-
guidance_scale = gr.Slider(
|
| 730 |
-
label="Guidance scale",
|
| 731 |
-
minimum=1,
|
| 732 |
-
maximum=12,
|
| 733 |
-
step=0.1,
|
| 734 |
-
value=7.0,
|
| 735 |
-
)
|
| 736 |
-
num_inference_steps = gr.Slider(
|
| 737 |
-
label="Number of inference steps",
|
| 738 |
-
minimum=1,
|
| 739 |
-
maximum=50,
|
| 740 |
-
step=1,
|
| 741 |
-
value=28,
|
| 742 |
-
)
|
| 743 |
-
|
| 744 |
-
with gr.Tab("Past Generation"):
|
| 745 |
-
gr_user_history.render()
|
| 746 |
-
with gr.Column(scale=3):
|
| 747 |
-
with gr.Blocks():
|
| 748 |
-
run_button = gr.Button("Generate", variant="primary")
|
| 749 |
-
result = gr.Image(label="Result", show_label=False)
|
| 750 |
-
with gr.Accordion(label="Generation Parameters", open=False):
|
| 751 |
-
gr_metadata = gr.JSON(label="Metadata", show_label=False)
|
| 752 |
-
gr.Examples(
|
| 753 |
-
examples=examples,
|
| 754 |
-
inputs=prompt,
|
| 755 |
-
outputs=[result, gr_metadata],
|
| 756 |
-
fn=generate,
|
| 757 |
-
cache_examples=CACHE_EXAMPLES,
|
| 758 |
)
|
| 759 |
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
)
|
| 777 |
use_upscaler.change(
|
| 778 |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
|
|
@@ -797,9 +356,6 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
| 797 |
custom_height,
|
| 798 |
guidance_scale,
|
| 799 |
num_inference_steps,
|
| 800 |
-
use_lora,
|
| 801 |
-
lora_weight,
|
| 802 |
-
selected_state,
|
| 803 |
sampler,
|
| 804 |
aspect_ratio_selector,
|
| 805 |
style_selector,
|
|
@@ -807,11 +363,11 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
| 807 |
use_upscaler,
|
| 808 |
upscaler_strength,
|
| 809 |
upscale_by,
|
| 810 |
-
add_quality_tags
|
| 811 |
]
|
| 812 |
|
| 813 |
prompt.submit(
|
| 814 |
-
fn=randomize_seed_fn,
|
| 815 |
inputs=[seed, randomize_seed],
|
| 816 |
outputs=seed,
|
| 817 |
queue=False,
|
|
@@ -823,7 +379,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
| 823 |
api_name="run",
|
| 824 |
)
|
| 825 |
negative_prompt.submit(
|
| 826 |
-
fn=randomize_seed_fn,
|
| 827 |
inputs=[seed, randomize_seed],
|
| 828 |
outputs=seed,
|
| 829 |
queue=False,
|
|
@@ -835,7 +391,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
| 835 |
api_name=False,
|
| 836 |
)
|
| 837 |
run_button.click(
|
| 838 |
-
fn=randomize_seed_fn,
|
| 839 |
inputs=[seed, randomize_seed],
|
| 840 |
outputs=seed,
|
| 841 |
queue=False,
|
|
@@ -846,4 +402,4 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
|
|
| 846 |
outputs=[result, gr_metadata],
|
| 847 |
api_name=False,
|
| 848 |
)
|
| 849 |
-
demo.queue(max_size=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import gc
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
import torch
|
| 6 |
import json
|
| 7 |
+
import spaces
|
| 8 |
+
import config
|
| 9 |
+
import utils
|
| 10 |
+
import logging
|
| 11 |
+
from PIL import Image, PngImagePlugin
|
| 12 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from diffusers.models import AutoencoderKL
|
| 14 |
+
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
DESCRIPTION = "Animagine XL 3.0"
|
| 20 |
if not torch.cuda.is_available():
|
| 21 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
|
| 22 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
|
|
|
| 23 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 24 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
| 25 |
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
|
| 26 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
|
| 27 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
|
| 28 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
| 29 |
+
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
|
| 30 |
|
| 31 |
+
MODEL = os.getenv(
|
| 32 |
+
"MODEL",
|
| 33 |
+
"https://huggingface.co/cagliostrolab/animagine-xl-3.0/blob/main/animagine-xl-3.0.safetensors",
|
| 34 |
+
)
|
| 35 |
|
| 36 |
torch.backends.cudnn.deterministic = True
|
| 37 |
torch.backends.cudnn.benchmark = False
|
| 38 |
|
| 39 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 40 |
|
| 41 |
+
|
| 42 |
+
def load_pipeline(model_name):
|
| 43 |
vae = AutoencoderKL.from_pretrained(
|
| 44 |
"madebyollin/sdxl-vae-fp16-fix",
|
| 45 |
torch_dtype=torch.float16,
|
| 46 |
)
|
| 47 |
+
pipeline = (
|
| 48 |
+
StableDiffusionXLPipeline.from_single_file
|
| 49 |
+
if MODEL.endswith(".safetensors")
|
| 50 |
+
else StableDiffusionXLPipeline.from_pretrained
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
pipe = pipeline(
|
| 54 |
+
model_name,
|
| 55 |
vae=vae,
|
| 56 |
torch_dtype=torch.float16,
|
| 57 |
custom_pipeline="lpw_stable_diffusion_xl",
|
| 58 |
use_safetensors=True,
|
| 59 |
+
add_watermarker=False,
|
| 60 |
use_auth_token=HF_TOKEN,
|
| 61 |
variant="fp16",
|
| 62 |
)
|
| 63 |
|
| 64 |
+
pipe.to(device)
|
| 65 |
+
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
@spaces.GPU
|
| 69 |
def generate(
|
| 70 |
prompt: str,
|
| 71 |
negative_prompt: str = "",
|
|
|
|
| 74 |
custom_height: int = 1024,
|
| 75 |
guidance_scale: float = 7.0,
|
| 76 |
num_inference_steps: int = 28,
|
|
|
|
|
|
|
|
|
|
| 77 |
sampler: str = "Euler a",
|
| 78 |
aspect_ratio_selector: str = "896 x 1152",
|
| 79 |
style_selector: str = "(None)",
|
| 80 |
quality_selector: str = "Standard",
|
| 81 |
use_upscaler: bool = False,
|
| 82 |
+
upscaler_strength: float = 0.55,
|
| 83 |
upscale_by: float = 1.5,
|
| 84 |
add_quality_tags: bool = True,
|
|
|
|
| 85 |
progress=gr.Progress(track_tqdm=True),
|
| 86 |
+
) -> Image:
|
| 87 |
+
generator = utils.seed_everything(seed)
|
| 88 |
|
| 89 |
+
width, height = utils.aspect_ratio_handler(
|
|
|
|
|
|
|
|
|
|
| 90 |
aspect_ratio_selector,
|
| 91 |
custom_width,
|
| 92 |
custom_height,
|
| 93 |
)
|
| 94 |
|
| 95 |
+
prompt = utils.add_wildcard(prompt, wildcard_files)
|
| 96 |
|
| 97 |
+
prompt, negative_prompt = utils.preprocess_prompt(
|
|
|
|
| 98 |
quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
|
| 99 |
)
|
| 100 |
+
prompt, negative_prompt = utils.preprocess_prompt(
|
| 101 |
styles, style_selector, prompt, negative_prompt
|
| 102 |
)
|
| 103 |
|
| 104 |
+
width, height = utils.preprocess_image_dimensions(width, height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
backup_scheduler = pipe.scheduler
|
| 107 |
+
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
|
| 108 |
|
| 109 |
if use_upscaler:
|
| 110 |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
|
|
|
| 111 |
metadata = {
|
| 112 |
"prompt": prompt,
|
| 113 |
"negative_prompt": negative_prompt,
|
|
|
|
| 121 |
"quality_tags": quality_selector,
|
| 122 |
}
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if use_upscaler:
|
| 125 |
new_width = int(width * upscale_by)
|
| 126 |
new_height = int(height * upscale_by)
|
|
|
|
| 132 |
}
|
| 133 |
else:
|
| 134 |
metadata["use_upscaler"] = None
|
| 135 |
+
logger.info(json.dumps(metadata, indent=4))
|
|
|
|
| 136 |
|
| 137 |
try:
|
| 138 |
if use_upscaler:
|
|
|
|
| 146 |
generator=generator,
|
| 147 |
output_type="latent",
|
| 148 |
).images
|
| 149 |
+
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
|
| 150 |
+
images = upscaler_pipe(
|
| 151 |
prompt=prompt,
|
| 152 |
negative_prompt=negative_prompt,
|
| 153 |
image=upscaled_latents,
|
|
|
|
| 156 |
strength=upscaler_strength,
|
| 157 |
generator=generator,
|
| 158 |
output_type="pil",
|
| 159 |
+
).images
|
| 160 |
else:
|
| 161 |
+
images = pipe(
|
| 162 |
prompt=prompt,
|
| 163 |
negative_prompt=negative_prompt,
|
| 164 |
width=width,
|
|
|
|
| 167 |
num_inference_steps=num_inference_steps,
|
| 168 |
generator=generator,
|
| 169 |
output_type="pil",
|
| 170 |
+
).images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
if images and IS_COLAB:
|
| 173 |
+
for image in images:
|
| 174 |
+
filepath = utils.save_image(image, metadata, OUTPUT_DIR)
|
| 175 |
+
logger.info(f"Image saved as {filepath} with metadata")
|
| 176 |
|
| 177 |
+
return images, metadata
|
| 178 |
except Exception as e:
|
| 179 |
+
logger.exception(f"An error occurred: {e}")
|
| 180 |
raise
|
| 181 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
if use_upscaler:
|
| 183 |
del upscaler_pipe
|
| 184 |
pipe.scheduler = backup_scheduler
|
| 185 |
+
utils.free_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
if torch.cuda.is_available():
|
| 189 |
+
pipe = load_pipeline(MODEL)
|
| 190 |
+
logger.info("Loaded on Device!")
|
| 191 |
+
else:
|
| 192 |
+
pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
|
| 195 |
quality_prompt = {
|
| 196 |
+
k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.quality_prompt_list
|
| 197 |
}
|
| 198 |
|
| 199 |
+
wildcard_files = utils.load_wildcard_files("wildcard")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
with gr.Blocks(css="style.css") as demo:
|
| 202 |
title = gr.HTML(
|
| 203 |
f"""<h1><span>{DESCRIPTION}</span></h1>""",
|
| 204 |
elem_id="title",
|
|
|
|
| 207 |
f"""Gradio demo for [cagliostrolab/animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0)""",
|
| 208 |
elem_id="subtitle",
|
| 209 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
gr.DuplicateButton(
|
| 211 |
value="Duplicate Space for private use",
|
| 212 |
elem_id="duplicate-button",
|
| 213 |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
| 214 |
)
|
| 215 |
+
with gr.Group():
|
| 216 |
+
with gr.Row():
|
| 217 |
+
prompt = gr.Text(
|
| 218 |
+
label="Prompt",
|
| 219 |
+
show_label=False,
|
| 220 |
+
max_lines=5,
|
| 221 |
+
placeholder="Enter your prompt",
|
| 222 |
+
container=False,
|
| 223 |
+
)
|
| 224 |
+
run_button = gr.Button(
|
| 225 |
+
"Generate",
|
| 226 |
+
variant="primary",
|
| 227 |
+
scale=0
|
| 228 |
+
)
|
| 229 |
+
result = gr.Gallery(
|
| 230 |
+
label="Result",
|
| 231 |
+
columns=1,
|
| 232 |
+
preview=True,
|
| 233 |
+
show_label=False
|
| 234 |
+
)
|
| 235 |
+
with gr.Accordion(label="Advanced Settings", open=False):
|
| 236 |
+
negative_prompt = gr.Text(
|
| 237 |
+
label="Negative Prompt",
|
| 238 |
+
max_lines=5,
|
| 239 |
+
placeholder="Enter a negative prompt",
|
| 240 |
+
)
|
| 241 |
+
with gr.Row():
|
| 242 |
+
add_quality_tags = gr.Checkbox(
|
| 243 |
+
label="Add Quality Tags",
|
| 244 |
+
value=True
|
| 245 |
+
)
|
| 246 |
+
quality_selector = gr.Dropdown(
|
| 247 |
+
label="Quality Tags Presets",
|
| 248 |
+
interactive=True,
|
| 249 |
+
choices=list(quality_prompt.keys()),
|
| 250 |
+
value="Standard",
|
| 251 |
+
)
|
| 252 |
+
style_selector = gr.Radio(
|
| 253 |
+
label="Style Preset",
|
| 254 |
+
container=True,
|
| 255 |
+
interactive=True,
|
| 256 |
+
choices=list(styles.keys()),
|
| 257 |
+
value="(None)",
|
| 258 |
+
)
|
| 259 |
+
aspect_ratio_selector = gr.Radio(
|
| 260 |
+
label="Aspect Ratio",
|
| 261 |
+
choices=config.aspect_ratios,
|
| 262 |
+
value="896 x 1152",
|
| 263 |
+
container=True,
|
| 264 |
+
)
|
| 265 |
+
with gr.Group(visible=False) as custom_resolution:
|
| 266 |
+
with gr.Row():
|
| 267 |
+
custom_width = gr.Slider(
|
| 268 |
+
label="Width",
|
| 269 |
+
minimum=MIN_IMAGE_SIZE,
|
| 270 |
+
maximum=MAX_IMAGE_SIZE,
|
| 271 |
+
step=8,
|
| 272 |
+
value=1024,
|
| 273 |
+
)
|
| 274 |
+
custom_height = gr.Slider(
|
| 275 |
+
label="Height",
|
| 276 |
+
minimum=MIN_IMAGE_SIZE,
|
| 277 |
+
maximum=MAX_IMAGE_SIZE,
|
| 278 |
+
step=8,
|
| 279 |
+
value=1024,
|
| 280 |
+
)
|
| 281 |
+
use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
|
| 282 |
+
with gr.Row() as upscaler_row:
|
| 283 |
+
upscaler_strength = gr.Slider(
|
| 284 |
+
label="Strength",
|
| 285 |
+
minimum=0,
|
| 286 |
+
maximum=1,
|
| 287 |
+
step=0.05,
|
| 288 |
+
value=0.55,
|
| 289 |
+
visible=False,
|
| 290 |
+
)
|
| 291 |
+
upscale_by = gr.Slider(
|
| 292 |
+
label="Upscale by",
|
| 293 |
+
minimum=1,
|
| 294 |
+
maximum=1.5,
|
| 295 |
+
step=0.1,
|
| 296 |
+
value=1.5,
|
| 297 |
+
visible=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
)
|
| 299 |
|
| 300 |
+
sampler = gr.Dropdown(
|
| 301 |
+
label="Sampler",
|
| 302 |
+
choices=config.sampler_list,
|
| 303 |
+
interactive=True,
|
| 304 |
+
value="Euler a",
|
| 305 |
+
)
|
| 306 |
+
with gr.Row():
|
| 307 |
+
seed = gr.Slider(
|
| 308 |
+
label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
|
| 309 |
+
)
|
| 310 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 311 |
+
with gr.Group():
|
| 312 |
+
with gr.Row():
|
| 313 |
+
guidance_scale = gr.Slider(
|
| 314 |
+
label="Guidance scale",
|
| 315 |
+
minimum=1,
|
| 316 |
+
maximum=12,
|
| 317 |
+
step=0.1,
|
| 318 |
+
value=7.0,
|
| 319 |
+
)
|
| 320 |
+
num_inference_steps = gr.Slider(
|
| 321 |
+
label="Number of inference steps",
|
| 322 |
+
minimum=1,
|
| 323 |
+
maximum=50,
|
| 324 |
+
step=1,
|
| 325 |
+
value=28,
|
| 326 |
+
)
|
| 327 |
+
with gr.Accordion(label="Generation Parameters", open=False):
|
| 328 |
+
gr_metadata = gr.JSON(label="Metadata", show_label=False)
|
| 329 |
+
gr.Examples(
|
| 330 |
+
examples=config.examples,
|
| 331 |
+
inputs=prompt,
|
| 332 |
+
outputs=[result, gr_metadata],
|
| 333 |
+
fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
|
| 334 |
+
cache_examples=CACHE_EXAMPLES,
|
| 335 |
)
|
| 336 |
use_upscaler.change(
|
| 337 |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
|
|
|
|
| 356 |
custom_height,
|
| 357 |
guidance_scale,
|
| 358 |
num_inference_steps,
|
|
|
|
|
|
|
|
|
|
| 359 |
sampler,
|
| 360 |
aspect_ratio_selector,
|
| 361 |
style_selector,
|
|
|
|
| 363 |
use_upscaler,
|
| 364 |
upscaler_strength,
|
| 365 |
upscale_by,
|
| 366 |
+
add_quality_tags,
|
| 367 |
]
|
| 368 |
|
| 369 |
prompt.submit(
|
| 370 |
+
fn=utils.randomize_seed_fn,
|
| 371 |
inputs=[seed, randomize_seed],
|
| 372 |
outputs=seed,
|
| 373 |
queue=False,
|
|
|
|
| 379 |
api_name="run",
|
| 380 |
)
|
| 381 |
negative_prompt.submit(
|
| 382 |
+
fn=utils.randomize_seed_fn,
|
| 383 |
inputs=[seed, randomize_seed],
|
| 384 |
outputs=seed,
|
| 385 |
queue=False,
|
|
|
|
| 391 |
api_name=False,
|
| 392 |
)
|
| 393 |
run_button.click(
|
| 394 |
+
fn=utils.randomize_seed_fn,
|
| 395 |
inputs=[seed, randomize_seed],
|
| 396 |
outputs=seed,
|
| 397 |
queue=False,
|
|
|
|
| 402 |
outputs=[result, gr_metadata],
|
| 403 |
api_name=False,
|
| 404 |
)
|
| 405 |
+
demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
|
config.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
examples = [
|
| 2 |
+
"1girl, arima kana, oshi no ko, solo, idol, idol clothes, one eye closed, red shirt, black skirt, black headwear, gloves, stage light, singing, open mouth, crowd, smile, pointing at viewer",
|
| 3 |
+
"1girl, c.c., code geass, white shirt, long sleeves, turtleneck, sitting, looking at viewer, eating, pizza, plate, fork, knife, table, chair, table, restaurant, cinematic angle, cinematic lighting",
|
| 4 |
+
"1girl, sakurauchi riko, \(love live\), queen hat, noble coat, red coat, noble shirt, sitting, crossed legs, gentle smile, parted lips, throne, cinematic angle",
|
| 5 |
+
"1girl, amiya \(arknights\), arknights, dirty face, outstretched hand, close-up, cinematic angle, foreshortening, dark, dark background",
|
| 6 |
+
"A boy and a girl, Emiya Shirou and Artoria Pendragon from fate series, having their breakfast in the dining room. Emiya Shirou wears white t-shirt and jacket. Artoria Pendragon wears white dress with blue neck ribbon. Rice, soup, and minced meats are served on the table. They look at each other while smiling happily",
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
quality_prompt_list = [
|
| 10 |
+
{
|
| 11 |
+
"name": "(None)",
|
| 12 |
+
"prompt": "{prompt}",
|
| 13 |
+
"negative_prompt": "nsfw, lowres, ",
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"name": "Standard",
|
| 17 |
+
"prompt": "{prompt}, masterpiece, best quality",
|
| 18 |
+
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, ",
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"name": "Light",
|
| 22 |
+
"prompt": "{prompt}, (masterpiece), best quality, perfect face",
|
| 23 |
+
"negative_prompt": "nsfw, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, ",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"name": "Heavy",
|
| 27 |
+
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), illustration, disheveled hair, perfect composition, moist skin, intricate details, earrings",
|
| 28 |
+
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, ",
|
| 29 |
+
},
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
sampler_list = [
|
| 33 |
+
"DPM++ 2M Karras",
|
| 34 |
+
"DPM++ SDE Karras",
|
| 35 |
+
"DPM++ 2M SDE Karras",
|
| 36 |
+
"Euler",
|
| 37 |
+
"Euler a",
|
| 38 |
+
"DDIM",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
aspect_ratios = [
|
| 42 |
+
"1024 x 1024",
|
| 43 |
+
"1152 x 896",
|
| 44 |
+
"896 x 1152",
|
| 45 |
+
"1216 x 832",
|
| 46 |
+
"832 x 1216",
|
| 47 |
+
"1344 x 768",
|
| 48 |
+
"768 x 1344",
|
| 49 |
+
"1536 x 640",
|
| 50 |
+
"640 x 1536",
|
| 51 |
+
"Custom",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
style_list = [
|
| 55 |
+
{
|
| 56 |
+
"name": "(None)",
|
| 57 |
+
"prompt": "{prompt}",
|
| 58 |
+
"negative_prompt": "",
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"name": "Cinematic",
|
| 62 |
+
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
| 63 |
+
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"name": "Photographic",
|
| 67 |
+
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
| 68 |
+
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "Anime",
|
| 72 |
+
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
| 73 |
+
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"name": "Manga",
|
| 77 |
+
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
| 78 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"name": "Digital Art",
|
| 82 |
+
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
| 83 |
+
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"name": "Pixel art",
|
| 87 |
+
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
| 88 |
+
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"name": "Fantasy art",
|
| 92 |
+
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
| 93 |
+
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"name": "Neonpunk",
|
| 97 |
+
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
| 98 |
+
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"name": "3D Model",
|
| 102 |
+
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
| 103 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
| 104 |
+
},
|
| 105 |
+
]
|
lora.toml
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
[[data]]
|
| 2 |
-
title = "Style Enhancer XL"
|
| 3 |
-
repo = "Linaqruf/style-enhancer-xl-lora"
|
| 4 |
-
weights = "style-enhancer-xl.safetensors"
|
| 5 |
-
multiplier = 0.6
|
| 6 |
-
[[data]]
|
| 7 |
-
title = "Anime Detailer XL"
|
| 8 |
-
repo = "Linaqruf/anime-detailer-xl-lora"
|
| 9 |
-
weights = "anime-detailer-xl.safetensors"
|
| 10 |
-
multiplier = 2.0
|
| 11 |
-
|
| 12 |
-
[[data]]
|
| 13 |
-
title = "Sketch Style XL"
|
| 14 |
-
repo = "Linaqruf/sketch-style-xl-lora"
|
| 15 |
-
weights = "sketch-style-xl.safetensors"
|
| 16 |
-
multiplier = 0.6
|
| 17 |
-
|
| 18 |
-
[[data]]
|
| 19 |
-
title = "Pastel Style XL 2.0"
|
| 20 |
-
repo = "Linaqruf/pastel-style-xl-lora"
|
| 21 |
-
weights = "pastel-style-xl-v2.safetensors"
|
| 22 |
-
multiplier = 0.6
|
| 23 |
-
|
| 24 |
-
[[data]]
|
| 25 |
-
title = "Anime Nouveau XL"
|
| 26 |
-
repo = "Linaqruf/anime-nouveau-xl-lora"
|
| 27 |
-
weights = "anime-nouveau-xl.safetensors"
|
| 28 |
-
multiplier = 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_diffusers.py
DELETED
|
@@ -1,478 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
LoRA module for Diffusers
|
| 3 |
-
==========================
|
| 4 |
-
|
| 5 |
-
This file works independently and is designed to operate with Diffusers.
|
| 6 |
-
|
| 7 |
-
Credits
|
| 8 |
-
-------
|
| 9 |
-
- Modified from: https://github.com/vladmandic/automatic/blob/master/modules/lora_diffusers.py
|
| 10 |
-
- Originally from: https://github.com/kohya-ss/sd-scripts/blob/sdxl/networks/lora_diffusers.py
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
import bisect
|
| 14 |
-
import math
|
| 15 |
-
import random
|
| 16 |
-
from typing import Any, Dict, List, Mapping, Optional, Union
|
| 17 |
-
from diffusers import UNet2DConditionModel
|
| 18 |
-
import numpy as np
|
| 19 |
-
from tqdm import tqdm
|
| 20 |
-
from transformers import CLIPTextModel
|
| 21 |
-
import torch
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def make_unet_conversion_map() -> Dict[str, str]:
|
| 25 |
-
unet_conversion_map_layer = []
|
| 26 |
-
|
| 27 |
-
for i in range(3): # num_blocks is 3 in sdxl
|
| 28 |
-
# loop over downblocks/upblocks
|
| 29 |
-
for j in range(2):
|
| 30 |
-
# loop over resnets/attentions for downblocks
|
| 31 |
-
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 32 |
-
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 33 |
-
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 34 |
-
|
| 35 |
-
if i < 3:
|
| 36 |
-
# no attention layers in down_blocks.3
|
| 37 |
-
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 38 |
-
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 39 |
-
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 40 |
-
|
| 41 |
-
for j in range(3):
|
| 42 |
-
# loop over resnets/attentions for upblocks
|
| 43 |
-
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 44 |
-
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
| 45 |
-
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 46 |
-
|
| 47 |
-
# if i > 0: commentout for sdxl
|
| 48 |
-
# no attention layers in up_blocks.0
|
| 49 |
-
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 50 |
-
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
| 51 |
-
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 52 |
-
|
| 53 |
-
if i < 3:
|
| 54 |
-
# no downsample in down_blocks.3
|
| 55 |
-
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 56 |
-
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 57 |
-
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 58 |
-
|
| 59 |
-
# no upsample in up_blocks.3
|
| 60 |
-
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 61 |
-
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
| 62 |
-
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 63 |
-
|
| 64 |
-
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 65 |
-
sd_mid_atn_prefix = "middle_block.1."
|
| 66 |
-
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 67 |
-
|
| 68 |
-
for j in range(2):
|
| 69 |
-
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 70 |
-
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 71 |
-
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 72 |
-
|
| 73 |
-
unet_conversion_map_resnet = [
|
| 74 |
-
# (stable-diffusion, HF Diffusers)
|
| 75 |
-
("in_layers.0.", "norm1."),
|
| 76 |
-
("in_layers.2.", "conv1."),
|
| 77 |
-
("out_layers.0.", "norm2."),
|
| 78 |
-
("out_layers.3.", "conv2."),
|
| 79 |
-
("emb_layers.1.", "time_emb_proj."),
|
| 80 |
-
("skip_connection.", "conv_shortcut."),
|
| 81 |
-
]
|
| 82 |
-
|
| 83 |
-
unet_conversion_map = []
|
| 84 |
-
for sd, hf in unet_conversion_map_layer:
|
| 85 |
-
if "resnets" in hf:
|
| 86 |
-
for sd_res, hf_res in unet_conversion_map_resnet:
|
| 87 |
-
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
| 88 |
-
else:
|
| 89 |
-
unet_conversion_map.append((sd, hf))
|
| 90 |
-
|
| 91 |
-
for j in range(2):
|
| 92 |
-
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
| 93 |
-
sd_time_embed_prefix = f"time_embed.{j*2}."
|
| 94 |
-
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
| 95 |
-
|
| 96 |
-
for j in range(2):
|
| 97 |
-
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
| 98 |
-
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
| 99 |
-
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
| 100 |
-
|
| 101 |
-
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
| 102 |
-
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
| 103 |
-
unet_conversion_map.append(("out.2.", "conv_out."))
|
| 104 |
-
|
| 105 |
-
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
| 106 |
-
return sd_hf_conversion_map
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
UNET_CONVERSION_MAP = make_unet_conversion_map()
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class LoRAModule(torch.nn.Module):
|
| 113 |
-
"""
|
| 114 |
-
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 115 |
-
"""
|
| 116 |
-
|
| 117 |
-
def __init__(
|
| 118 |
-
self,
|
| 119 |
-
lora_name,
|
| 120 |
-
org_module: torch.nn.Module,
|
| 121 |
-
multiplier=1.0,
|
| 122 |
-
lora_dim=4,
|
| 123 |
-
alpha=1,
|
| 124 |
-
):
|
| 125 |
-
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 126 |
-
super().__init__()
|
| 127 |
-
self.lora_name = lora_name
|
| 128 |
-
|
| 129 |
-
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
| 130 |
-
in_dim = org_module.in_channels
|
| 131 |
-
out_dim = org_module.out_channels
|
| 132 |
-
else:
|
| 133 |
-
in_dim = org_module.in_features
|
| 134 |
-
out_dim = org_module.out_features
|
| 135 |
-
|
| 136 |
-
self.lora_dim = lora_dim
|
| 137 |
-
|
| 138 |
-
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
| 139 |
-
kernel_size = org_module.kernel_size
|
| 140 |
-
stride = org_module.stride
|
| 141 |
-
padding = org_module.padding
|
| 142 |
-
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 143 |
-
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 144 |
-
else:
|
| 145 |
-
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 146 |
-
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 147 |
-
|
| 148 |
-
if type(alpha) == torch.Tensor:
|
| 149 |
-
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 150 |
-
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 151 |
-
self.scale = alpha / self.lora_dim
|
| 152 |
-
self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
|
| 153 |
-
|
| 154 |
-
# same as microsoft's
|
| 155 |
-
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 156 |
-
torch.nn.init.zeros_(self.lora_up.weight)
|
| 157 |
-
|
| 158 |
-
self.multiplier = multiplier
|
| 159 |
-
self.org_module = [org_module]
|
| 160 |
-
self.enabled = True
|
| 161 |
-
self.network: LoRANetwork = None
|
| 162 |
-
self.org_forward = None
|
| 163 |
-
|
| 164 |
-
# override org_module's forward method
|
| 165 |
-
def apply_to(self, multiplier=None):
|
| 166 |
-
if multiplier is not None:
|
| 167 |
-
self.multiplier = multiplier
|
| 168 |
-
if self.org_forward is None:
|
| 169 |
-
self.org_forward = self.org_module[0].forward
|
| 170 |
-
self.org_module[0].forward = self.forward
|
| 171 |
-
|
| 172 |
-
# restore org_module's forward method
|
| 173 |
-
def unapply_to(self):
|
| 174 |
-
if self.org_forward is not None:
|
| 175 |
-
self.org_module[0].forward = self.org_forward
|
| 176 |
-
|
| 177 |
-
# forward with lora
|
| 178 |
-
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
|
| 179 |
-
def forward(self, x, scale=1.0):
|
| 180 |
-
if not self.enabled:
|
| 181 |
-
return self.org_forward(x)
|
| 182 |
-
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 183 |
-
|
| 184 |
-
def set_network(self, network):
|
| 185 |
-
self.network = network
|
| 186 |
-
|
| 187 |
-
# merge lora weight to org weight
|
| 188 |
-
def merge_to(self, multiplier=1.0):
|
| 189 |
-
# get lora weight
|
| 190 |
-
lora_weight = self.get_weight(multiplier)
|
| 191 |
-
|
| 192 |
-
# get org weight
|
| 193 |
-
org_sd = self.org_module[0].state_dict()
|
| 194 |
-
org_weight = org_sd["weight"]
|
| 195 |
-
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
| 196 |
-
|
| 197 |
-
# set weight to org_module
|
| 198 |
-
org_sd["weight"] = weight
|
| 199 |
-
self.org_module[0].load_state_dict(org_sd)
|
| 200 |
-
|
| 201 |
-
# restore org weight from lora weight
|
| 202 |
-
def restore_from(self, multiplier=1.0):
|
| 203 |
-
# get lora weight
|
| 204 |
-
lora_weight = self.get_weight(multiplier)
|
| 205 |
-
|
| 206 |
-
# get org weight
|
| 207 |
-
org_sd = self.org_module[0].state_dict()
|
| 208 |
-
org_weight = org_sd["weight"]
|
| 209 |
-
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
| 210 |
-
|
| 211 |
-
# set weight to org_module
|
| 212 |
-
org_sd["weight"] = weight
|
| 213 |
-
self.org_module[0].load_state_dict(org_sd)
|
| 214 |
-
|
| 215 |
-
# return lora weight
|
| 216 |
-
def get_weight(self, multiplier=None):
|
| 217 |
-
if multiplier is None:
|
| 218 |
-
multiplier = self.multiplier
|
| 219 |
-
|
| 220 |
-
# get up/down weight from module
|
| 221 |
-
up_weight = self.lora_up.weight.to(torch.float)
|
| 222 |
-
down_weight = self.lora_down.weight.to(torch.float)
|
| 223 |
-
|
| 224 |
-
# pre-calculated weight
|
| 225 |
-
if len(down_weight.size()) == 2:
|
| 226 |
-
# linear
|
| 227 |
-
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
| 228 |
-
elif down_weight.size()[2:4] == (1, 1):
|
| 229 |
-
# conv2d 1x1
|
| 230 |
-
weight = (
|
| 231 |
-
self.multiplier
|
| 232 |
-
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 233 |
-
* self.scale
|
| 234 |
-
)
|
| 235 |
-
else:
|
| 236 |
-
# conv2d 3x3
|
| 237 |
-
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 238 |
-
weight = self.multiplier * conved * self.scale
|
| 239 |
-
|
| 240 |
-
return weight
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
# Create network from weights for inference, weights are not loaded here
|
| 244 |
-
def create_network_from_weights(
|
| 245 |
-
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
|
| 246 |
-
):
|
| 247 |
-
# get dim/alpha mapping
|
| 248 |
-
modules_dim = {}
|
| 249 |
-
modules_alpha = {}
|
| 250 |
-
for key, value in weights_sd.items():
|
| 251 |
-
if "." not in key:
|
| 252 |
-
continue
|
| 253 |
-
|
| 254 |
-
lora_name = key.split(".")[0]
|
| 255 |
-
if "alpha" in key:
|
| 256 |
-
modules_alpha[lora_name] = value
|
| 257 |
-
elif "lora_down" in key:
|
| 258 |
-
dim = value.size()[0]
|
| 259 |
-
modules_dim[lora_name] = dim
|
| 260 |
-
# print(lora_name, value.size(), dim)
|
| 261 |
-
|
| 262 |
-
# support old LoRA without alpha
|
| 263 |
-
for key in modules_dim.keys():
|
| 264 |
-
if key not in modules_alpha:
|
| 265 |
-
modules_alpha[key] = modules_dim[key]
|
| 266 |
-
|
| 267 |
-
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
| 271 |
-
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
| 272 |
-
unet = pipe.unet
|
| 273 |
-
|
| 274 |
-
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
| 275 |
-
lora_network.load_state_dict(weights_sd)
|
| 276 |
-
lora_network.merge_to(multiplier=multiplier)
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
| 280 |
-
class LoRANetwork(torch.nn.Module):
|
| 281 |
-
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 282 |
-
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 283 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 284 |
-
LORA_PREFIX_UNET = "lora_unet"
|
| 285 |
-
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 286 |
-
|
| 287 |
-
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
| 288 |
-
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
| 289 |
-
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
| 290 |
-
|
| 291 |
-
def __init__(
|
| 292 |
-
self,
|
| 293 |
-
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
| 294 |
-
unet: UNet2DConditionModel,
|
| 295 |
-
multiplier: float = 1.0,
|
| 296 |
-
modules_dim: Optional[Dict[str, int]] = None,
|
| 297 |
-
modules_alpha: Optional[Dict[str, int]] = None,
|
| 298 |
-
varbose: Optional[bool] = False,
|
| 299 |
-
) -> None:
|
| 300 |
-
super().__init__()
|
| 301 |
-
self.multiplier = multiplier
|
| 302 |
-
|
| 303 |
-
print(f"create LoRA network from weights")
|
| 304 |
-
|
| 305 |
-
# convert SDXL Stability AI's U-Net modules to Diffusers
|
| 306 |
-
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
| 307 |
-
if converted:
|
| 308 |
-
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
| 309 |
-
|
| 310 |
-
# create module instances
|
| 311 |
-
def create_modules(
|
| 312 |
-
is_unet: bool,
|
| 313 |
-
text_encoder_idx: Optional[int], # None, 1, 2
|
| 314 |
-
root_module: torch.nn.Module,
|
| 315 |
-
target_replace_modules: List[torch.nn.Module],
|
| 316 |
-
) -> List[LoRAModule]:
|
| 317 |
-
prefix = (
|
| 318 |
-
self.LORA_PREFIX_UNET
|
| 319 |
-
if is_unet
|
| 320 |
-
else (
|
| 321 |
-
self.LORA_PREFIX_TEXT_ENCODER
|
| 322 |
-
if text_encoder_idx is None
|
| 323 |
-
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
| 324 |
-
)
|
| 325 |
-
)
|
| 326 |
-
loras = []
|
| 327 |
-
skipped = []
|
| 328 |
-
for name, module in root_module.named_modules():
|
| 329 |
-
if module.__class__.__name__ in target_replace_modules:
|
| 330 |
-
for child_name, child_module in module.named_modules():
|
| 331 |
-
is_linear = (
|
| 332 |
-
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
| 333 |
-
)
|
| 334 |
-
is_conv2d = (
|
| 335 |
-
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
if is_linear or is_conv2d:
|
| 339 |
-
lora_name = prefix + "." + name + "." + child_name
|
| 340 |
-
lora_name = lora_name.replace(".", "_")
|
| 341 |
-
|
| 342 |
-
if lora_name not in modules_dim:
|
| 343 |
-
# print(f"skipped {lora_name} (not found in modules_dim)")
|
| 344 |
-
skipped.append(lora_name)
|
| 345 |
-
continue
|
| 346 |
-
|
| 347 |
-
dim = modules_dim[lora_name]
|
| 348 |
-
alpha = modules_alpha[lora_name]
|
| 349 |
-
lora = LoRAModule(
|
| 350 |
-
lora_name,
|
| 351 |
-
child_module,
|
| 352 |
-
self.multiplier,
|
| 353 |
-
dim,
|
| 354 |
-
alpha,
|
| 355 |
-
)
|
| 356 |
-
loras.append(lora)
|
| 357 |
-
return loras, skipped
|
| 358 |
-
|
| 359 |
-
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 360 |
-
|
| 361 |
-
# create LoRA for text encoder
|
| 362 |
-
# 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
|
| 363 |
-
self.text_encoder_loras: List[LoRAModule] = []
|
| 364 |
-
skipped_te = []
|
| 365 |
-
for i, text_encoder in enumerate(text_encoders):
|
| 366 |
-
if len(text_encoders) > 1:
|
| 367 |
-
index = i + 1
|
| 368 |
-
else:
|
| 369 |
-
index = None
|
| 370 |
-
|
| 371 |
-
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 372 |
-
self.text_encoder_loras.extend(text_encoder_loras)
|
| 373 |
-
skipped_te += skipped
|
| 374 |
-
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 375 |
-
if len(skipped_te) > 0:
|
| 376 |
-
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
| 377 |
-
|
| 378 |
-
# extend U-Net target modules to include Conv2d 3x3
|
| 379 |
-
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 380 |
-
|
| 381 |
-
self.unet_loras: List[LoRAModule]
|
| 382 |
-
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
| 383 |
-
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 384 |
-
if len(skipped_un) > 0:
|
| 385 |
-
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
| 386 |
-
|
| 387 |
-
# assertion
|
| 388 |
-
names = set()
|
| 389 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
| 390 |
-
names.add(lora.lora_name)
|
| 391 |
-
for lora_name in modules_dim.keys():
|
| 392 |
-
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
|
| 393 |
-
|
| 394 |
-
# make to work load_state_dict
|
| 395 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
| 396 |
-
self.add_module(lora.lora_name, lora)
|
| 397 |
-
|
| 398 |
-
# SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
|
| 399 |
-
def convert_unet_modules(self, modules_dim, modules_alpha):
|
| 400 |
-
converted_count = 0
|
| 401 |
-
not_converted_count = 0
|
| 402 |
-
|
| 403 |
-
map_keys = list(UNET_CONVERSION_MAP.keys())
|
| 404 |
-
map_keys.sort()
|
| 405 |
-
|
| 406 |
-
for key in list(modules_dim.keys()):
|
| 407 |
-
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
| 408 |
-
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
| 409 |
-
position = bisect.bisect_right(map_keys, search_key)
|
| 410 |
-
map_key = map_keys[position - 1]
|
| 411 |
-
if search_key.startswith(map_key):
|
| 412 |
-
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
| 413 |
-
modules_dim[new_key] = modules_dim[key]
|
| 414 |
-
modules_alpha[new_key] = modules_alpha[key]
|
| 415 |
-
del modules_dim[key]
|
| 416 |
-
del modules_alpha[key]
|
| 417 |
-
converted_count += 1
|
| 418 |
-
else:
|
| 419 |
-
not_converted_count += 1
|
| 420 |
-
assert (
|
| 421 |
-
converted_count == 0 or not_converted_count == 0
|
| 422 |
-
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
|
| 423 |
-
return converted_count
|
| 424 |
-
|
| 425 |
-
def set_multiplier(self, multiplier):
|
| 426 |
-
self.multiplier = multiplier
|
| 427 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
| 428 |
-
lora.multiplier = self.multiplier
|
| 429 |
-
|
| 430 |
-
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
| 431 |
-
if apply_text_encoder:
|
| 432 |
-
print("enable LoRA for text encoder")
|
| 433 |
-
for lora in self.text_encoder_loras:
|
| 434 |
-
lora.apply_to(multiplier)
|
| 435 |
-
if apply_unet:
|
| 436 |
-
print("enable LoRA for U-Net")
|
| 437 |
-
for lora in self.unet_loras:
|
| 438 |
-
lora.apply_to(multiplier)
|
| 439 |
-
|
| 440 |
-
def unapply_to(self):
|
| 441 |
-
for lora in self.text_encoder_loras + self.unet_loras:
|
| 442 |
-
lora.unapply_to()
|
| 443 |
-
|
| 444 |
-
def merge_to(self, multiplier=1.0):
|
| 445 |
-
print("merge LoRA weights to original weights")
|
| 446 |
-
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
| 447 |
-
lora.merge_to(multiplier)
|
| 448 |
-
print(f"weights are merged")
|
| 449 |
-
|
| 450 |
-
def restore_from(self, multiplier=1.0):
|
| 451 |
-
print("restore LoRA weights from original weights")
|
| 452 |
-
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
| 453 |
-
lora.restore_from(multiplier)
|
| 454 |
-
print(f"weights are restored")
|
| 455 |
-
|
| 456 |
-
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 457 |
-
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
| 458 |
-
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
| 459 |
-
map_keys.sort()
|
| 460 |
-
for key in list(state_dict.keys()):
|
| 461 |
-
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
| 462 |
-
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
| 463 |
-
position = bisect.bisect_right(map_keys, search_key)
|
| 464 |
-
map_key = map_keys[position - 1]
|
| 465 |
-
if search_key.startswith(map_key):
|
| 466 |
-
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
| 467 |
-
state_dict[new_key] = state_dict[key]
|
| 468 |
-
del state_dict[key]
|
| 469 |
-
|
| 470 |
-
# in case of V2, some weights have different shape, so we need to convert them
|
| 471 |
-
# because V2 LoRA is based on U-Net created by use_linear_projection=False
|
| 472 |
-
my_state_dict = self.state_dict()
|
| 473 |
-
for key in state_dict.keys():
|
| 474 |
-
if state_dict[key].size() != my_state_dict[key].size():
|
| 475 |
-
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
| 476 |
-
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
| 477 |
-
|
| 478 |
-
return super().load_state_dict(state_dict, strict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
-
accelerate==0.
|
| 2 |
-
diffusers==0.
|
| 3 |
-
gradio==4.
|
| 4 |
invisible-watermark==0.2.0
|
| 5 |
-
Pillow==10.
|
|
|
|
| 6 |
torch==2.0.1
|
| 7 |
-
transformers==4.
|
| 8 |
-
toml==0.10.2
|
| 9 |
omegaconf==2.3.0
|
| 10 |
timm==0.9.10
|
| 11 |
-
git+https://huggingface.co/spaces/Wauplin/gradio-user-history
|
|
|
|
| 1 |
+
accelerate==0.27.2
|
| 2 |
+
diffusers==0.26.3
|
| 3 |
+
gradio==4.20.0
|
| 4 |
invisible-watermark==0.2.0
|
| 5 |
+
Pillow==10.2.0
|
| 6 |
+
spaces==0.24.0
|
| 7 |
torch==2.0.1
|
| 8 |
+
transformers==4.38.1
|
|
|
|
| 9 |
omegaconf==2.3.0
|
| 10 |
timm==0.9.10
|
|
|
style.css
CHANGED
|
@@ -1,11 +1,6 @@
|
|
| 1 |
h1 {
|
| 2 |
text-align: center;
|
| 3 |
-
|
| 4 |
-
}
|
| 5 |
-
|
| 6 |
-
h2 {
|
| 7 |
-
text-align: center;
|
| 8 |
-
font-size: 10vw; /* relative to the viewport width */
|
| 9 |
}
|
| 10 |
|
| 11 |
#duplicate-button {
|
|
@@ -15,24 +10,12 @@ h2 {
|
|
| 15 |
border-radius: 100vh;
|
| 16 |
}
|
| 17 |
|
| 18 |
-
|
| 19 |
-
max-width:
|
| 20 |
margin: auto;
|
| 21 |
padding-top: 1.5rem;
|
| 22 |
}
|
| 23 |
|
| 24 |
-
/* You can also use media queries to adjust your style for different screen sizes */
|
| 25 |
-
@media (max-width: 600px) {
|
| 26 |
-
#component-0 {
|
| 27 |
-
max-width: 90%;
|
| 28 |
-
padding-top: 1rem;
|
| 29 |
-
}
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
#gallery .grid-wrap{
|
| 33 |
-
min-height: 25%;
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
#title-container {
|
| 37 |
display: flex;
|
| 38 |
justify-content: center;
|
|
@@ -43,18 +26,9 @@ h2 {
|
|
| 43 |
#title {
|
| 44 |
font-size: 3em;
|
| 45 |
text-align: center;
|
| 46 |
-
color: #333;
|
| 47 |
-
font-family: 'Helvetica Neue', sans-serif;
|
| 48 |
-
text-transform: uppercase;
|
| 49 |
background: transparent;
|
| 50 |
}
|
| 51 |
|
| 52 |
-
#title span {
|
| 53 |
-
background: -webkit-linear-gradient(45deg, #4EACEF, #28b485);
|
| 54 |
-
-webkit-background-clip: text;
|
| 55 |
-
-webkit-text-fill-color: transparent;
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
#subtitle {
|
| 59 |
text-align: center;
|
| 60 |
-
}
|
|
|
|
| 1 |
h1 {
|
| 2 |
text-align: center;
|
| 3 |
+
display: block;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
}
|
| 5 |
|
| 6 |
#duplicate-button {
|
|
|
|
| 10 |
border-radius: 100vh;
|
| 11 |
}
|
| 12 |
|
| 13 |
+
.gradio-container {
|
| 14 |
+
max-width: 730px !important;
|
| 15 |
margin: auto;
|
| 16 |
padding-top: 1.5rem;
|
| 17 |
}
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
#title-container {
|
| 20 |
display: flex;
|
| 21 |
justify-content: center;
|
|
|
|
| 26 |
#title {
|
| 27 |
font-size: 3em;
|
| 28 |
text-align: center;
|
|
|
|
|
|
|
|
|
|
| 29 |
background: transparent;
|
| 30 |
}
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
#subtitle {
|
| 33 |
text-align: center;
|
| 34 |
+
}
|
utils.py
CHANGED
|
@@ -1,7 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def is_google_colab():
|
| 2 |
try:
|
| 3 |
import google.colab
|
| 4 |
-
|
| 5 |
return True
|
| 6 |
except:
|
| 7 |
return False
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image, PngImagePlugin
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Callable, Dict, Optional, Tuple
|
| 11 |
+
from diffusers import (
|
| 12 |
+
DDIMScheduler,
|
| 13 |
+
DPMSolverMultistepScheduler,
|
| 14 |
+
DPMSolverSinglestepScheduler,
|
| 15 |
+
EulerAncestralDiscreteScheduler,
|
| 16 |
+
EulerDiscreteScheduler,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class StyleConfig:
|
| 24 |
+
prompt: str
|
| 25 |
+
negative_prompt: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 29 |
+
if randomize_seed:
|
| 30 |
+
seed = random.randint(0, MAX_SEED)
|
| 31 |
+
return seed
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def seed_everything(seed: int) -> torch.Generator:
|
| 35 |
+
torch.manual_seed(seed)
|
| 36 |
+
torch.cuda.manual_seed_all(seed)
|
| 37 |
+
np.random.seed(seed)
|
| 38 |
+
generator = torch.Generator()
|
| 39 |
+
generator.manual_seed(seed)
|
| 40 |
+
return generator
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
|
| 44 |
+
if aspect_ratio == "Custom":
|
| 45 |
+
return None
|
| 46 |
+
width, height = aspect_ratio.split(" x ")
|
| 47 |
+
return int(width), int(height)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def aspect_ratio_handler(
|
| 51 |
+
aspect_ratio: str, custom_width: int, custom_height: int
|
| 52 |
+
) -> Tuple[int, int]:
|
| 53 |
+
if aspect_ratio == "Custom":
|
| 54 |
+
return custom_width, custom_height
|
| 55 |
+
else:
|
| 56 |
+
width, height = parse_aspect_ratio(aspect_ratio)
|
| 57 |
+
return width, height
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
|
| 61 |
+
scheduler_factory_map = {
|
| 62 |
+
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
| 63 |
+
scheduler_config, use_karras_sigmas=True
|
| 64 |
+
),
|
| 65 |
+
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
|
| 66 |
+
scheduler_config, use_karras_sigmas=True
|
| 67 |
+
),
|
| 68 |
+
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
| 69 |
+
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
|
| 70 |
+
),
|
| 71 |
+
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
|
| 72 |
+
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
|
| 73 |
+
scheduler_config
|
| 74 |
+
),
|
| 75 |
+
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
| 76 |
+
}
|
| 77 |
+
return scheduler_factory_map.get(name, lambda: None)()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def free_memory() -> None:
|
| 81 |
+
torch.cuda.empty_cache()
|
| 82 |
+
gc.collect()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def preprocess_prompt(
|
| 86 |
+
style_dict,
|
| 87 |
+
style_name: str,
|
| 88 |
+
positive: str,
|
| 89 |
+
negative: str = "",
|
| 90 |
+
add_style: bool = True,
|
| 91 |
+
) -> Tuple[str, str]:
|
| 92 |
+
p, n = style_dict.get(style_name, style_dict["(None)"])
|
| 93 |
+
|
| 94 |
+
if add_style and positive.strip():
|
| 95 |
+
formatted_positive = p.format(prompt=positive)
|
| 96 |
+
else:
|
| 97 |
+
formatted_positive = positive
|
| 98 |
+
|
| 99 |
+
combined_negative = n
|
| 100 |
+
if negative.strip():
|
| 101 |
+
if combined_negative:
|
| 102 |
+
combined_negative += ", " + negative
|
| 103 |
+
else:
|
| 104 |
+
combined_negative = negative
|
| 105 |
+
|
| 106 |
+
return formatted_positive, combined_negative
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def common_upscale(
|
| 110 |
+
samples: torch.Tensor,
|
| 111 |
+
width: int,
|
| 112 |
+
height: int,
|
| 113 |
+
upscale_method: str,
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
return torch.nn.functional.interpolate(
|
| 116 |
+
samples, size=(height, width), mode=upscale_method
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def upscale(
|
| 121 |
+
samples: torch.Tensor, upscale_method: str, scale_by: float
|
| 122 |
+
) -> torch.Tensor:
|
| 123 |
+
width = round(samples.shape[3] * scale_by)
|
| 124 |
+
height = round(samples.shape[2] * scale_by)
|
| 125 |
+
return common_upscale(samples, width, height, upscale_method)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_wildcard_files(wildcard_dir: str) -> Dict[str, str]:
|
| 129 |
+
wildcard_files = {}
|
| 130 |
+
for file in os.listdir(wildcard_dir):
|
| 131 |
+
if file.endswith(".txt"):
|
| 132 |
+
key = f"__{file.split('.')[0]}__" # Create a key like __character__
|
| 133 |
+
wildcard_files[key] = os.path.join(wildcard_dir, file)
|
| 134 |
+
return wildcard_files
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_random_line_from_file(file_path: str) -> str:
|
| 138 |
+
with open(file_path, "r") as file:
|
| 139 |
+
lines = file.readlines()
|
| 140 |
+
if not lines:
|
| 141 |
+
return ""
|
| 142 |
+
return random.choice(lines).strip()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def add_wildcard(prompt: str, wildcard_files: Dict[str, str]) -> str:
|
| 146 |
+
for key, file_path in wildcard_files.items():
|
| 147 |
+
if key in prompt:
|
| 148 |
+
wildcard_line = get_random_line_from_file(file_path)
|
| 149 |
+
prompt = prompt.replace(key, wildcard_line)
|
| 150 |
+
return prompt
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def preprocess_image_dimensions(width, height):
|
| 154 |
+
if width % 8 != 0:
|
| 155 |
+
width = width - (width % 8)
|
| 156 |
+
if height % 8 != 0:
|
| 157 |
+
height = height - (height % 8)
|
| 158 |
+
return width, height
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def save_image(image, metadata, output_dir):
|
| 162 |
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 163 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 164 |
+
filename = f"image_{current_time}.png"
|
| 165 |
+
filepath = os.path.join(output_dir, filename)
|
| 166 |
+
|
| 167 |
+
metadata_str = json.dumps(metadata)
|
| 168 |
+
info = PngImagePlugin.PngInfo()
|
| 169 |
+
info.add_text("metadata", metadata_str)
|
| 170 |
+
image.save(filepath, "PNG", pnginfo=info)
|
| 171 |
+
return filepath
|
| 172 |
+
|
| 173 |
+
|
| 174 |
def is_google_colab():
|
| 175 |
try:
|
| 176 |
import google.colab
|
|
|
|
| 177 |
return True
|
| 178 |
except:
|
| 179 |
return False
|