Spaces:
Running
Running
| import gc | |
| import os | |
| from collections import OrderedDict | |
| from typing import ForwardRef, List, Optional, Union | |
| import torch | |
| from safetensors.torch import save_file, load_file | |
| from jobs.process.BaseProcess import BaseProcess | |
| from toolkit.config_modules import ModelConfig, GenerateImageConfig | |
| from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ | |
| add_base_model_info_to_meta | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| from toolkit.train_tools import get_torch_dtype | |
| import random | |
| class GenerateConfig: | |
| def __init__(self, **kwargs): | |
| self.prompts: List[str] | |
| self.sampler = kwargs.get('sampler', 'ddpm') | |
| self.width = kwargs.get('width', 512) | |
| self.height = kwargs.get('height', 512) | |
| self.size_list: Union[List[int], None] = kwargs.get('size_list', None) | |
| self.neg = kwargs.get('neg', '') | |
| self.seed = kwargs.get('seed', -1) | |
| self.guidance_scale = kwargs.get('guidance_scale', 7) | |
| self.sample_steps = kwargs.get('sample_steps', 20) | |
| self.prompt_2 = kwargs.get('prompt_2', None) | |
| self.neg_2 = kwargs.get('neg_2', None) | |
| self.prompts = kwargs.get('prompts', None) | |
| self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) | |
| self.compile = kwargs.get('compile', False) | |
| self.ext = kwargs.get('ext', 'png') | |
| self.prompt_file = kwargs.get('prompt_file', False) | |
| self.num_repeats = kwargs.get('num_repeats', 1) | |
| self.prompts_in_file = self.prompts | |
| if self.prompts is None: | |
| raise ValueError("Prompts must be set") | |
| if isinstance(self.prompts, str): | |
| if os.path.exists(self.prompts): | |
| with open(self.prompts, 'r', encoding='utf-8') as f: | |
| self.prompts_in_file = f.read().splitlines() | |
| self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0] | |
| else: | |
| raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts") | |
| self.random_prompts = kwargs.get('random_prompts', False) | |
| self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1) | |
| self.max_images = kwargs.get('max_images', 10000) | |
| if self.random_prompts: | |
| self.prompts = [] | |
| for i in range(self.max_images): | |
| num_prompts = random.randint(1, self.max_random_per_prompt) | |
| prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)] | |
| self.prompts.append(", ".join(prompt_list)) | |
| else: | |
| self.prompts = self.prompts_in_file | |
| if kwargs.get('shuffle', False): | |
| # shuffle the prompts | |
| random.shuffle(self.prompts) | |
| class GenerateProcess(BaseProcess): | |
| process_id: int | |
| config: OrderedDict | |
| progress_bar: ForwardRef('tqdm') = None | |
| sd: StableDiffusion | |
| def __init__( | |
| self, | |
| process_id: int, | |
| job, | |
| config: OrderedDict | |
| ): | |
| super().__init__(process_id, job, config) | |
| self.output_folder = self.get_conf('output_folder', required=True) | |
| self.model_config = ModelConfig(**self.get_conf('model', required=True)) | |
| self.device = self.get_conf('device', self.job.device) | |
| self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) | |
| self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16')) | |
| self.progress_bar = None | |
| self.sd = StableDiffusion( | |
| device=self.device, | |
| model_config=self.model_config, | |
| dtype=self.model_config.dtype, | |
| ) | |
| print(f"Using device {self.device}") | |
| def clean_prompt(self, prompt: str): | |
| # remove any non alpha numeric characters or ,'" from prompt | |
| return ''.join(e for e in prompt if e.isalnum() or e in ", '\"") | |
| def run(self): | |
| with torch.no_grad(): | |
| super().run() | |
| print("Loading model...") | |
| self.sd.load_model() | |
| self.sd.pipeline.to(self.device, self.torch_dtype) | |
| print("Compiling model...") | |
| # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) | |
| if self.generate_config.compile: | |
| self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead") | |
| print(f"Generating {len(self.generate_config.prompts)} images") | |
| # build prompt image configs | |
| prompt_image_configs = [] | |
| for _ in range(self.generate_config.num_repeats): | |
| for prompt in self.generate_config.prompts: | |
| width = self.generate_config.width | |
| height = self.generate_config.height | |
| # prompt = self.clean_prompt(prompt) | |
| if self.generate_config.size_list is not None: | |
| # randomly select a size | |
| width, height = random.choice(self.generate_config.size_list) | |
| prompt_image_configs.append(GenerateImageConfig( | |
| prompt=prompt, | |
| prompt_2=self.generate_config.prompt_2, | |
| width=width, | |
| height=height, | |
| num_inference_steps=self.generate_config.sample_steps, | |
| guidance_scale=self.generate_config.guidance_scale, | |
| negative_prompt=self.generate_config.neg, | |
| negative_prompt_2=self.generate_config.neg_2, | |
| seed=self.generate_config.seed, | |
| guidance_rescale=self.generate_config.guidance_rescale, | |
| output_ext=self.generate_config.ext, | |
| output_folder=self.output_folder, | |
| add_prompt_file=self.generate_config.prompt_file | |
| )) | |
| # generate images | |
| self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) | |
| print("Done generating images") | |
| # cleanup | |
| del self.sd | |
| gc.collect() | |
| torch.cuda.empty_cache() | |