Spaces:
Runtime error
Runtime error
File size: 2,650 Bytes
e276be2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from functools import partial
import math
import torch
from einops import rearrange, repeat
from ...util import append_dims, default, instantiate_from_config
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
pass
class VanillaCFG:
"""
implements parallelized CFG
"""
def __init__(self, scale, dyn_thresh_config=None):
self.scale = scale
scale_schedule = lambda scale, sigma: scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
)
)
def __call__(self, x, sigma, scale=None):
x_u, x_c = x.chunk(2)
scale_value = default(scale, self.scale_schedule(sigma))
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class DynamicCFG(VanillaCFG):
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
super().__init__(scale, dyn_thresh_config)
scale_schedule = (
lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
)
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
)
)
def __call__(self, x, sigma, step_index, scale=None):
x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma, step_index.item())
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred
class IdentityGuider:
def __call__(self, x, sigma):
return x
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
|