|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
|
|
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
def random_noise( |
|
tensor: torch.Tensor = None, |
|
shape: Tuple[int] = None, |
|
dtype: torch.dtype = None, |
|
device: torch.device = None, |
|
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, |
|
noise_offset: Optional[float] = None, |
|
) -> torch.Tensor: |
|
if tensor is not None: |
|
shape = tensor.shape |
|
device = tensor.device |
|
dtype = tensor.dtype |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
noise = randn_tensor(shape, dtype=dtype, device=device, generator=generator) |
|
if noise_offset is not None: |
|
|
|
noise += noise_offset * torch.randn( |
|
(tensor.shape[0], tensor.shape[1], 1, 1, 1), device |
|
) |
|
return noise |
|
|
|
|
|
def video_fusion_noise( |
|
tensor: torch.Tensor = None, |
|
shape: Tuple[int] = None, |
|
dtype: torch.dtype = None, |
|
device: torch.device = None, |
|
w_ind_noise: float = 0.5, |
|
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, |
|
initial_common_noise: torch.Tensor = None, |
|
) -> torch.Tensor: |
|
if tensor is not None: |
|
shape = tensor.shape |
|
device = tensor.device |
|
dtype = tensor.dtype |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
batch_size, c, t, h, w = shape |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
if not isinstance(generator, list): |
|
if initial_common_noise is not None: |
|
common_noise = initial_common_noise.to(device, dtype=dtype) |
|
else: |
|
common_noise = randn_tensor( |
|
(shape[0], shape[1], 1, shape[3], shape[4]), |
|
generator=generator, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
ind_noise = randn_tensor( |
|
shape, |
|
generator=generator, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
s = torch.tensor(w_ind_noise, device=device, dtype=dtype) |
|
latents = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise |
|
else: |
|
latents = [] |
|
for i in range(batch_size): |
|
latent = video_fusion_noise( |
|
shape=(1, c, t, h, w), |
|
dtype=dtype, |
|
device=device, |
|
w_ind_noise=w_ind_noise, |
|
generator=generator[i], |
|
initial_common_noise=initial_common_noise, |
|
) |
|
latents.append(latent) |
|
latents = torch.cat(latents, dim=0).to(device) |
|
return latents |
|
|