|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
|
|
from ...configuration_utils import FrozenDict |
|
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
|
from ...utils import deprecate |
|
|
|
|
|
class DDPMPipeline(DiffusionPipeline): |
|
r""" |
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Parameters: |
|
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of |
|
[`DDPMScheduler`], or [`DDIMScheduler`]. |
|
""" |
|
|
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
self.register_modules(unet=unet, scheduler=scheduler) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
batch_size: int = 1, |
|
generator: Optional[torch.Generator] = None, |
|
num_inference_steps: int = 1000, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
**kwargs, |
|
) -> Union[ImagePipelineOutput, Tuple]: |
|
r""" |
|
Args: |
|
batch_size (`int`, *optional*, defaults to 1): |
|
The number of images to generate. |
|
generator (`torch.Generator`, *optional*): |
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
|
deterministic. |
|
num_inference_steps (`int`, *optional*, defaults to 1000): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if |
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the |
|
generated images. |
|
""" |
|
message = ( |
|
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" |
|
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`." |
|
) |
|
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) |
|
|
|
if predict_epsilon is not None: |
|
new_config = dict(self.scheduler.config) |
|
new_config["predict_epsilon"] = predict_epsilon |
|
self.scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": |
|
message = ( |
|
f"The `generator` device is `{generator.device}` and does not match the pipeline " |
|
f"device `{self.device}`, so the `generator` will be ignored. " |
|
f'Please use `torch.Generator(device="{self.device}")` instead.' |
|
) |
|
deprecate( |
|
"generator.device == 'cpu'", |
|
"0.11.0", |
|
message, |
|
) |
|
generator = None |
|
|
|
|
|
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) |
|
if self.device.type == "mps": |
|
|
|
image = torch.randn(image_shape, generator=generator) |
|
image = image.to(self.device) |
|
else: |
|
image = torch.randn(image_shape, generator=generator, device=self.device) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
for t in self.progress_bar(self.scheduler.timesteps): |
|
|
|
model_output = self.unet(image, t).sample |
|
|
|
|
|
image = self.scheduler.step( |
|
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon |
|
).prev_sample |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return ImagePipelineOutput(images=image) |
|
|