Fabrice-TIERCELIN commited on
Commit
fc6764a
·
verified ·
1 Parent(s): 9ecbf68

z scheduling_flow_match_discrete.py

Browse files
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ Euler scheduler.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ timestep_spacing (`str`, defaults to `"linspace"`):
59
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
+ shift (`float`, defaults to 1.0):
62
+ The shift value for the timestep schedule.
63
+ reverse (`bool`, defaults to `True`):
64
+ Whether to reverse the timestep schedule.
65
+ """
66
+
67
+ _compatibles = []
68
+ order = 1
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_train_timesteps: int = 1000,
74
+ shift: float = 1.0,
75
+ reverse: bool = True,
76
+ solver: str = "euler",
77
+ n_tokens: Optional[int] = None,
78
+ ):
79
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
+
81
+ if not reverse:
82
+ sigmas = sigmas.flip(0)
83
+
84
+ self.sigmas = sigmas
85
+ # the value fed to model
86
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
+
88
+ self._step_index = None
89
+ self._begin_index = None
90
+
91
+ self.supported_solver = ["euler"]
92
+ if solver not in self.supported_solver:
93
+ raise ValueError(
94
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
+ )
96
+
97
+ @property
98
+ def step_index(self):
99
+ """
100
+ The index counter for current timestep. It will increase 1 after each scheduler step.
101
+ """
102
+ return self._step_index
103
+
104
+ @property
105
+ def begin_index(self):
106
+ """
107
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
+ """
109
+ return self._begin_index
110
+
111
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
+ def set_begin_index(self, begin_index: int = 0):
113
+ """
114
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
+
116
+ Args:
117
+ begin_index (`int`):
118
+ The begin index for the scheduler.
119
+ """
120
+ self._begin_index = begin_index
121
+
122
+ def _sigma_to_t(self, sigma):
123
+ return sigma * self.config.num_train_timesteps
124
+
125
+ def set_timesteps(
126
+ self,
127
+ num_inference_steps: int,
128
+ device: Union[str, torch.device] = None,
129
+ n_tokens: int = None,
130
+ ):
131
+ """
132
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
+
134
+ Args:
135
+ num_inference_steps (`int`):
136
+ The number of diffusion steps used when generating samples with a pre-trained model.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ n_tokens (`int`, *optional*):
140
+ Number of tokens in the input sequence.
141
+ """
142
+ self.num_inference_steps = num_inference_steps
143
+
144
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
+ sigmas = self.sd3_time_shift(sigmas)
146
+
147
+ if not self.config.reverse:
148
+ sigmas = 1 - sigmas
149
+
150
+ self.sigmas = sigmas
151
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
+ dtype=torch.float32, device=device
153
+ )
154
+
155
+ # Reset step index
156
+ self._step_index = None
157
+
158
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
159
+ if schedule_timesteps is None:
160
+ schedule_timesteps = self.timesteps
161
+
162
+ indices = (schedule_timesteps == timestep).nonzero()
163
+
164
+ # The sigma index that is taken for the **very** first `step`
165
+ # is always the second index (or the last index if there is only 1)
166
+ # This way we can ensure we don't accidentally skip a sigma in
167
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
+ pos = 1 if len(indices) > 1 else 0
169
+
170
+ return indices[pos].item()
171
+
172
+ def _init_step_index(self, timestep):
173
+ if self.begin_index is None:
174
+ if isinstance(timestep, torch.Tensor):
175
+ timestep = timestep.to(self.timesteps.device)
176
+ self._step_index = self.index_for_timestep(timestep)
177
+ else:
178
+ self._step_index = self._begin_index
179
+
180
+ def scale_model_input(
181
+ self, sample: torch.Tensor, timestep: Optional[int] = None
182
+ ) -> torch.Tensor:
183
+ return sample
184
+
185
+ def sd3_time_shift(self, t: torch.Tensor):
186
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
+
188
+ def step(
189
+ self,
190
+ model_output: torch.FloatTensor,
191
+ timestep: Union[float, torch.FloatTensor],
192
+ sample: torch.FloatTensor,
193
+ return_dict: bool = True,
194
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
+ """
196
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
+ process from the learned model outputs (most often the predicted noise).
198
+
199
+ Args:
200
+ model_output (`torch.FloatTensor`):
201
+ The direct output from learned diffusion model.
202
+ timestep (`float`):
203
+ The current discrete timestep in the diffusion chain.
204
+ sample (`torch.FloatTensor`):
205
+ A current instance of a sample created by the diffusion process.
206
+ generator (`torch.Generator`, *optional*):
207
+ A random number generator.
208
+ n_tokens (`int`, *optional*):
209
+ Number of tokens in the input sequence.
210
+ return_dict (`bool`):
211
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
+ tuple.
213
+
214
+ Returns:
215
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
218
+ """
219
+
220
+ if (
221
+ isinstance(timestep, int)
222
+ or isinstance(timestep, torch.IntTensor)
223
+ or isinstance(timestep, torch.LongTensor)
224
+ ):
225
+ raise ValueError(
226
+ (
227
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
+ " one of the `scheduler.timesteps` as a timestep."
230
+ ),
231
+ )
232
+
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ # Upcast to avoid precision issues when computing prev_sample
237
+ sample = sample.to(torch.float32)
238
+
239
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
+
241
+ if self.config.solver == "euler":
242
+ prev_sample = sample + model_output.to(torch.float32) * dt
243
+ else:
244
+ raise ValueError(
245
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
+ )
247
+
248
+ # upon completion increase step index by one
249
+ self._step_index += 1
250
+
251
+ if not return_dict:
252
+ return (prev_sample,)
253
+
254
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
+
256
+ def __len__(self):
257
+ return self.config.num_train_timesteps