dinhdat1110 commited on
Commit
9457143
·
verified ·
1 Parent(s): a2b808f

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ *.jpeg
162
+ *.gz
163
+ cifar-10-batches-py
164
+ checkpoints
165
+ MNIST
166
+ *.ipynb
167
+ data
168
+ wandb
Dockerfile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ FROM python:latest
2
+
3
+ COPY . /app
4
+ WORKDIR /app
5
+
6
+ RUN pip install .
7
+
8
+ CMD ["python", "-m", "diffusion.train"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Võ Đình Đạt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,8 @@
1
  ---
2
- title: Diffusion Model
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.19.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: diffusion-model
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.18.0
6
  ---
7
+ # latent-diffusion-model
8
+ Coming Soon!
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import gradio as gr
4
+ import diffusion
5
+ from torchvision import transforms
6
+
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--ckpt_path", type=str, default="./checkpoints/mnist.ckpt")
10
+ parser.add_argument("--map_location", type=str, default="cpu")
11
+ parser.add_argument("--share", action='store_true')
12
+ args = parser.parse_args()
13
+
14
+ if __name__ == "__main__":
15
+ model = diffusion.DiffusionModel.load_from_checkpoint(
16
+ args.ckpt_path, in_channels=1, map_location=args.map_location, num_classes=10
17
+ )
18
+ to_pil = transforms.ToPILImage()
19
+
20
+ def reset(image):
21
+ image = to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8))
22
+ return image
23
+
24
+ def denoise(label):
25
+ labels = torch.tensor([label]).to(model.device)
26
+ for img in model.sampling_demo(labels=labels):
27
+ image = to_pil(img[0])
28
+ yield image
29
+
30
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
31
+ gr.Markdown("# Simple Diffusion Model")
32
+
33
+ gr.Markdown("## MNIST")
34
+ with gr.Row():
35
+ with gr.Column(scale=2):
36
+ label = gr.Dropdown(
37
+ label='Label',
38
+ choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
39
+ value=0
40
+ )
41
+ with gr.Row():
42
+ sample_btn = gr.Button("Sampling")
43
+ reset_btn = gr.Button("Reset")
44
+ output = gr.Image(
45
+ value=to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8)),
46
+ scale=2,
47
+ image_mode="L",
48
+ type='pil',
49
+ )
50
+ sample_btn.click(denoise, [label], outputs=output)
51
+ reset_btn.click(reset, [output], outputs=output)
52
+
53
+ demo.launch(share=args.share)
diffusion/README.md ADDED
File without changes
diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import *
2
+ from .dataset import *
3
+ from .utils import *
diffusion/dataset/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cifar10 import *
2
+ from .mnist import *
3
+ from .celeba import *
diffusion/dataset/celeba.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import os
4
+ from PIL import Image
5
+ from torch.utils.data import DataLoader, Dataset, random_split
6
+ from torchvision import transforms
7
+ from functools import partial
8
+
9
+
10
+ class CelebADataset(Dataset):
11
+ def __init__(
12
+ self,
13
+ data_dir: str,
14
+ ):
15
+ self.list_path = os.listdir(data_dir)
16
+ self.data_dir = data_dir
17
+ self.transform = transforms.Compose(
18
+ [
19
+ transforms.Resize((64, 64)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
22
+ ]
23
+ )
24
+
25
+ def __len__(self):
26
+ return len(self.list_path)
27
+
28
+ def __getitem__(self, index):
29
+ img = Image.open(os.path.join(self.data_dir, self.list_path[index]))
30
+ return self.transform(img)
31
+
32
+
33
+ class CelebADataModule(pl.LightningDataModule):
34
+ def __init__(
35
+ self,
36
+ data_dir: str = "./",
37
+ batch_size: int = 32,
38
+ num_workers: int = 0,
39
+ seed: int = 42,
40
+ train_ratio: float = 0.99
41
+ ):
42
+ super().__init__()
43
+ self.data_dir = data_dir
44
+ self.batch_size = batch_size
45
+ self.num_workers = num_workers
46
+ self.train_ratio = min(train_ratio, 0.99)
47
+ self.seed = seed
48
+
49
+ self.loader = partial(
50
+ DataLoader,
51
+ batch_size=self.batch_size,
52
+ pin_memory=True,
53
+ num_workers=self.num_workers,
54
+ persistent_workers=True
55
+ )
56
+
57
+ def setup(self, stage: str):
58
+ if stage == "fit":
59
+ dataset = CelebADataset(self.data_dir)
60
+ self.CelebA_train, self.CelebA_val, _ = random_split(
61
+ dataset=dataset,
62
+ lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
63
+ generator=torch.Generator().manual_seed(self.seed)
64
+ )
65
+ else:
66
+ pass
67
+
68
+ def train_dataloader(self):
69
+ return self.loader(dataset=self.CelebA_train)
70
+
71
+ def val_dataloader(self):
72
+ return self.loader(dataset=self.CelebA_val)
diffusion/dataset/cifar10.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from torchvision.datasets import CIFAR10
4
+ from torch.utils.data import DataLoader, random_split
5
+ from torchvision import transforms
6
+ from functools import partial
7
+
8
+
9
+ class CIFAR10DataModule(pl.LightningDataModule):
10
+ def __init__(
11
+ self,
12
+ data_dir: str = "./",
13
+ batch_size: int = 32,
14
+ num_workers: int = 0,
15
+ seed: int = 42,
16
+ train_ratio: float = 0.99
17
+ ):
18
+ super().__init__()
19
+ self.data_dir = data_dir
20
+ self.batch_size = batch_size
21
+ self.num_workers = num_workers
22
+ self.seed = seed
23
+ self.train_ratio = min(train_ratio, 0.99)
24
+ self.transform = transforms.Compose(
25
+ [
26
+ transforms.Resize((32, 32)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
29
+ ]
30
+ )
31
+ self.loader = partial(
32
+ DataLoader,
33
+ batch_size=self.batch_size,
34
+ pin_memory=True,
35
+ num_workers=self.num_workers,
36
+ persistent_workers=True
37
+ )
38
+
39
+ def setup(self, stage: str):
40
+ cifar_partial = partial(
41
+ CIFAR10,
42
+ root=self.data_dir, transform=self.transform, download=True
43
+ )
44
+ if stage == "fit":
45
+ retrying = True
46
+ while retrying:
47
+ try:
48
+ cifar_full = cifar_partial(train=True)
49
+ retrying = False
50
+ except:
51
+ pass
52
+ self.cifar_train, self.cifar_val, _ = random_split(
53
+ dataset=cifar_full,
54
+ lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
55
+ generator=torch.Generator().manual_seed(self.seed)
56
+ )
57
+ else:
58
+ retrying = True
59
+ while retrying:
60
+ try:
61
+ self.cifar_test = cifar_partial(train=False)
62
+ retrying = False
63
+ except:
64
+ pass
65
+
66
+ def train_dataloader(self):
67
+ return self.loader(dataset=self.cifar_train)
68
+
69
+ def val_dataloader(self):
70
+ return self.loader(dataset=self.cifar_val)
71
+
72
+ def test_dataloader(self):
73
+ return self.loader(dataset=self.cifar_test)
diffusion/dataset/mnist.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from torchvision.datasets import MNIST
4
+ from torch.utils.data import DataLoader, random_split
5
+ from torchvision import transforms
6
+ from functools import partial
7
+
8
+
9
+ class MNISTDataModule(pl.LightningDataModule):
10
+ def __init__(
11
+ self,
12
+ data_dir: str = "./",
13
+ batch_size: int = 32,
14
+ num_workers: int = 0,
15
+ seed: int = 42,
16
+ train_ratio: float = 0.99
17
+ ):
18
+ super().__init__()
19
+ self.data_dir = data_dir
20
+ self.batch_size = batch_size
21
+ self.num_workers = num_workers
22
+ self.train_ratio = min(train_ratio, 0.99)
23
+ self.seed = seed
24
+ self.transform = transforms.Compose(
25
+ [
26
+ transforms.Resize((32, 32)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=(0.5), std=(0.5))
29
+ ]
30
+ )
31
+ self.loader = partial(
32
+ DataLoader,
33
+ batch_size=self.batch_size,
34
+ pin_memory=True,
35
+ num_workers=self.num_workers,
36
+ persistent_workers=True
37
+ )
38
+
39
+ def setup(self, stage: str):
40
+ mnist_partial = partial(
41
+ MNIST,
42
+ root=self.data_dir, transform=self.transform, download=True
43
+ )
44
+ if stage == "fit":
45
+ retrying = True
46
+ while retrying:
47
+ try:
48
+ mnist_full = mnist_partial(train=True)
49
+ retrying = False
50
+ except:
51
+ pass
52
+ self.mnist_train, self.mnist_val, _ = random_split(
53
+ dataset=mnist_full,
54
+ lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
55
+ generator=torch.Generator().manual_seed(self.seed)
56
+ )
57
+ else:
58
+ retrying = True
59
+ while retrying:
60
+ try:
61
+ self.mnist_test = mnist_partial(train=False)
62
+ retrying = False
63
+ except:
64
+ pass
65
+
66
+ def train_dataloader(self):
67
+ return self.loader(dataset=self.mnist_train)
68
+
69
+ def val_dataloader(self):
70
+ return self.loader(dataset=self.mnist_val)
71
+
72
+ def test_dataloader(self):
73
+ return self.loader(dataset=self.mnist_test)
diffusion/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .diffusion import *
2
+ from .ldm import *
diffusion/model/diffusion/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .unet import *
2
+ from .model import *
3
+ from .sampling import *
4
+ from .scheduler import *
diffusion/model/diffusion/model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import diffusion
6
+ import wandb
7
+ from torchvision.utils import make_grid
8
+ from torch.optim.lr_scheduler import OneCycleLR
9
+
10
+
11
+ class DiffusionModel(pl.LightningModule):
12
+ def __init__(
13
+ self,
14
+ lr: float = 1e-4,
15
+ max_timesteps: int = 1000,
16
+ beta_1: float = 0.0001,
17
+ beta_2: float = 0.02,
18
+ in_channels: int = 3,
19
+ dim: int = 32,
20
+ num_classes: int | None = 10,
21
+ sample_per_epochs: int = 50,
22
+ n_samples: int = 16
23
+ ):
24
+ super().__init__()
25
+ self.save_hyperparameters()
26
+
27
+ self.model = diffusion.ConditionalUNet(
28
+ c_in=in_channels,
29
+ c_out=in_channels,
30
+ num_classes=num_classes
31
+ )
32
+ self.lr = lr
33
+ self.max_timesteps = max_timesteps
34
+ self.in_channels = in_channels
35
+ self.dim = dim
36
+ self.num_classes = num_classes
37
+
38
+ self.scheduler = diffusion.LinearScheduler(
39
+ max_timesteps, beta_1, beta_2
40
+ )
41
+
42
+ self.criterion = nn.MSELoss()
43
+
44
+ self.spe = sample_per_epochs
45
+ self.n_samples = n_samples
46
+ self.epoch_count = 0
47
+ self.train_loss = []
48
+ self.val_loss = []
49
+
50
+ self.sampling_kwargs = {
51
+ 'model': self.model,
52
+ 'scheduler': self.scheduler,
53
+ 'max_timesteps': self.max_timesteps,
54
+ 'in_channels': self.in_channels,
55
+ 'dim': self.dim,
56
+ }
57
+
58
+ def _batch_index_select(
59
+ self,
60
+ x: torch.Tensor,
61
+ t: torch.Tensor,
62
+ device: torch.device
63
+ ):
64
+ # x.shape = [T,]
65
+ # t.shape = [B,]
66
+ if x.device != device:
67
+ x = x.to(device)
68
+ if t.device != device:
69
+ t = t.to(device)
70
+ x_select = x.gather(dim=-1, index=t)
71
+ return x_select[:, None, None, None] # [B,1]
72
+
73
+ def noising(
74
+ self,
75
+ x_0: torch.Tensor,
76
+ t: torch.Tensor
77
+ ):
78
+ noise = torch.randn_like(x_0, device=x_0.device)
79
+ new_x = self.scheduler.get('sqrt_alpha_hat', t) * x_0
80
+ new_noise = self.scheduler.get('sqrt_one_minus_alpha_hat', t) * noise
81
+ return new_x + new_noise, noise
82
+
83
+ def sampling(self, labels=None, n_samples: int = 16):
84
+ return diffusion.ddpm_sampling(
85
+ n_samples=n_samples,
86
+ labels=labels,
87
+ **self.sampling_kwargs
88
+ )
89
+
90
+ def sampling_demo(self, labels=None, n_samples: int = 16):
91
+ return diffusion.ddpm_sampling_demo(
92
+ n_samples=n_samples,
93
+ labels=labels,
94
+ **self.sampling_kwargs
95
+ )
96
+
97
+ def forward(self, x_0, labels):
98
+ t = torch.randint(
99
+ low=0, high=self.max_timesteps, size=(x_0.shape[0],), device=x_0.device
100
+ )
101
+ x_noise, noise = self.noising(x_0, t)
102
+ noise_pred = self.model(x_noise, t, labels)
103
+ return noise, noise_pred
104
+
105
+ def training_step(self, batch, idx):
106
+ if isinstance(batch, torch.Tensor):
107
+ x_0 = batch
108
+ labels = None
109
+ else:
110
+ x_0, labels = batch
111
+ if np.random.random() < 0.1:
112
+ labels = None
113
+ noise, noise_pred = self(x_0, labels)
114
+ loss = self.criterion(noise, noise_pred)
115
+ self.train_loss.append(loss)
116
+ return loss
117
+
118
+ def validation_step(self, batch, idx):
119
+ if isinstance(batch, torch.Tensor):
120
+ x_0 = batch
121
+ labels = None
122
+ else:
123
+ x_0, labels = batch
124
+ noise, noise_pred = self(x_0, labels)
125
+ loss = self.criterion(noise, noise_pred)
126
+ self.val_loss.append(loss)
127
+ return loss
128
+
129
+ def on_train_epoch_end(self) -> None:
130
+ self.log_dict(
131
+ {
132
+ "train_loss": sum(self.train_loss) / len(self.train_loss)
133
+ },
134
+ sync_dist=True
135
+ )
136
+ self.train_loss.clear()
137
+
138
+ if self.epoch_count % self.spe == 0:
139
+ wandblog = self.logger.experiment
140
+ x_t = self.sampling(n_samples=self.n_samples)
141
+ img_array = [x_t[i] for i in range(x_t.shape[0])]
142
+
143
+ wandblog.log(
144
+ {
145
+ "sampling": wandb.Image(
146
+ make_grid(img_array, nrow=4).permute(1, 2, 0).cpu().numpy(),
147
+ caption="Sampled Image!"
148
+ )
149
+ }
150
+ )
151
+
152
+ self.epoch_count += 1
153
+
154
+ def on_validation_epoch_end(self):
155
+ self.log_dict(
156
+ {
157
+ "val_loss": sum(self.val_loss) / len(self.val_loss)
158
+ },
159
+ sync_dist=True
160
+ )
161
+ self.val_loss.clear()
162
+
163
+ def configure_optimizers(self):
164
+ optimizer = torch.optim.AdamW(
165
+ params=self.parameters(),
166
+ lr=self.lr,
167
+ weight_decay=0.001,
168
+ betas=(0.9, 0.999)
169
+ )
170
+ scheduler = OneCycleLR(
171
+ optimizer=optimizer,
172
+ max_lr=self.lr,
173
+ total_steps=self.trainer.estimated_stepping_batches,
174
+
175
+ )
176
+ return {
177
+ 'optimizer': optimizer,
178
+ 'lr_scheduler': scheduler
179
+ }
180
+
181
+
182
+ if __name__ == "__main__":
183
+ a = torch.randn(32, 3, 32, 32)
184
+ model = DiffusionModel(max_timesteps=10)
185
+ n, n_pred = model(a)
186
+ print(n.shape, n_pred.shape)
187
+ print(torch.mean((n-n_pred)**2))
188
+ print(model.sampling(1))
diffusion/model/diffusion/sampling.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def ddpm_sampling_timestep(
5
+ x_t,
6
+ model,
7
+ scheduler,
8
+ labels,
9
+ t,
10
+ n_samples: int = 16,
11
+ cfg_scale: int = 3,
12
+ ):
13
+ time = torch.full((n_samples,), fill_value=t, device=model.device)
14
+ pred_noise = model(x_t, time, labels)
15
+ if cfg_scale > 0:
16
+ uncond_pred_noise = model(x_t, time, None)
17
+ pred_noise = torch.lerp(uncond_pred_noise, pred_noise, cfg_scale)
18
+ alpha = scheduler.get('alpha', time)
19
+ sqrt_alpha = scheduler.get('sqrt_alpha', time)
20
+ somah = scheduler.get('sqrt_one_minus_alpha_hat', time)
21
+ sqrt_beta = scheduler.get('sqrt_beta', time)
22
+ if t > 0:
23
+ noise = torch.randn_like(x_t, device=model.device)
24
+ else:
25
+ noise = torch.zeros_like(x_t, device=model.device)
26
+
27
+ x_t_new = 1 / sqrt_alpha * (x_t - (1-alpha) / somah * pred_noise) + sqrt_beta * noise
28
+ return x_t_new.clamp(-1, 1)
29
+
30
+
31
+ @torch.no_grad()
32
+ def ddpm_sampling(
33
+ model,
34
+ scheduler,
35
+ n_samples: int = 16,
36
+ max_timesteps: int = 1000,
37
+ in_channels: int = 3,
38
+ dim: int = 32,
39
+ cfg_scale: int = 3,
40
+
41
+ labels=None
42
+ ):
43
+ if labels is not None:
44
+ n_samples = labels.shape[0]
45
+
46
+ x_t = torch.randn(
47
+ n_samples, in_channels, dim, dim, device=model.device
48
+ )
49
+ model.eval()
50
+ for t in range(max_timesteps-1, -1, -1):
51
+ x_t = ddpm_sampling_timestep(x_t=x_t, model=model, scheduler=scheduler,
52
+ labels=labels, t=t, n_samples=n_samples,
53
+ cfg_scale=cfg_scale)
54
+
55
+ model.train()
56
+ x_t = (x_t + 1) / 2 * 255. # range [0,255]
57
+ return x_t.type(torch.uint8)
58
+
59
+
60
+ @torch.no_grad()
61
+ def ddpm_sampling_demo(
62
+ model,
63
+ scheduler,
64
+ n_samples: int = 16,
65
+ max_timesteps: int = 1000,
66
+ in_channels: int = 3,
67
+ dim: int = 32,
68
+ cfg_scale: int = 3,
69
+ labels=None
70
+ ):
71
+ if labels is not None:
72
+ n_samples = labels.shape[0]
73
+
74
+ x_t = torch.randn(
75
+ n_samples, in_channels, dim, dim, device=model.device
76
+ )
77
+ model.eval()
78
+ for t in range(max_timesteps-1, -1, -1):
79
+ x_t = ddpm_sampling_timestep(x_t=x_t, model=model, scheduler=scheduler,
80
+ labels=labels, t=t, n_samples=n_samples,
81
+ cfg_scale=cfg_scale)
82
+ yield ((x_t + 1) / 2 * 255).type(torch.uint8)
diffusion/model/diffusion/scheduler.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class LinearScheduler:
5
+ def __init__(
6
+ self,
7
+ max_timesteps: int = 1000,
8
+ beta_1: int = 0.0001,
9
+ beta_2: int = 0.02
10
+ ) -> None:
11
+ self.beta = torch.linspace(beta_1, beta_2, max_timesteps)
12
+ self.sqrt_beta = torch.sqrt(self.beta)[:, None, None, None]
13
+ self.alpha = (1 - self.beta)[:, None, None, None]
14
+ self.sqrt_alpha = torch.sqrt(self.alpha)
15
+ self.alpha_hat = torch.cumprod(1 - self.beta, dim=0)[:, None, None, None]
16
+ self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
17
+ self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)
18
+
19
+ def get(self, key: str, t: torch.Tensor):
20
+ return self.__dict__[key].to(t.device)[t]
diffusion/model/diffusion/unet.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import pytorch_lightning as pl
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class SelfAttention(nn.Module):
9
+ def __init__(
10
+ self,
11
+ channels: int
12
+ ):
13
+ super(SelfAttention, self).__init__()
14
+ self.channels = channels
15
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
16
+ self.ln = nn.LayerNorm([channels])
17
+ self.ff_self = nn.Sequential(
18
+ nn.LayerNorm([channels]),
19
+ nn.Linear(channels, channels),
20
+ nn.GELU(),
21
+ nn.Linear(channels, channels),
22
+ )
23
+
24
+ def forward(self, x):
25
+ B, C, H, W = x.shape
26
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
27
+ x_ln = self.ln(x)
28
+ attention_value, _ = self.mha(x_ln, x_ln, x_ln)
29
+ attention_value = attention_value + x
30
+ attention_value = self.ff_self(attention_value) + attention_value
31
+ return rearrange(attention_value, 'b (h w) c -> b c h w', h=H, w=W).contiguous()
32
+
33
+
34
+ class DoubleConv(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels: int,
38
+ out_channels: int,
39
+ mid_channels: int | None = None,
40
+ residual: bool = False
41
+ ):
42
+ super().__init__()
43
+ self.residual = residual
44
+ if not mid_channels:
45
+ mid_channels = out_channels
46
+ self.double_conv = nn.Sequential(
47
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
48
+ nn.GroupNorm(8, mid_channels),
49
+ nn.GELU(),
50
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
51
+ nn.GroupNorm(8, out_channels),
52
+ )
53
+
54
+ def forward(self, x):
55
+ if self.residual:
56
+ return (x + self.double_conv(x)) / 1.414
57
+ else:
58
+ return self.double_conv(x)
59
+
60
+
61
+ class DownSample(nn.Module):
62
+ def __init__(
63
+ self,
64
+ in_channels: int,
65
+ out_channels: int,
66
+ emb_dim: int = 256
67
+ ):
68
+ super().__init__()
69
+ self.maxpool_conv = nn.Sequential(
70
+ nn.MaxPool2d(2),
71
+ DoubleConv(in_channels, in_channels, residual=True),
72
+ DoubleConv(in_channels, out_channels),
73
+ )
74
+
75
+ self.emb_layer = nn.Sequential(
76
+ nn.SiLU(),
77
+ nn.Linear(emb_dim, out_channels),
78
+ )
79
+
80
+ def forward(self, x, t):
81
+ x = self.maxpool_conv(x)
82
+ _, _, H, W = x.shape
83
+ emb = repeat(self.emb_layer(t), 'b d -> b d h w', h=H, w=W).contiguous()
84
+ return x + emb
85
+
86
+
87
+ class UpSample(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels: int,
91
+ out_channels: int,
92
+ emb_dim: int = 256
93
+ ):
94
+ super().__init__()
95
+
96
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
97
+ self.conv = nn.Sequential(
98
+ DoubleConv(in_channels, in_channels, residual=True),
99
+ DoubleConv(in_channels, out_channels, in_channels // 2),
100
+ )
101
+
102
+ self.emb_layer = nn.Sequential(
103
+ nn.SiLU(),
104
+ nn.Linear(emb_dim, out_channels)
105
+ )
106
+
107
+ def forward(self, x, skip_x, t):
108
+ x = self.up(x)
109
+ x = torch.cat([skip_x, x], dim=1)
110
+ x = self.conv(x)
111
+ _, _, H, W = x.shape
112
+ emb = repeat(self.emb_layer(t), 'b d -> b d h w', h=H, w=W).contiguous()
113
+ return x + emb
114
+
115
+
116
+ class UNet(pl.LightningModule):
117
+ def __init__(
118
+ self,
119
+ c_in: int = 3,
120
+ c_out: int = 3,
121
+ time_dim: int = 256
122
+ ):
123
+ super().__init__()
124
+ self.time_dim = time_dim
125
+
126
+ self.time_embed = nn.Sequential(
127
+ nn.Linear(time_dim, time_dim),
128
+ nn.SiLU(),
129
+ nn.Linear(time_dim, time_dim),
130
+ )
131
+ self.inc = DoubleConv(in_channels=c_in, out_channels=64)
132
+ self.down1 = DownSample(in_channels=64, out_channels=128)
133
+ self.sa1 = SelfAttention(channels=128)
134
+ self.down2 = DownSample(in_channels=128, out_channels=256)
135
+ self.sa2 = SelfAttention(channels=256)
136
+ self.down3 = DownSample(in_channels=256, out_channels=256)
137
+ self.sa3 = SelfAttention(channels=256)
138
+
139
+ self.mid1 = DoubleConv(in_channels=256, out_channels=512)
140
+ self.mid2 = DoubleConv(in_channels=512, out_channels=512)
141
+
142
+ self.up1 = UpSample(in_channels=512, out_channels=256)
143
+ self.sa4 = SelfAttention(channels=256)
144
+ self.up2 = UpSample(in_channels=256, out_channels=128)
145
+ self.sa5 = SelfAttention(channels=128)
146
+ self.up3 = UpSample(in_channels=128, out_channels=64)
147
+ self.sa6 = SelfAttention(channels=64)
148
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
149
+
150
+ def pos_encoding(self, t, channels):
151
+ inv_freq = 1.0 / (
152
+ 10000
153
+ ** (torch.arange(0, channels, 2).float().to(t.device) / channels)
154
+ ) * t.repeat(1, channels // 2)
155
+
156
+ pos_enc = torch.zeros((t.shape[0], channels), device=t.device)
157
+ pos_enc[:, 0::2] = torch.sin(inv_freq)
158
+ pos_enc[:, 1::2] = torch.cos(inv_freq)
159
+ return pos_enc
160
+
161
+ def forward_unet(self, x, t):
162
+ x1 = self.inc(x)
163
+ x2 = self.down1(x1, t)
164
+ x2 = self.sa1(x2)
165
+ x3 = self.down2(x2, t)
166
+ x3 = self.sa2(x3)
167
+ x4 = self.down3(x3, t)
168
+ x4 = self.sa3(x4)
169
+
170
+ x4 = self.mid1(x4)
171
+ x4 = self.mid2(x4)
172
+
173
+ x = self.up1(x4, x3, t)
174
+ x = self.sa4(x)
175
+ x = self.up2(x, x2, t)
176
+ x = self.sa5(x)
177
+ x = self.up3(x, x1, t)
178
+ x = self.sa6(x)
179
+ output = self.outc(x)
180
+ return output
181
+
182
+ def forward(
183
+ self,
184
+ x: torch.Tensor,
185
+ t: torch.Tensor
186
+ ):
187
+ t = t.unsqueeze(-1).type(torch.float)
188
+ t = self.pos_encoding(t, self.time_dim)
189
+ t = self.time_embed(t)
190
+ return self.forward_unet(x, t)
191
+
192
+
193
+ class ConditionalUNet(UNet):
194
+ def __init__(
195
+ self,
196
+ c_in: int = 3,
197
+ c_out: int = 3,
198
+ time_dim: int = 256,
199
+ num_classes: int | None = None,
200
+ ):
201
+ super().__init__(c_in, c_out, time_dim)
202
+ self.num_classes = num_classes
203
+ if num_classes is not None:
204
+ self.cls_embed = nn.Embedding(num_classes, time_dim)
205
+
206
+ def forward(
207
+ self,
208
+ x: torch.Tensor,
209
+ t: torch.Tensor,
210
+ label: torch.Tensor | None = None
211
+ ):
212
+ t = t.unsqueeze(-1).type(torch.float)
213
+ t = self.pos_encoding(t, self.time_dim)
214
+ t = self.time_embed(t)
215
+
216
+ if label is not None and self.num_classes is not None:
217
+ t += self.cls_embed(label)
218
+ return self.forward_unet(x, t)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ net = ConditionalUNet()
223
+ print(sum([p.numel() for p in net.parameters()]))
224
+ x = torch.randn(2, 3, 32, 32)
225
+ t = x.new_tensor([500] * x.shape[0]).long()
226
+ print(t)
227
+ print(net(x, t).shape)
diffusion/model/ldm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import *
diffusion/model/ldm/model.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+
4
+ class LatentDiffusionModel(pl.LightningModule):
5
+ pass
diffusion/model/ldm/tests/__init__.py ADDED
File without changes
diffusion/tests/__init__.py ADDED
File without changes
diffusion/train/__init__.py ADDED
File without changes
diffusion/train/__main__.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.loggers import WandbLogger
2
+ import diffusion
3
+ import torch
4
+ import wandb
5
+ import pytorch_lightning as pl
6
+ import argparse
7
+ import os
8
+
9
+ torch.multiprocessing.set_sharing_strategy('file_system')
10
+
11
+
12
+ def main():
13
+ # PARSERs
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ '--dataset', '-d', type=str, default='mnist',
17
+ help='choose dataset'
18
+ )
19
+ parser.add_argument(
20
+ '--data_dir', '-dd', type=str, default='./data/',
21
+ help='model name'
22
+ )
23
+ parser.add_argument(
24
+ '--max_epochs', '-me', type=int, default=200,
25
+ help='max epoch'
26
+ )
27
+ parser.add_argument(
28
+ '--batch_size', '-bs', type=int, default=32,
29
+ help='batch size'
30
+ )
31
+ parser.add_argument(
32
+ '--train_ratio', '-tr', type=float, default=0.99,
33
+ help='batch size'
34
+ )
35
+ parser.add_argument(
36
+ '--timesteps', '-ts', type=int, default=1000,
37
+ help='max timesteps diffusion'
38
+ )
39
+ parser.add_argument(
40
+ '--max_batch_size', '-mbs', type=int, default=32,
41
+ help='max batch size'
42
+ )
43
+ parser.add_argument(
44
+ '--lr', '-l', type=float, default=1e-4,
45
+ help='learning rate'
46
+ )
47
+ parser.add_argument(
48
+ '--num_workers', '-nw', type=int, default=4,
49
+ help='number of workers'
50
+ )
51
+ parser.add_argument(
52
+ '--seed', '-s', type=int, default=42,
53
+ help='seed'
54
+ )
55
+ parser.add_argument(
56
+ '--name', '-n', type=str, default=None,
57
+ help='name of the experiment'
58
+ )
59
+ parser.add_argument(
60
+ '--pbar', action='store_true',
61
+ help='progress bar'
62
+ )
63
+ parser.add_argument(
64
+ '--precision', '-p', type=str, default='32',
65
+ help='numerical precision'
66
+ )
67
+ parser.add_argument(
68
+ '--sample_per_epochs', '-spe', type=int, default=25,
69
+ help='sample every n epochs'
70
+ )
71
+ parser.add_argument(
72
+ '--n_samples', '-ns', type=int, default=4,
73
+ help='number of workers'
74
+ )
75
+ parser.add_argument(
76
+ '--monitor', '-m', type=str, default='val_loss',
77
+ help='callbacks monitor'
78
+ )
79
+ parser.add_argument(
80
+ '--wandb', '-wk', type=str, default=None,
81
+ help='wandb API key'
82
+ )
83
+
84
+ args = parser.parse_args()
85
+
86
+ # SEED
87
+ pl.seed_everything(args.seed, workers=True)
88
+
89
+ # WANDB (OPTIONAL)
90
+ if args.wandb is not None:
91
+ wandb.login(key=args.wandb) # API KEY
92
+ name = args.name or f"diffusion-{args.max_epochs}-{args.batch_size}-{args.lr}"
93
+ logger = WandbLogger(
94
+ project="diffusion-model",
95
+ name=name,
96
+ log_model=False
97
+ )
98
+ else:
99
+ logger = None
100
+
101
+ # DATAMODULE
102
+ if args.dataset == "mnist":
103
+ DATAMODULE = diffusion.MNISTDataModule
104
+ img_dim = 32
105
+ num_classes = 10
106
+ elif args.dataset == "cifar10":
107
+ DATAMODULE = diffusion.CIFAR10DataModule
108
+ img_dim = 32
109
+ num_classes = 10
110
+ elif args.dataset == "celeba":
111
+ DATAMODULE = diffusion.CelebADataModule
112
+ img_dim = 64
113
+ num_classes = None
114
+
115
+ datamodule = DATAMODULE(
116
+ data_dir=args.data_dir,
117
+ batch_size=args.batch_size,
118
+ num_workers=args.num_workers,
119
+ seed=args.seed,
120
+ train_ratio=args.train_ratio
121
+ )
122
+
123
+ # MODEL
124
+ in_channels = 1 if args.dataset == "mnist" else 3
125
+ model = diffusion.DiffusionModel(
126
+ lr=args.lr,
127
+ in_channels=in_channels,
128
+ sample_per_epochs=args.sample_per_epochs,
129
+ max_timesteps=args.timesteps,
130
+ dim=img_dim,
131
+ num_classes=num_classes,
132
+ n_samples=args.n_samples
133
+ )
134
+
135
+ # CALLBACK
136
+ root_path = os.path.join(os.getcwd(), "checkpoints")
137
+ callback = diffusion.ModelCallback(
138
+ root_path=root_path,
139
+ ckpt_monitor=args.monitor
140
+ )
141
+
142
+ # STRATEGY
143
+ strategy = 'ddp_find_unused_parameters_true' if torch.cuda.is_available() else 'auto'
144
+
145
+ # TRAINER
146
+ trainer = pl.Trainer(
147
+ default_root_dir=root_path,
148
+ logger=logger,
149
+ callbacks=callback.get_callback(),
150
+ gradient_clip_val=0.5,
151
+ max_epochs=args.max_epochs,
152
+ enable_progress_bar=args.pbar,
153
+ deterministic=False,
154
+ precision=args.precision,
155
+ strategy=strategy,
156
+ accumulate_grad_batches=max(int(args.max_batch_size / args.batch_size), 1)
157
+ )
158
+
159
+ # FIT MODEL
160
+ trainer.fit(model=model, datamodule=datamodule)
161
+
162
+
163
+ if __name__ == '__main__':
164
+ main()
diffusion/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .callback import *
2
+ from .ema import *
diffusion/utils/callback.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import diffusion
3
+ from pytorch_lightning.callbacks import (
4
+ ModelCheckpoint,
5
+ LearningRateMonitor
6
+ )
7
+
8
+
9
+ class ModelCallback:
10
+ def __init__(
11
+ self,
12
+ root_path: str,
13
+ ckpt_monitor: str = "val_loss",
14
+ ckpt_mode: str = "min",
15
+ ):
16
+ ckpt_path = os.path.join(os.path.join(root_path, "model/"))
17
+ if not os.path.exists(root_path):
18
+ os.makedirs(root_path)
19
+ if not os.path.exists(ckpt_path):
20
+ os.makedirs(ckpt_path)
21
+
22
+ self.ckpt_callback = ModelCheckpoint(
23
+ monitor=ckpt_monitor,
24
+ dirpath=ckpt_path,
25
+ filename="model",
26
+ save_top_k=1,
27
+ mode=ckpt_mode,
28
+ save_weights_only=True
29
+ )
30
+
31
+ self.lr_callback = LearningRateMonitor("step")
32
+
33
+ self.ema_callback = diffusion.EMACallback(decay=0.995)
34
+
35
+ def get_callback(self):
36
+ return [
37
+ self.ckpt_callback, self.lr_callback, self.ema_callback
38
+ ]
diffusion/utils/ema.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+ from timm.utils.model import get_state_dict, unwrap_model
3
+ from timm.utils.model_ema import ModelEmaV2
4
+ # Cell
5
+
6
+
7
+ class EMACallback(Callback):
8
+ """
9
+ Model Exponential Moving Average. Empirically it has been found that using the moving average
10
+ of the trained parameters of a deep network is better than using its trained parameters directly.
11
+
12
+ If `use_ema_weights`, then the ema parameters of the network is set after training end.
13
+ """
14
+
15
+ def __init__(self, decay=0.9999, use_ema_weights: bool = True):
16
+ self.decay = decay
17
+ self.ema = None
18
+ self.use_ema_weights = use_ema_weights
19
+
20
+ def on_fit_start(self, trainer, pl_module, *args):
21
+ "Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights"
22
+ self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None)
23
+
24
+ def on_train_batch_end(
25
+ self, trainer, pl_module, *args
26
+ ):
27
+ "Update the stored parameters using a moving average"
28
+ # Update currently maintained parameters.
29
+ self.ema.update(pl_module)
30
+
31
+ def on_validation_epoch_start(self, trainer, pl_module, *args):
32
+ "do validation using the stored parameters"
33
+ # save original parameters before replacing with EMA version
34
+ self.store(pl_module.parameters())
35
+
36
+ # update the LightningModule with the EMA weights
37
+ # ~ Copy EMA parameters to LightningModule
38
+ self.copy_to(self.ema.module.parameters(), pl_module.parameters())
39
+
40
+ def on_validation_end(self, trainer, pl_module, *args):
41
+ "Restore original parameters to resume training later"
42
+ self.restore(pl_module.parameters())
43
+
44
+ def on_train_end(self, trainer, pl_module, *args):
45
+ # update the LightningModule with the EMA weights
46
+ if self.use_ema_weights:
47
+ self.copy_to(self.ema.module.parameters(), pl_module.parameters())
48
+ msg = "Model weights replaced with the EMA version."
49
+
50
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint, *args):
51
+ if self.ema is not None:
52
+ return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)}
53
+
54
+ def on_load_checkpoint(self, callback_state, *args):
55
+ if self.ema is not None:
56
+ self.ema.module.load_state_dict(callback_state["state_dict_ema"])
57
+
58
+ def store(self, parameters):
59
+ "Save the current parameters for restoring later."
60
+ self.collected_params = [param.clone() for param in parameters]
61
+
62
+ def restore(self, parameters):
63
+ """
64
+ Restore the parameters stored with the `store` method.
65
+ Useful to validate the model with EMA parameters without affecting the
66
+ original optimization process.
67
+ """
68
+ for c_param, param in zip(self.collected_params, parameters):
69
+ param.data.copy_(c_param.data)
70
+
71
+ def copy_to(self, shadow_parameters, parameters):
72
+ "Copy current parameters into given collection of parameters."
73
+ for s_param, param in zip(shadow_parameters, parameters):
74
+ if param.requires_grad:
75
+ param.data.copy_(s_param.data)
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "diffusion"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Võ Đình Đạt <[email protected]>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.10"
10
+ torch = "*"
11
+ lightning = "*"
12
+ einops = "^0.7.0"
13
+ torchvision = "*"
14
+ wandb = "^0.16.3"
15
+ torchaudio = "*"
16
+ tqdm = "^4.66.2"
17
+ timm = "^0.9.12"
18
+ gradio = "^4.18.0"
19
+
20
+ [build-system]
21
+ requires = ["poetry-core"]
22
+ build-backend = "poetry.core.masonry.api"
script/setup.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ pip install .