Spaces:
Runtime error
Runtime error
| import torch.nn.functional as F | |
| from typing import Tuple | |
| import torch | |
| from model.base import BaseModel | |
| class CausVid(BaseModel): | |
| def __init__(self, args, device): | |
| """ | |
| Initialize the DMD (Distribution Matching Distillation) module. | |
| This class is self-contained and compute generator and fake score losses | |
| in the forward pass. | |
| """ | |
| super().__init__(args, device) | |
| self.num_frame_per_block = getattr(args, "num_frame_per_block", 1) | |
| self.num_training_frames = getattr(args, "num_training_frames", 21) | |
| if self.num_frame_per_block > 1: | |
| self.generator.model.num_frame_per_block = self.num_frame_per_block | |
| self.independent_first_frame = getattr(args, "independent_first_frame", False) | |
| if self.independent_first_frame: | |
| self.generator.model.independent_first_frame = True | |
| if args.gradient_checkpointing: | |
| self.generator.enable_gradient_checkpointing() | |
| self.fake_score.enable_gradient_checkpointing() | |
| # Step 2: Initialize all dmd hyperparameters | |
| self.num_train_timestep = args.num_train_timestep | |
| self.min_step = int(0.02 * self.num_train_timestep) | |
| self.max_step = int(0.98 * self.num_train_timestep) | |
| if hasattr(args, "real_guidance_scale"): | |
| self.real_guidance_scale = args.real_guidance_scale | |
| self.fake_guidance_scale = args.fake_guidance_scale | |
| else: | |
| self.real_guidance_scale = args.guidance_scale | |
| self.fake_guidance_scale = 0.0 | |
| self.timestep_shift = getattr(args, "timestep_shift", 1.0) | |
| self.teacher_forcing = getattr(args, "teacher_forcing", False) | |
| if getattr(self.scheduler, "alphas_cumprod", None) is not None: | |
| self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device) | |
| else: | |
| self.scheduler.alphas_cumprod = None | |
| def _compute_kl_grad( | |
| self, noisy_image_or_video: torch.Tensor, | |
| estimated_clean_image_or_video: torch.Tensor, | |
| timestep: torch.Tensor, | |
| conditional_dict: dict, unconditional_dict: dict, | |
| normalization: bool = True | |
| ) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828). | |
| Input: | |
| - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video. | |
| - timestep: a tensor with shape [B, F] containing the randomly generated timestep. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - normalization: a boolean indicating whether to normalize the gradient. | |
| Output: | |
| - kl_grad: a tensor representing the KL grad. | |
| - kl_log_dict: a dictionary containing the intermediate tensors for logging. | |
| """ | |
| # Step 1: Compute the fake score | |
| _, pred_fake_image_cond = self.fake_score( | |
| noisy_image_or_video=noisy_image_or_video, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep | |
| ) | |
| if self.fake_guidance_scale != 0.0: | |
| _, pred_fake_image_uncond = self.fake_score( | |
| noisy_image_or_video=noisy_image_or_video, | |
| conditional_dict=unconditional_dict, | |
| timestep=timestep | |
| ) | |
| pred_fake_image = pred_fake_image_cond + ( | |
| pred_fake_image_cond - pred_fake_image_uncond | |
| ) * self.fake_guidance_scale | |
| else: | |
| pred_fake_image = pred_fake_image_cond | |
| # Step 2: Compute the real score | |
| # We compute the conditional and unconditional prediction | |
| # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598) | |
| _, pred_real_image_cond = self.real_score( | |
| noisy_image_or_video=noisy_image_or_video, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep | |
| ) | |
| _, pred_real_image_uncond = self.real_score( | |
| noisy_image_or_video=noisy_image_or_video, | |
| conditional_dict=unconditional_dict, | |
| timestep=timestep | |
| ) | |
| pred_real_image = pred_real_image_cond + ( | |
| pred_real_image_cond - pred_real_image_uncond | |
| ) * self.real_guidance_scale | |
| # Step 3: Compute the DMD gradient (DMD paper eq. 7). | |
| grad = (pred_fake_image - pred_real_image) | |
| # TODO: Change the normalizer for causal teacher | |
| if normalization: | |
| # Step 4: Gradient normalization (DMD paper eq. 8). | |
| p_real = (estimated_clean_image_or_video - pred_real_image) | |
| normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True) | |
| grad = grad / normalizer | |
| grad = torch.nan_to_num(grad) | |
| return grad, { | |
| "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(), | |
| "timestep": timestep.detach() | |
| } | |
| def compute_distribution_matching_loss( | |
| self, | |
| image_or_video: torch.Tensor, | |
| conditional_dict: dict, | |
| unconditional_dict: dict, | |
| gradient_mask: torch.Tensor = None, | |
| ) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828). | |
| Input: | |
| - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss . | |
| Output: | |
| - dmd_loss: a scalar tensor representing the DMD loss. | |
| - dmd_log_dict: a dictionary containing the intermediate tensors for logging. | |
| """ | |
| original_latent = image_or_video | |
| batch_size, num_frame = image_or_video.shape[:2] | |
| with torch.no_grad(): | |
| # Step 1: Randomly sample timestep based on the given schedule and corresponding noise | |
| timestep = self._get_timestep( | |
| 0, | |
| self.num_train_timestep, | |
| batch_size, | |
| num_frame, | |
| self.num_frame_per_block, | |
| uniform_timestep=True | |
| ) | |
| if self.timestep_shift > 1: | |
| timestep = self.timestep_shift * \ | |
| (timestep / 1000) / \ | |
| (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000 | |
| timestep = timestep.clamp(self.min_step, self.max_step) | |
| noise = torch.randn_like(image_or_video) | |
| noisy_latent = self.scheduler.add_noise( | |
| image_or_video.flatten(0, 1), | |
| noise.flatten(0, 1), | |
| timestep.flatten(0, 1) | |
| ).detach().unflatten(0, (batch_size, num_frame)) | |
| # Step 2: Compute the KL grad | |
| grad, dmd_log_dict = self._compute_kl_grad( | |
| noisy_image_or_video=noisy_latent, | |
| estimated_clean_image_or_video=original_latent, | |
| timestep=timestep, | |
| conditional_dict=conditional_dict, | |
| unconditional_dict=unconditional_dict | |
| ) | |
| if gradient_mask is not None: | |
| dmd_loss = 0.5 * F.mse_loss(original_latent.double( | |
| )[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean") | |
| else: | |
| dmd_loss = 0.5 * F.mse_loss(original_latent.double( | |
| ), (original_latent.double() - grad.double()).detach(), reduction="mean") | |
| return dmd_loss, dmd_log_dict | |
| def _run_generator( | |
| self, | |
| image_or_video_shape, | |
| conditional_dict: dict, | |
| clean_latent: torch.tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Optionally simulate the generator's input from noise using backward simulation | |
| and then run the generator for one-step. | |
| Input: | |
| - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. | |
| - initial_latent: a tensor containing the initial latents [B, F, C, H, W]. | |
| Output: | |
| - pred_image: a tensor with shape [B, F, C, H, W]. | |
| """ | |
| simulated_noisy_input = [] | |
| for timestep in self.denoising_step_list: | |
| noise = torch.randn( | |
| image_or_video_shape, device=self.device, dtype=self.dtype) | |
| noisy_timestep = timestep * torch.ones( | |
| image_or_video_shape[:2], device=self.device, dtype=torch.long) | |
| if timestep != 0: | |
| noisy_image = self.scheduler.add_noise( | |
| clean_latent.flatten(0, 1), | |
| noise.flatten(0, 1), | |
| noisy_timestep.flatten(0, 1) | |
| ).unflatten(0, image_or_video_shape[:2]) | |
| else: | |
| noisy_image = clean_latent | |
| simulated_noisy_input.append(noisy_image) | |
| simulated_noisy_input = torch.stack(simulated_noisy_input, dim=1) | |
| # Step 2: Randomly sample a timestep and pick the corresponding input | |
| index = self._get_timestep( | |
| 0, | |
| len(self.denoising_step_list), | |
| image_or_video_shape[0], | |
| image_or_video_shape[1], | |
| self.num_frame_per_block, | |
| uniform_timestep=False | |
| ) | |
| # select the corresponding timestep's noisy input from the stacked tensor [B, T, F, C, H, W] | |
| noisy_input = torch.gather( | |
| simulated_noisy_input, dim=1, | |
| index=index.reshape(index.shape[0], 1, index.shape[1], 1, 1, 1).expand( | |
| -1, -1, -1, *image_or_video_shape[2:]).to(self.device) | |
| ).squeeze(1) | |
| timestep = self.denoising_step_list[index].to(self.device) | |
| _, pred_image_or_video = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| clean_x=clean_latent if self.teacher_forcing else None, | |
| ) | |
| gradient_mask = None # timestep != 0 | |
| pred_image_or_video = pred_image_or_video.type_as(noisy_input) | |
| return pred_image_or_video, gradient_mask | |
| def generator_loss( | |
| self, | |
| image_or_video_shape, | |
| conditional_dict: dict, | |
| unconditional_dict: dict, | |
| clean_latent: torch.Tensor, | |
| initial_latent: torch.Tensor = None | |
| ) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Generate image/videos from noise and compute the DMD loss. | |
| The noisy input to the generator is backward simulated. | |
| This removes the need of any datasets during distillation. | |
| See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details. | |
| Input: | |
| - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. | |
| Output: | |
| - loss: a scalar tensor representing the generator loss. | |
| - generator_log_dict: a dictionary containing the intermediate tensors for logging. | |
| """ | |
| # Step 1: Run generator on backward simulated noisy input | |
| pred_image, gradient_mask = self._run_generator( | |
| image_or_video_shape=image_or_video_shape, | |
| conditional_dict=conditional_dict, | |
| clean_latent=clean_latent | |
| ) | |
| # Step 2: Compute the DMD loss | |
| dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss( | |
| image_or_video=pred_image, | |
| conditional_dict=conditional_dict, | |
| unconditional_dict=unconditional_dict, | |
| gradient_mask=gradient_mask | |
| ) | |
| # Step 3: TODO: Implement the GAN loss | |
| return dmd_loss, dmd_log_dict | |
| def critic_loss( | |
| self, | |
| image_or_video_shape, | |
| conditional_dict: dict, | |
| unconditional_dict: dict, | |
| clean_latent: torch.Tensor, | |
| initial_latent: torch.Tensor = None | |
| ) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Generate image/videos from noise and train the critic with generated samples. | |
| The noisy input to the generator is backward simulated. | |
| This removes the need of any datasets during distillation. | |
| See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details. | |
| Input: | |
| - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. | |
| Output: | |
| - loss: a scalar tensor representing the generator loss. | |
| - critic_log_dict: a dictionary containing the intermediate tensors for logging. | |
| """ | |
| # Step 1: Run generator on backward simulated noisy input | |
| with torch.no_grad(): | |
| generated_image, _ = self._run_generator( | |
| image_or_video_shape=image_or_video_shape, | |
| conditional_dict=conditional_dict, | |
| clean_latent=clean_latent | |
| ) | |
| # Step 2: Compute the fake prediction | |
| critic_timestep = self._get_timestep( | |
| 0, | |
| self.num_train_timestep, | |
| image_or_video_shape[0], | |
| image_or_video_shape[1], | |
| self.num_frame_per_block, | |
| uniform_timestep=True | |
| ) | |
| if self.timestep_shift > 1: | |
| critic_timestep = self.timestep_shift * \ | |
| (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000 | |
| critic_timestep = critic_timestep.clamp(self.min_step, self.max_step) | |
| critic_noise = torch.randn_like(generated_image) | |
| noisy_generated_image = self.scheduler.add_noise( | |
| generated_image.flatten(0, 1), | |
| critic_noise.flatten(0, 1), | |
| critic_timestep.flatten(0, 1) | |
| ).unflatten(0, image_or_video_shape[:2]) | |
| _, pred_fake_image = self.fake_score( | |
| noisy_image_or_video=noisy_generated_image, | |
| conditional_dict=conditional_dict, | |
| timestep=critic_timestep | |
| ) | |
| # Step 3: Compute the denoising loss for the fake critic | |
| if self.args.denoising_loss_type == "flow": | |
| from utils.wan_wrapper import WanDiffusionWrapper | |
| flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred( | |
| scheduler=self.scheduler, | |
| x0_pred=pred_fake_image.flatten(0, 1), | |
| xt=noisy_generated_image.flatten(0, 1), | |
| timestep=critic_timestep.flatten(0, 1) | |
| ) | |
| pred_fake_noise = None | |
| else: | |
| flow_pred = None | |
| pred_fake_noise = self.scheduler.convert_x0_to_noise( | |
| x0=pred_fake_image.flatten(0, 1), | |
| xt=noisy_generated_image.flatten(0, 1), | |
| timestep=critic_timestep.flatten(0, 1) | |
| ).unflatten(0, image_or_video_shape[:2]) | |
| denoising_loss = self.denoising_loss_func( | |
| x=generated_image.flatten(0, 1), | |
| x_pred=pred_fake_image.flatten(0, 1), | |
| noise=critic_noise.flatten(0, 1), | |
| noise_pred=pred_fake_noise, | |
| alphas_cumprod=self.scheduler.alphas_cumprod, | |
| timestep=critic_timestep.flatten(0, 1), | |
| flow_pred=flow_pred | |
| ) | |
| # Step 4: TODO: Compute the GAN loss | |
| # Step 5: Debugging Log | |
| critic_log_dict = { | |
| "critic_timestep": critic_timestep.detach() | |
| } | |
| return denoising_loss, critic_log_dict | |