Spaces:
Running
Running
| import glob | |
| import os | |
| from collections import OrderedDict | |
| import random | |
| from typing import Optional, List | |
| from safetensors.torch import save_file, load_file | |
| from tqdm import tqdm | |
| from toolkit.layers import ReductionKernel | |
| from toolkit.stable_diffusion_model import PromptEmbeds | |
| from toolkit.train_tools import get_torch_dtype, apply_noise_offset | |
| import gc | |
| from toolkit import train_tools | |
| import torch | |
| from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion | |
| def flush(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| class RescaleConfig: | |
| def __init__( | |
| self, | |
| **kwargs | |
| ): | |
| self.from_resolution = kwargs.get('from_resolution', 512) | |
| self.scale = kwargs.get('scale', 0.5) | |
| self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None) | |
| self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000) | |
| self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale)) | |
| self.prompt_dropout = kwargs.get('prompt_dropout', 0.1) | |
| class PromptEmbedsCache: | |
| prompts: dict[str, PromptEmbeds] = {} | |
| def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: | |
| self.prompts[__name] = __value | |
| def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: | |
| if __name in self.prompts: | |
| return self.prompts[__name] | |
| else: | |
| return None | |
| class TrainSDRescaleProcess(BaseSDTrainProcess): | |
| def __init__(self, process_id: int, job, config: OrderedDict): | |
| # pass our custom pipeline to super so it sets it up | |
| super().__init__(process_id, job, config) | |
| self.step_num = 0 | |
| self.start_step = 0 | |
| self.device = self.get_conf('device', self.job.device) | |
| self.device_torch = torch.device(self.device) | |
| self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True)) | |
| self.reduce_size_fn = ReductionKernel( | |
| in_channels=4, | |
| kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution), | |
| dtype=get_torch_dtype(self.train_config.dtype), | |
| device=self.device_torch, | |
| ) | |
| self.latent_paths: List[str] = [] | |
| self.empty_embedding: PromptEmbeds = None | |
| def before_model_load(self): | |
| pass | |
| def get_latent_tensors(self): | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| num_to_generate = 0 | |
| # check if dir exists | |
| if not os.path.exists(self.rescale_config.latent_tensor_dir): | |
| os.makedirs(self.rescale_config.latent_tensor_dir) | |
| num_to_generate = self.rescale_config.num_latent_tensors | |
| else: | |
| # find existing | |
| current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors")) | |
| num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list) | |
| self.latent_paths = current_tensor_list | |
| if num_to_generate > 0: | |
| print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors") | |
| # unload other model | |
| self.sd.unet.to('cpu') | |
| # load aux network | |
| self.sd_parent = StableDiffusion( | |
| self.device_torch, | |
| model_config=self.model_config, | |
| dtype=self.train_config.dtype, | |
| ) | |
| self.sd_parent.load_model() | |
| self.sd_parent.unet.to(self.device_torch, dtype=dtype) | |
| # we dont need text encoder for this | |
| del self.sd_parent.text_encoder | |
| del self.sd_parent.tokenizer | |
| self.sd_parent.unet.eval() | |
| self.sd_parent.unet.requires_grad_(False) | |
| # save current seed state for training | |
| rng_state = torch.get_rng_state() | |
| cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None | |
| text_embeddings = train_tools.concat_prompt_embeddings( | |
| self.empty_embedding, # unconditional (negative prompt) | |
| self.empty_embedding, # conditional (positive prompt) | |
| self.train_config.batch_size, | |
| ) | |
| torch.set_default_device(self.device_torch) | |
| for i in tqdm(range(num_to_generate)): | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| # get a random seed | |
| seed = torch.randint(0, 2 ** 32, (1,)).item() | |
| # zero pad seed string to max length | |
| seed_string = str(seed).zfill(10) | |
| # set seed | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| # # ger a random number of steps | |
| timesteps_to = self.train_config.max_denoising_steps | |
| # set the scheduler to the number of steps | |
| self.sd.noise_scheduler.set_timesteps( | |
| timesteps_to, device=self.device_torch | |
| ) | |
| noise = self.sd.get_latent_noise( | |
| pixel_height=self.rescale_config.from_resolution, | |
| pixel_width=self.rescale_config.from_resolution, | |
| batch_size=self.train_config.batch_size, | |
| noise_offset=self.train_config.noise_offset, | |
| ).to(self.device_torch, dtype=dtype) | |
| # get latents | |
| latents = noise * self.sd.noise_scheduler.init_noise_sigma | |
| latents = latents.to(self.device_torch, dtype=dtype) | |
| # get random guidance scale from 1.0 to 10.0 (CFG) | |
| guidance_scale = torch.rand(1).item() * 9.0 + 1.0 | |
| # do a timestep of 1 | |
| timestep = 1 | |
| noise_pred_target = self.sd_parent.predict_noise( | |
| latents, | |
| text_embeddings=text_embeddings, | |
| timestep=timestep, | |
| guidance_scale=guidance_scale | |
| ) | |
| # build state dict | |
| state_dict = OrderedDict() | |
| state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16) | |
| state_dict['latents'] = latents.to('cpu', dtype=torch.float16) | |
| state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16) | |
| state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16) | |
| state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16) | |
| state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow | |
| file_name = f"{seed_string}_{i}.safetensors" | |
| file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name) | |
| save_file(state_dict, file_path) | |
| self.latent_paths.append(file_path) | |
| print("Removing parent model") | |
| # delete parent | |
| del self.sd_parent | |
| flush() | |
| torch.set_rng_state(rng_state) | |
| if cuda_rng_state is not None: | |
| torch.cuda.set_rng_state(cuda_rng_state) | |
| self.sd.unet.to(self.device_torch, dtype=dtype) | |
| def hook_before_train_loop(self): | |
| # encode our empty prompt | |
| self.empty_embedding = self.sd.encode_prompt("") | |
| self.empty_embedding = self.empty_embedding.to(self.device_torch, | |
| dtype=get_torch_dtype(self.train_config.dtype)) | |
| # Move train model encoder to cpu | |
| if isinstance(self.sd.text_encoder, list): | |
| for encoder in self.sd.text_encoder: | |
| encoder.to('cpu') | |
| encoder.eval() | |
| encoder.requires_grad_(False) | |
| else: | |
| self.sd.text_encoder.to('cpu') | |
| self.sd.text_encoder.eval() | |
| self.sd.text_encoder.requires_grad_(False) | |
| # self.sd.unet.to('cpu') | |
| flush() | |
| self.get_latent_tensors() | |
| flush() | |
| # end hook_before_train_loop | |
| def hook_train_loop(self, batch): | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| loss_function = torch.nn.MSELoss() | |
| # train it | |
| # Begin gradient accumulation | |
| self.sd.unet.train() | |
| self.sd.unet.requires_grad_(True) | |
| self.sd.unet.to(self.device_torch, dtype=dtype) | |
| with torch.no_grad(): | |
| self.optimizer.zero_grad() | |
| # pick random latent tensor | |
| latent_path = random.choice(self.latent_paths) | |
| latent_tensor = load_file(latent_path) | |
| noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype) | |
| latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype) | |
| guidance_scale = (latent_tensor['guidance_scale']).item() | |
| timestep = int((latent_tensor['timestep']).item()) | |
| timesteps_to = int((latent_tensor['timesteps_to']).item()) | |
| # seed = int((latent_tensor['seed']).item()) | |
| text_embeddings = train_tools.concat_prompt_embeddings( | |
| self.empty_embedding, # unconditional (negative prompt) | |
| self.empty_embedding, # conditional (positive prompt) | |
| self.train_config.batch_size, | |
| ) | |
| self.sd.noise_scheduler.set_timesteps( | |
| timesteps_to, device=self.device_torch | |
| ) | |
| denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample | |
| # get the reduced latents | |
| # reduced_pred = self.reduce_size_fn(noise_pred_target.detach()) | |
| denoised_target = self.reduce_size_fn(denoised_target.detach()) | |
| reduced_latents = self.reduce_size_fn(latents.detach()) | |
| denoised_target.requires_grad = False | |
| self.optimizer.zero_grad() | |
| noise_pred_train = self.sd.predict_noise( | |
| reduced_latents, | |
| text_embeddings=text_embeddings, | |
| timestep=timestep, | |
| guidance_scale=guidance_scale | |
| ) | |
| denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample | |
| loss = loss_function(denoised_pred, denoised_target) | |
| loss_float = loss.item() | |
| loss.backward() | |
| self.optimizer.step() | |
| self.lr_scheduler.step() | |
| self.optimizer.zero_grad() | |
| flush() | |
| loss_dict = OrderedDict( | |
| {'loss': loss_float}, | |
| ) | |
| return loss_dict | |
| # end hook_train_loop | |