import os
import folder_paths
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
import comfy.utils
import comfy.controlnet
import comfy.clip_vision
import comfy.model_management
from comfy.cli_args import args
import torch
import torch.nn as nn
import numpy as np
import latent_preview
from PIL import Image
from einops import rearrange
import scipy.ndimage
import sys
import cv2
from magic_utils import HWC3, apply_color, common_input_validate, resize_image_with_pad
from pidi import pidinet
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
folder_names_and_paths = {}
base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "../models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
def common_annotator_call(model, tensor_image, input_batch=False, show_pbar=True, **kwargs):
if "detect_resolution" in kwargs:
del kwargs["detect_resolution"] #Prevent weird case?
if "resolution" in kwargs:
detect_resolution = kwargs["resolution"] if type(kwargs["resolution"]) == int and kwargs["resolution"] >= 64 else 512
del kwargs["resolution"]
detect_resolution = 512
if input_batch:
np_images = np.asarray(tensor_image * 255., dtype=np.uint8)
np_results = model(np_images, output_type="np", detect_resolution=detect_resolution, **kwargs)
return torch.from_numpy(np_results.astype(np.float32) / 255.0)
batch_size = tensor_image.shape[0]
if show_pbar:
pbar = comfy.utils.ProgressBar(batch_size)
out_tensor = None
for i, image in enumerate(tensor_image):
np_image = np.asarray(image.cpu() * 255., dtype=np.uint8)
np_result = model(np_image, output_type="np", detect_resolution=detect_resolution, **kwargs)
out = torch.from_numpy(np_result.astype(np.float32) / 255.0)
if out_tensor is None:
out_tensor = torch.zeros(batch_size, *out.shape, dtype=torch.float32)
out_tensor[i] = out
if show_pbar:
return out_tensor
class CheckpointLoaderSimple:
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
print("Loading checkpoint from:", ckpt_path)
out =, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
class ControlNetLoader:
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet, )
class ControlNetApplyAdvanced:
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
if strength == 0:
return (positive, negative)
control_hint = image.movedim(-1,1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
return (out[0], out[1])
class CLIPTextEncode:
def encode(self, clip, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
class KSampler:
def common_ksampler(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out["samples"] = samples
return (out, )
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return self.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
class VAEDecode:
def decode(self, vae, samples):
return (vae.decode(samples["samples"]), )
class ColorDetector:
def __call__(self, input_image=None, detect_resolution=2048, output_type=None, **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
input_image = HWC3(input_image)
detected_map = HWC3(apply_color(input_image, detect_resolution))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map
class Color_Preprocessor:
def execute(self, image, resolution=512, **kwargs):
return (common_annotator_call(ColorDetector(), image, resolution=resolution), )
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.Conv2d(in_features, in_features, 3),
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
# Initial convolution block
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
class LineartDetector:
def __init__(self, model, coarse_model):
self.model = model
self.model_coarse = coarse_model
self.device = "cpu"
def from_pretrained(cls):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "../models/preprocessor/sk_model.pth")
coarse_model_path = os.path.join(current_dir, "../models/preprocessor/sk_model2.pth")
# print("model_path:", model_path)
model = Generator(3, 1, 3)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
coarse_model = Generator(3, 1, 3)
coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu')))
return cls(model, coarse_model)
def to(self, device):
self.device = device
return self
def __call__(self, input_image, coarse=False, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
model = self.model_coarse if coarse else self.model
assert detected_map.ndim == 3
with torch.no_grad():
image = torch.from_numpy(detected_map).float().to(self.device)
image = image / 255.0
image = rearrange(image, 'h w c -> 1 c h w')
line = model(image)[0][0]
line = line.cpu().numpy()
line = (line * 255.0).clip(0, 255).astype(np.uint8)
detected_map = HWC3(line)
detected_map = remove_pad(255 - detected_map)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map
class LineArt_Preprocessor:
def execute(self, image, resolution=512, **kwargs):
model = LineartDetector.from_pretrained().to(comfy.model_management.get_torch_device())
print("model.device:", model.device)
out = common_annotator_call(model, image, resolution=resolution, apply_filter=False, coarse = kwargs["coarse"] == "enable")
del model
return (out, )
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
class PidiNetDetector:
def __init__(self, netNetwork):
self.netNetwork = netNetwork
self.device = "cpu"
def from_pretrained(cls, filename="table5_pidinet.pth"):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, f"../models/preprocessor/{filename}")
netNetwork = pidinet()
netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()})
return cls(netNetwork)
def to(self, device):
self.device = device
return self
def __call__(self, input_image, detect_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=True, upscale_method="INTER_CUBIC", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
detected_map = detected_map[:, :, ::-1].copy()
with torch.no_grad():
image_pidi = torch.from_numpy(detected_map).float().to(self.device)
image_pidi = image_pidi / 255.0
image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
edge = self.netNetwork(image_pidi)[-1]
edge = edge.cpu().numpy()
if apply_filter:
edge = edge > 0.5
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
detected_map = edge[0, 0]
if scribble:
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0
detected_map = HWC3(remove_pad(detected_map))
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map
class GrowMask:
def expand_mask(self, mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]])
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = []
for m in mask:
output = m.numpy()
for _ in range(abs(expand)):
if expand < 0:
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
return (torch.stack(out, dim=0),)
class PIDINET_Preprocessor:
def execute(self, image, resolution=512, **kwargs):
model = PidiNetDetector.from_pretrained().to(comfy.model_management.get_torch_device())
out = common_annotator_call(model, image, resolution=resolution, safe=True)
del model
return (out, )