Spaces:
Runtime error
Runtime error
File size: 5,247 Bytes
e276be2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import ListConfig
import math
from ...modules.diffusionmodules.sampling import VideoDDIMSampler, VPSDEDPMPP2MSampler
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
# import rearrange
from einops import rearrange
import random
from sat import mpu
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
super().__init__()
assert type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.type = type
self.offset_noise_level = offset_noise_level
if type == "lpips":
self.lpips = LPIPS().eval()
if not batch2model_keys:
batch2model_keys = []
if isinstance(batch2model_keys, str):
batch2model_keys = [batch2model_keys]
self.batch2model_keys = set(batch2model_keys)
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = (
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
)
noise = noise.to(input.dtype)
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)
w = append_dims(denoiser.w(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
elif self.type == "l1":
return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
class VideoDiffusionLoss(StandardDiffusionLoss):
def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs):
self.fixed_frames = fixed_frames
self.block_scale = block_scale
self.block_size = block_size
self.min_snr_value = min_snr_value
super().__init__(**kwargs)
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
idx = idx.to(input.device)
noise = torch.randn_like(input)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group())
additional_model_inputs["idx"] = idx
if self.offset_noise_level > 0.0:
noise = (
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
)
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
(1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim
)
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
if self.min_snr_value is not None:
w = min(w, self.min_snr_value)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
elif self.type == "l1":
return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
def get_3d_position_ids(frame_len, h, w):
i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
return position_ids
|