Spaces:
Paused
Paused
""" | |
Wrapper class to call the stablediffusion.cpp shared library for GGUF support | |
""" | |
import ctypes | |
import platform | |
from ctypes import ( | |
POINTER, | |
c_bool, | |
c_char_p, | |
c_float, | |
c_int, | |
c_int64, | |
c_void_p, | |
) | |
from dataclasses import dataclass | |
from os import path | |
from typing import List, Any | |
import numpy as np | |
from PIL import Image | |
from backend.gguf.sdcpp_types import ( | |
RngType, | |
SampleMethod, | |
Schedule, | |
SDCPPLogLevel, | |
SDImage, | |
SdType, | |
) | |
class ModelConfig: | |
model_path: str = "" | |
clip_l_path: str = "" | |
t5xxl_path: str = "" | |
diffusion_model_path: str = "" | |
vae_path: str = "" | |
taesd_path: str = "" | |
control_net_path: str = "" | |
lora_model_dir: str = "" | |
embed_dir: str = "" | |
stacked_id_embed_dir: str = "" | |
vae_decode_only: bool = True | |
vae_tiling: bool = False | |
free_params_immediately: bool = False | |
n_threads: int = 4 | |
wtype: SdType = SdType.SD_TYPE_Q4_0 | |
rng_type: RngType = RngType.CUDA_RNG | |
schedule: Schedule = Schedule.DEFAULT | |
keep_clip_on_cpu: bool = False | |
keep_control_net_cpu: bool = False | |
keep_vae_on_cpu: bool = False | |
class Txt2ImgConfig: | |
prompt: str = "a man wearing sun glasses, highly detailed" | |
negative_prompt: str = "" | |
clip_skip: int = -1 | |
cfg_scale: float = 2.0 | |
guidance: float = 3.5 | |
width: int = 512 | |
height: int = 512 | |
sample_method: SampleMethod = SampleMethod.EULER_A | |
sample_steps: int = 1 | |
seed: int = -1 | |
batch_count: int = 2 | |
control_cond: Image = None | |
control_strength: float = 0.90 | |
style_strength: float = 0.5 | |
normalize_input: bool = False | |
input_id_images_path: bytes = b"" | |
class GGUFDiffusion: | |
"""GGUF Diffusion | |
To support GGUF diffusion model based on stablediffusion.cpp | |
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md | |
Implmented based on stablediffusion.h | |
""" | |
def __init__( | |
self, | |
libpath: str, | |
config: ModelConfig, | |
logging_enabled: bool = False, | |
): | |
sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath) | |
try: | |
self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path) | |
except OSError as e: | |
print(f"Failed to load library {sdcpp_shared_lib_path}") | |
raise ValueError(f"Error: {e}") | |
if not config.clip_l_path or not path.exists(config.clip_l_path): | |
raise ValueError( | |
"CLIP model file not found,please check readme.md for GGUF model usage" | |
) | |
if not config.t5xxl_path or not path.exists(config.t5xxl_path): | |
raise ValueError( | |
"T5XXL model file not found,please check readme.md for GGUF model usage" | |
) | |
if not config.diffusion_model_path or not path.exists( | |
config.diffusion_model_path | |
): | |
raise ValueError( | |
"Diffusion model file not found,please check readme.md for GGUF model usage" | |
) | |
if not config.vae_path or not path.exists(config.vae_path): | |
raise ValueError( | |
"VAE model file not found,please check readme.md for GGUF model usage" | |
) | |
self.model_config = config | |
self.libsdcpp.new_sd_ctx.argtypes = [ | |
c_char_p, # const char* model_path | |
c_char_p, # const char* clip_l_path | |
c_char_p, # const char* t5xxl_path | |
c_char_p, # const char* diffusion_model_path | |
c_char_p, # const char* vae_path | |
c_char_p, # const char* taesd_path | |
c_char_p, # const char* control_net_path_c_str | |
c_char_p, # const char* lora_model_dir | |
c_char_p, # const char* embed_dir_c_str | |
c_char_p, # const char* stacked_id_embed_dir_c_str | |
c_bool, # bool vae_decode_only | |
c_bool, # bool vae_tiling | |
c_bool, # bool free_params_immediately | |
c_int, # int n_threads | |
SdType, # enum sd_type_t wtype | |
RngType, # enum rng_type_t rng_type | |
Schedule, # enum schedule_t s | |
c_bool, # bool keep_clip_on_cpu | |
c_bool, # bool keep_control_net_cpu | |
c_bool, # bool keep_vae_on_cpu | |
] | |
self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p) | |
self.sd_ctx = self.libsdcpp.new_sd_ctx( | |
self._str_to_bytes(self.model_config.model_path), | |
self._str_to_bytes(self.model_config.clip_l_path), | |
self._str_to_bytes(self.model_config.t5xxl_path), | |
self._str_to_bytes(self.model_config.diffusion_model_path), | |
self._str_to_bytes(self.model_config.vae_path), | |
self._str_to_bytes(self.model_config.taesd_path), | |
self._str_to_bytes(self.model_config.control_net_path), | |
self._str_to_bytes(self.model_config.lora_model_dir), | |
self._str_to_bytes(self.model_config.embed_dir), | |
self._str_to_bytes(self.model_config.stacked_id_embed_dir), | |
self.model_config.vae_decode_only, | |
self.model_config.vae_tiling, | |
self.model_config.free_params_immediately, | |
self.model_config.n_threads, | |
self.model_config.wtype, | |
self.model_config.rng_type, | |
self.model_config.schedule, | |
self.model_config.keep_clip_on_cpu, | |
self.model_config.keep_control_net_cpu, | |
self.model_config.keep_vae_on_cpu, | |
) | |
if logging_enabled: | |
self._set_logcallback() | |
def _set_logcallback(self): | |
print("Setting logging callback") | |
# Define function callback | |
SdLogCallbackType = ctypes.CFUNCTYPE( | |
None, | |
SDCPPLogLevel, | |
ctypes.c_char_p, | |
ctypes.c_void_p, | |
) | |
self.libsdcpp.sd_set_log_callback.argtypes = [ | |
SdLogCallbackType, | |
ctypes.c_void_p, | |
] | |
self.libsdcpp.sd_set_log_callback.restype = None | |
# Convert the Python callback to a C func pointer | |
self.c_log_callback = SdLogCallbackType( | |
self.log_callback | |
) # prevent GC,keep callback as member variable | |
self.libsdcpp.sd_set_log_callback(self.c_log_callback, None) | |
def _get_sdcpp_shared_lib_path( | |
self, | |
root_path: str, | |
) -> str: | |
system_name = platform.system() | |
print(f"GGUF Diffusion on {system_name}") | |
lib_name = "stable-diffusion.dll" | |
sdcpp_lib_path = "" | |
if system_name == "Windows": | |
sdcpp_lib_path = path.join(root_path, lib_name) | |
elif system_name == "Linux": | |
lib_name = "libstable-diffusion.so" | |
sdcpp_lib_path = path.join(root_path, lib_name) | |
elif system_name == "Darwin": | |
lib_name = "libstable-diffusion.dylib" | |
sdcpp_lib_path = path.join(root_path, lib_name) | |
else: | |
print("Unknown platform.") | |
return sdcpp_lib_path | |
def log_callback( | |
level, | |
text, | |
data, | |
): | |
print(f"{text.decode('utf-8')}", end="") | |
def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes: | |
if in_str: | |
return in_str.encode(encoding) | |
else: | |
return b"" | |
def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]: | |
self.libsdcpp.txt2img.restype = POINTER(SDImage) | |
self.libsdcpp.txt2img.argtypes = [ | |
c_void_p, # sd_ctx_t* sd_ctx (pointer to context object) | |
c_char_p, # const char* prompt | |
c_char_p, # const char* negative_prompt | |
c_int, # int clip_skip | |
c_float, # float cfg_scale | |
c_float, # float guidance | |
c_int, # int width | |
c_int, # int height | |
SampleMethod, # enum sample_method_t sample_method | |
c_int, # int sample_steps | |
c_int64, # int64_t seed | |
c_int, # int batch_count | |
POINTER(SDImage), # const sd_image_t* control_cond (pointer to SDImage) | |
c_float, # float control_strength | |
c_float, # float style_strength | |
c_bool, # bool normalize_input | |
c_char_p, # const char* input_id_images_path | |
] | |
image_buffer = self.libsdcpp.txt2img( | |
self.sd_ctx, | |
self._str_to_bytes(txt2img_cfg.prompt), | |
self._str_to_bytes(txt2img_cfg.negative_prompt), | |
txt2img_cfg.clip_skip, | |
txt2img_cfg.cfg_scale, | |
txt2img_cfg.guidance, | |
txt2img_cfg.width, | |
txt2img_cfg.height, | |
txt2img_cfg.sample_method, | |
txt2img_cfg.sample_steps, | |
txt2img_cfg.seed, | |
txt2img_cfg.batch_count, | |
txt2img_cfg.control_cond, | |
txt2img_cfg.control_strength, | |
txt2img_cfg.style_strength, | |
txt2img_cfg.normalize_input, | |
txt2img_cfg.input_id_images_path, | |
) | |
images = self._get_sd_images_from_buffer( | |
image_buffer, | |
txt2img_cfg.batch_count, | |
) | |
return images | |
def _get_sd_images_from_buffer( | |
self, | |
image_buffer: Any, | |
batch_count: int, | |
) -> List[Any]: | |
images = [] | |
if image_buffer: | |
for i in range(batch_count): | |
image = image_buffer[i] | |
print( | |
f"Generated image: {image.width}x{image.height} with {image.channel} channels" | |
) | |
width = image.width | |
height = image.height | |
channels = image.channel | |
pixel_data = np.ctypeslib.as_array( | |
image.data, shape=(height, width, channels) | |
) | |
if channels == 1: | |
pil_image = Image.fromarray(pixel_data.squeeze(), mode="L") | |
elif channels == 3: | |
pil_image = Image.fromarray(pixel_data, mode="RGB") | |
elif channels == 4: | |
pil_image = Image.fromarray(pixel_data, mode="RGBA") | |
else: | |
raise ValueError(f"Unsupported number of channels: {channels}") | |
images.append(pil_image) | |
return images | |
def terminate(self): | |
if self.libsdcpp: | |
if self.sd_ctx: | |
self.libsdcpp.free_sd_ctx.argtypes = [c_void_p] | |
self.libsdcpp.free_sd_ctx.restype = None | |
self.libsdcpp.free_sd_ctx(self.sd_ctx) | |
del self.sd_ctx | |
self.sd_ctx = None | |
del self.libsdcpp | |
self.libsdcpp = None | |