KrutikaBM commited on
Commit
852d224
·
verified ·
1 Parent(s): c590419

Create pipelines/pipeline_tuneavideo.py

Browse files
tuneavideo/pipelines/pipeline_tuneavideo.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from diffusers.utils import is_accelerate_available
11
+ from packaging import version
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.models import AutoencoderKL
16
+ from diffusers.pipeline_utils import DiffusionPipeline
17
+ from diffusers.schedulers import (
18
+ DDIMScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler,
21
+ EulerDiscreteScheduler,
22
+ LMSDiscreteScheduler,
23
+ PNDMScheduler,
24
+ )
25
+ from diffusers.utils import deprecate, logging, BaseOutput
26
+
27
+ from einops import rearrange
28
+
29
+ from ..models.unet import UNet3DConditionModel
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class TuneAVideoPipelineOutput(BaseOutput):
37
+ videos: Union[torch.Tensor, np.ndarray]
38
+
39
+
40
+ class TuneAVideoPipeline(DiffusionPipeline):
41
+ _optional_components = []
42
+
43
+ def __init__(
44
+ self,
45
+ vae: AutoencoderKL,
46
+ text_encoder: CLIPTextModel,
47
+ tokenizer: CLIPTokenizer,
48
+ unet: UNet3DConditionModel,
49
+ scheduler: Union[
50
+ DDIMScheduler,
51
+ PNDMScheduler,
52
+ LMSDiscreteScheduler,
53
+ EulerDiscreteScheduler,
54
+ EulerAncestralDiscreteScheduler,
55
+ DPMSolverMultistepScheduler,
56
+ ],
57
+ ):
58
+ super().__init__()
59
+
60
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
61
+ deprecation_message = (
62
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
63
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
64
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
65
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
66
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
67
+ " file"
68
+ )
69
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
70
+ new_config = dict(scheduler.config)
71
+ new_config["steps_offset"] = 1
72
+ scheduler._internal_dict = FrozenDict(new_config)
73
+
74
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
75
+ deprecation_message = (
76
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
77
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
78
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
79
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
80
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
81
+ )
82
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
83
+ new_config = dict(scheduler.config)
84
+ new_config["clip_sample"] = False
85
+ scheduler._internal_dict = FrozenDict(new_config)
86
+
87
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
88
+ version.parse(unet.config._diffusers_version).base_version
89
+ ) < version.parse("0.9.0.dev0")
90
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
91
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
92
+ deprecation_message = (
93
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
94
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
95
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
96
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
97
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
98
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
99
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
100
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
101
+ " the `unet/config.json` file"
102
+ )
103
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
104
+ new_config = dict(unet.config)
105
+ new_config["sample_size"] = 64
106
+ unet._internal_dict = FrozenDict(new_config)
107
+
108
+ self.register_modules(
109
+ vae=vae,
110
+ text_encoder=text_encoder,
111
+ tokenizer=tokenizer,
112
+ unet=unet,
113
+ scheduler=scheduler,
114
+ )
115
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
116
+
117
+ def enable_vae_slicing(self):
118
+ self.vae.enable_slicing()
119
+
120
+ def disable_vae_slicing(self):
121
+ self.vae.disable_slicing()
122
+
123
+ def enable_sequential_cpu_offload(self, gpu_id=0):
124
+ if is_accelerate_available():
125
+ from accelerate import cpu_offload
126
+ else:
127
+ raise ImportError("Please install accelerate via `pip install accelerate`")
128
+
129
+ device = torch.device(f"cuda:{gpu_id}")
130
+
131
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
132
+ if cpu_offloaded_model is not None:
133
+ cpu_offload(cpu_offloaded_model, device)
134
+
135
+
136
+ @property
137
+ def _execution_device(self):
138
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
139
+ return self.device
140
+ for module in self.unet.modules():
141
+ if (
142
+ hasattr(module, "_hf_hook")
143
+ and hasattr(module._hf_hook, "execution_device")
144
+ and module._hf_hook.execution_device is not None
145
+ ):
146
+ return torch.device(module._hf_hook.execution_device)
147
+ return self.device
148
+
149
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
150
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
151
+
152
+ text_inputs = self.tokenizer(
153
+ prompt,
154
+ padding="max_length",
155
+ max_length=self.tokenizer.model_max_length,
156
+ truncation=True,
157
+ return_tensors="pt",
158
+ )
159
+ text_input_ids = text_inputs.input_ids
160
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
161
+
162
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
163
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
164
+ logger.warning(
165
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
166
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
167
+ )
168
+
169
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
170
+ attention_mask = text_inputs.attention_mask.to(device)
171
+ else:
172
+ attention_mask = None
173
+
174
+ text_embeddings = self.text_encoder(
175
+ text_input_ids.to(device),
176
+ attention_mask=attention_mask,
177
+ )
178
+ text_embeddings = text_embeddings[0]
179
+
180
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
181
+ bs_embed, seq_len, _ = text_embeddings.shape
182
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
183
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
184
+
185
+ # get unconditional embeddings for classifier free guidance
186
+ if do_classifier_free_guidance:
187
+ uncond_tokens: List[str]
188
+ if negative_prompt is None:
189
+ uncond_tokens = [""] * batch_size
190
+ elif type(prompt) is not type(negative_prompt):
191
+ raise TypeError(
192
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
193
+ f" {type(prompt)}."
194
+ )
195
+ elif isinstance(negative_prompt, str):
196
+ uncond_tokens = [negative_prompt]
197
+ elif batch_size != len(negative_prompt):
198
+ raise ValueError(
199
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
200
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
201
+ " the batch size of `prompt`."
202
+ )
203
+ else:
204
+ uncond_tokens = negative_prompt
205
+
206
+ max_length = text_input_ids.shape[-1]
207
+ uncond_input = self.tokenizer(
208
+ uncond_tokens,
209
+ padding="max_length",
210
+ max_length=max_length,
211
+ truncation=True,
212
+ return_tensors="pt",
213
+ )
214
+
215
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
216
+ attention_mask = uncond_input.attention_mask.to(device)
217
+ else:
218
+ attention_mask = None
219
+
220
+ uncond_embeddings = self.text_encoder(
221
+ uncond_input.input_ids.to(device),
222
+ attention_mask=attention_mask,
223
+ )
224
+ uncond_embeddings = uncond_embeddings[0]
225
+
226
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
227
+ seq_len = uncond_embeddings.shape[1]
228
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
229
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ # For classifier free guidance, we need to do two forward passes.
232
+ # Here we concatenate the unconditional and text embeddings into a single batch
233
+ # to avoid doing two forward passes
234
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
235
+
236
+ return text_embeddings
237
+
238
+ def decode_latents(self, latents):
239
+ video_length = latents.shape[2]
240
+ latents = 1 / 0.18215 * latents
241
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
242
+ video = self.vae.decode(latents).sample
243
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
244
+ video = (video / 2 + 0.5).clamp(0, 1)
245
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
246
+ video = video.cpu().float().numpy()
247
+ return video
248
+
249
+ def prepare_extra_step_kwargs(self, generator, eta):
250
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
251
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
252
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
253
+ # and should be between [0, 1]
254
+
255
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
256
+ extra_step_kwargs = {}
257
+ if accepts_eta:
258
+ extra_step_kwargs["eta"] = eta
259
+
260
+ # check if the scheduler accepts generator
261
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
262
+ if accepts_generator:
263
+ extra_step_kwargs["generator"] = generator
264
+ return extra_step_kwargs
265
+
266
+ def check_inputs(self, prompt, height, width, callback_steps):
267
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
268
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
269
+
270
+ if height % 8 != 0 or width % 8 != 0:
271
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
272
+
273
+ if (callback_steps is None) or (
274
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
275
+ ):
276
+ raise ValueError(
277
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
278
+ f" {type(callback_steps)}."
279
+ )
280
+
281
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
282
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
283
+ if isinstance(generator, list) and len(generator) != batch_size:
284
+ raise ValueError(
285
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
286
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
287
+ )
288
+
289
+ if latents is None:
290
+ rand_device = "cpu" if device.type == "mps" else device
291
+
292
+ if isinstance(generator, list):
293
+ shape = (1,) + shape[1:]
294
+ latents = [
295
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
296
+ for i in range(batch_size)
297
+ ]
298
+ latents = torch.cat(latents, dim=0).to(device)
299
+ else:
300
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
301
+ else:
302
+ if latents.shape != shape:
303
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
304
+ latents = latents.to(device)
305
+
306
+ # scale the initial noise by the standard deviation required by the scheduler
307
+ latents = latents * self.scheduler.init_noise_sigma
308
+ return latents
309
+
310
+ @torch.no_grad()
311
+ def __call__(
312
+ self,
313
+ prompt: Union[str, List[str]],
314
+ video_length: Optional[int],
315
+ height: Optional[int] = None,
316
+ width: Optional[int] = None,
317
+ num_inference_steps: int = 50,
318
+ guidance_scale: float = 7.5,
319
+ negative_prompt: Optional[Union[str, List[str]]] = None,
320
+ num_videos_per_prompt: Optional[int] = 1,
321
+ eta: float = 0.0,
322
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
323
+ latents: Optional[torch.FloatTensor] = None,
324
+ output_type: Optional[str] = "tensor",
325
+ return_dict: bool = True,
326
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
327
+ callback_steps: Optional[int] = 1,
328
+ **kwargs,
329
+ ):
330
+ # Default height and width to unet
331
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
332
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
333
+
334
+ # Check inputs. Raise error if not correct
335
+ self.check_inputs(prompt, height, width, callback_steps)
336
+
337
+ # Define call parameters
338
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
339
+ device = self._execution_device
340
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
341
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
342
+ # corresponds to doing no classifier free guidance.
343
+ do_classifier_free_guidance = guidance_scale > 1.0
344
+
345
+ # Encode input prompt
346
+ text_embeddings = self._encode_prompt(
347
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
348
+ )
349
+
350
+ # Prepare timesteps
351
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
352
+ timesteps = self.scheduler.timesteps
353
+
354
+ # Prepare latent variables
355
+ num_channels_latents = self.unet.in_channels
356
+ latents = self.prepare_latents(
357
+ batch_size * num_videos_per_prompt,
358
+ num_channels_latents,
359
+ video_length,
360
+ height,
361
+ width,
362
+ text_embeddings.dtype,
363
+ device,
364
+ generator,
365
+ latents,
366
+ )
367
+ latents_dtype = latents.dtype
368
+
369
+ # Prepare extra step kwargs.
370
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
371
+
372
+ # Denoising loop
373
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
374
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
375
+ for i, t in enumerate(timesteps):
376
+ # expand the latents if we are doing classifier free guidance
377
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
378
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
379
+
380
+ # predict the noise residual
381
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
382
+
383
+ # perform guidance
384
+ if do_classifier_free_guidance:
385
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
386
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
387
+
388
+ # compute the previous noisy sample x_t -> x_t-1
389
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
390
+
391
+ # call the callback, if provided
392
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
393
+ progress_bar.update()
394
+ if callback is not None and i % callback_steps == 0:
395
+ callback(i, t, latents)
396
+
397
+ # Post-processing
398
+ video = self.decode_latents(latents)
399
+
400
+ # Convert to tensor
401
+ if output_type == "tensor":
402
+ video = torch.from_numpy(video)
403
+
404
+ if not return_dict:
405
+ return video
406
+
407
+ return TuneAVideoPipelineOutput(videos=video)