Spaces:
Paused
Paused
| from __future__ import annotations | |
| import gc | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from controlnet_aux.util import HWC3 | |
| from diffusers import (ControlNetModel, DiffusionPipeline, | |
| StableDiffusionControlNetPipeline, | |
| UniPCMultistepScheduler) | |
| from cv_utils import resize_image | |
| from preprocessor import Preprocessor | |
| CONTROLNET_MODEL_IDS = { | |
| 'Openpose': 'lllyasviel/control_v11p_sd15_openpose', | |
| 'Canny': 'lllyasviel/control_v11p_sd15_canny', | |
| 'MLSD': 'lllyasviel/control_v11p_sd15_mlsd', | |
| 'scribble': 'lllyasviel/control_v11p_sd15_scribble', | |
| 'softedge': 'lllyasviel/control_v11p_sd15_softedge', | |
| 'segmentation': 'lllyasviel/control_v11p_sd15_seg', | |
| 'depth': 'lllyasviel/control_v11f1p_sd15_depth', | |
| 'NormalBae': 'lllyasviel/control_v11p_sd15_normalbae', | |
| 'lineart': 'lllyasviel/control_v11p_sd15_lineart', | |
| 'lineart_anime': 'lllyasviel/control_v11p_sd15s2_lineart_anime', | |
| 'shuffle': 'lllyasviel/control_v11e_sd15_shuffle', | |
| 'ip2p': 'lllyasviel/control_v11e_sd15_ip2p', | |
| 'inpaint': 'lllyasviel/control_v11e_sd15_inpaint', | |
| } | |
| def download_all_controlnet_weights() -> None: | |
| for model_id in CONTROLNET_MODEL_IDS.values(): | |
| ControlNetModel.from_pretrained(model_id) | |
| class Model: | |
| def __init__(self, | |
| base_model_id: str = 'runwayml/stable-diffusion-v1-5', | |
| task_name: str = 'Canny'): | |
| self.device = torch.device( | |
| 'cuda:0' if torch.cuda.is_available() else 'cpu') | |
| self.base_model_id = '' | |
| self.task_name = '' | |
| self.pipe = self.load_pipe(base_model_id, task_name) | |
| self.preprocessor = Preprocessor() | |
| def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline: | |
| if base_model_id == self.base_model_id and task_name == self.task_name and hasattr( | |
| self, 'pipe') and self.pipe is not None: | |
| return self.pipe | |
| model_id = CONTROLNET_MODEL_IDS[task_name] | |
| controlnet = ControlNetModel.from_pretrained(model_id, | |
| torch_dtype=torch.float16) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| base_model_id, | |
| safety_checker=None, | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| pipe.scheduler.config) | |
| if self.device.type == 'cuda': | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.to(self.device) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| self.base_model_id = base_model_id | |
| self.task_name = task_name | |
| return pipe | |
| def set_base_model(self, base_model_id: str) -> str: | |
| if not base_model_id or base_model_id == self.base_model_id: | |
| return self.base_model_id | |
| del self.pipe | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| try: | |
| self.pipe = self.load_pipe(base_model_id, self.task_name) | |
| except Exception: | |
| self.pipe = self.load_pipe(self.base_model_id, self.task_name) | |
| return self.base_model_id | |
| def load_controlnet_weight(self, task_name: str) -> None: | |
| if task_name == self.task_name: | |
| return | |
| if self.pipe is not None and hasattr(self.pipe, 'controlnet'): | |
| del self.pipe.controlnet | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| model_id = CONTROLNET_MODEL_IDS[task_name] | |
| controlnet = ControlNetModel.from_pretrained(model_id, | |
| torch_dtype=torch.float16) | |
| controlnet.to(self.device) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| self.pipe.controlnet = controlnet | |
| self.task_name = task_name | |
| def get_prompt(self, prompt: str, additional_prompt: str) -> str: | |
| if not prompt: | |
| prompt = additional_prompt | |
| else: | |
| prompt = f'{prompt}, {additional_prompt}' | |
| return prompt | |
| def run_pipe( | |
| self, | |
| prompt: str, | |
| negative_prompt: str, | |
| control_image: PIL.Image.Image, | |
| num_images: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| if seed == -1: | |
| seed = np.random.randint(0, np.iinfo(np.int64).max) | |
| generator = torch.Generator().manual_seed(seed) | |
| return self.pipe(prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=num_images, | |
| num_inference_steps=num_steps, | |
| generator=generator, | |
| image=control_image).images | |
| def process_canny( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| low_threshold: int, | |
| high_threshold: int, | |
| ) -> list[PIL.Image.Image]: | |
| self.preprocessor.load('Canny') | |
| control_image = self.preprocessor(image=image, | |
| low_threshold=low_threshold, | |
| high_threshold=high_threshold, | |
| detect_resolution=image_resolution) | |
| self.load_controlnet_weight('Canny') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_mlsd( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| value_threshold: float, | |
| distance_threshold: float, | |
| ) -> list[PIL.Image.Image]: | |
| self.preprocessor.load('MLSD') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| thr_v=value_threshold, | |
| thr_d=distance_threshold, | |
| ) | |
| self.load_controlnet_weight('MLSD') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_scribble( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| elif preprocessor_name == 'HED': | |
| self.preprocessor.load(preprocessor_name) | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| scribble=False, | |
| ) | |
| elif preprocessor_name == 'PidiNet': | |
| self.preprocessor.load(preprocessor_name) | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| safe=False, | |
| ) | |
| self.load_controlnet_weight('scribble') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_scribble_interactive( | |
| self, | |
| image_and_mask: dict[str, np.ndarray], | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| image = image_and_mask['mask'] | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| self.load_controlnet_weight('scribble') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_softedge( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| elif preprocessor_name in ['HED', 'HED safe']: | |
| safe = 'safe' in preprocessor_name | |
| self.preprocessor.load('HED') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| scribble=safe, | |
| ) | |
| elif preprocessor_name in ['PidiNet', 'PidiNet safe']: | |
| safe = 'safe' in preprocessor_name | |
| self.preprocessor.load('PidiNet') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| safe=safe, | |
| ) | |
| else: | |
| raise ValueError | |
| self.load_controlnet_weight('softedge') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_openpose( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| else: | |
| self.preprocessor.load('Openpose') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| hand_and_face=True, | |
| ) | |
| self.load_controlnet_weight('Openpose') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_segmentation( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| else: | |
| self.preprocessor.load(preprocessor_name) | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| self.load_controlnet_weight('segmentation') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_depth( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| else: | |
| self.preprocessor.load(preprocessor_name) | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| self.load_controlnet_weight('depth') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_normal( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| else: | |
| self.preprocessor.load('NormalBae') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| self.load_controlnet_weight('NormalBae') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_lineart( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name in ['None', 'None (anime)']: | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| elif preprocessor_name in ['Lineart', 'Lineart coarse']: | |
| coarse = 'coarse' in preprocessor_name | |
| self.preprocessor.load('Lineart') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| coarse=coarse, | |
| ) | |
| elif preprocessor_name == 'Lineart (anime)': | |
| self.preprocessor.load('LineartAnime') | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| if 'anime' in preprocessor_name: | |
| self.load_controlnet_weight('lineart_anime') | |
| else: | |
| self.load_controlnet_weight('lineart') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_shuffle( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| preprocessor_name: str, | |
| ) -> list[PIL.Image.Image]: | |
| if preprocessor_name == 'None': | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| else: | |
| self.preprocessor.load(preprocessor_name) | |
| control_image = self.preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| ) | |
| self.load_controlnet_weight('shuffle') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |
| def process_ip2p( | |
| self, | |
| image: np.ndarray, | |
| prompt: str, | |
| additional_prompt: str, | |
| negative_prompt: str, | |
| num_images: int, | |
| image_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| ) -> list[PIL.Image.Image]: | |
| image = HWC3(image) | |
| image = resize_image(image, resolution=image_resolution) | |
| control_image = PIL.Image.fromarray(image) | |
| self.load_controlnet_weight('ip2p') | |
| results = self.run_pipe( | |
| prompt=self.get_prompt(prompt, additional_prompt), | |
| negative_prompt=negative_prompt, | |
| control_image=control_image, | |
| num_images=num_images, | |
| num_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ) | |
| return [control_image] + results | |