ai-toolkit / toolkit /samplers /mean_flow_scheduler.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 430 files
3cc1e25 verified
raw
history blame
3.04 kB
from typing import Union
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
from dataclasses import dataclass
from typing import Optional, Tuple
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
self.timestep_type = "linear"
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# Create linear timesteps from 1000 to 1
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
pass
def get_weights_for_timesteps(
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
weights = 1.0
# Get the weights for the timesteps
if timestep_type == "weighted":
weights = torch.tensor(
[default_weighing_scheme[i] for i in step_indices],
device=timesteps.device,
dtype=timesteps.dtype,
)
return weights
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
return noisy_model_input
def scale_model_input(
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, **kwargs):
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
**kwargs: Optional[dict],
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
# single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ)
output = sample - model_output
if not return_dict:
return (output,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)