|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...models.modeling_utils import ModelMixin |
|
|
|
|
|
class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): |
|
""" |
|
This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. |
|
|
|
It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image |
|
embeddings. |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
embedding_dim: int = 768, |
|
): |
|
super().__init__() |
|
|
|
self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) |
|
self.std = nn.Parameter(torch.ones(1, embedding_dim)) |
|
|
|
def scale(self, embeds): |
|
embeds = (embeds - self.mean) * 1.0 / self.std |
|
return embeds |
|
|
|
def unscale(self, embeds): |
|
embeds = (embeds * self.std) + self.mean |
|
return embeds |
|
|