Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitignore +168 -0
- Dockerfile +8 -0
- LICENSE +21 -0
- README.md +5 -9
- app.py +53 -0
- diffusion/README.md +0 -0
- diffusion/__init__.py +3 -0
- diffusion/dataset/__init__.py +3 -0
- diffusion/dataset/celeba.py +72 -0
- diffusion/dataset/cifar10.py +73 -0
- diffusion/dataset/mnist.py +73 -0
- diffusion/model/__init__.py +2 -0
- diffusion/model/diffusion/__init__.py +4 -0
- diffusion/model/diffusion/model.py +188 -0
- diffusion/model/diffusion/sampling.py +82 -0
- diffusion/model/diffusion/scheduler.py +20 -0
- diffusion/model/diffusion/unet.py +227 -0
- diffusion/model/ldm/__init__.py +1 -0
- diffusion/model/ldm/model.py +5 -0
- diffusion/model/ldm/tests/__init__.py +0 -0
- diffusion/tests/__init__.py +0 -0
- diffusion/train/__init__.py +0 -0
- diffusion/train/__main__.py +164 -0
- diffusion/utils/__init__.py +2 -0
- diffusion/utils/callback.py +38 -0
- diffusion/utils/ema.py +75 -0
- poetry.lock +0 -0
- pyproject.toml +22 -0
- script/setup.sh +1 -0
.gitignore
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
*.jpeg
|
162 |
+
*.gz
|
163 |
+
cifar-10-batches-py
|
164 |
+
checkpoints
|
165 |
+
MNIST
|
166 |
+
*.ipynb
|
167 |
+
data
|
168 |
+
wandb
|
Dockerfile
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:latest
|
2 |
+
|
3 |
+
COPY . /app
|
4 |
+
WORKDIR /app
|
5 |
+
|
6 |
+
RUN pip install .
|
7 |
+
|
8 |
+
CMD ["python", "-m", "diffusion.train"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Võ Đình Đạt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 🦀
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.19.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
11 |
-
|
12 |
-
|
|
|
1 |
---
|
2 |
+
title: diffusion-model
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.18.0
|
6 |
---
|
7 |
+
# latent-diffusion-model
|
8 |
+
Coming Soon!
|
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import gradio as gr
|
4 |
+
import diffusion
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/mnist.ckpt")
|
10 |
+
parser.add_argument("--map_location", type=str, default="cpu")
|
11 |
+
parser.add_argument("--share", action='store_true')
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
model = diffusion.DiffusionModel.load_from_checkpoint(
|
16 |
+
args.ckpt_path, in_channels=1, map_location=args.map_location, num_classes=10
|
17 |
+
)
|
18 |
+
to_pil = transforms.ToPILImage()
|
19 |
+
|
20 |
+
def reset(image):
|
21 |
+
image = to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8))
|
22 |
+
return image
|
23 |
+
|
24 |
+
def denoise(label):
|
25 |
+
labels = torch.tensor([label]).to(model.device)
|
26 |
+
for img in model.sampling_demo(labels=labels):
|
27 |
+
image = to_pil(img[0])
|
28 |
+
yield image
|
29 |
+
|
30 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
|
31 |
+
gr.Markdown("# Simple Diffusion Model")
|
32 |
+
|
33 |
+
gr.Markdown("## MNIST")
|
34 |
+
with gr.Row():
|
35 |
+
with gr.Column(scale=2):
|
36 |
+
label = gr.Dropdown(
|
37 |
+
label='Label',
|
38 |
+
choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
39 |
+
value=0
|
40 |
+
)
|
41 |
+
with gr.Row():
|
42 |
+
sample_btn = gr.Button("Sampling")
|
43 |
+
reset_btn = gr.Button("Reset")
|
44 |
+
output = gr.Image(
|
45 |
+
value=to_pil((torch.randn(1, 32, 32)*255).type(torch.uint8)),
|
46 |
+
scale=2,
|
47 |
+
image_mode="L",
|
48 |
+
type='pil',
|
49 |
+
)
|
50 |
+
sample_btn.click(denoise, [label], outputs=output)
|
51 |
+
reset_btn.click(reset, [output], outputs=output)
|
52 |
+
|
53 |
+
demo.launch(share=args.share)
|
diffusion/README.md
ADDED
File without changes
|
diffusion/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .model import *
|
2 |
+
from .dataset import *
|
3 |
+
from .utils import *
|
diffusion/dataset/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .cifar10 import *
|
2 |
+
from .mnist import *
|
3 |
+
from .celeba import *
|
diffusion/dataset/celeba.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
6 |
+
from torchvision import transforms
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
|
10 |
+
class CelebADataset(Dataset):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
data_dir: str,
|
14 |
+
):
|
15 |
+
self.list_path = os.listdir(data_dir)
|
16 |
+
self.data_dir = data_dir
|
17 |
+
self.transform = transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.Resize((64, 64)),
|
20 |
+
transforms.ToTensor(),
|
21 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
22 |
+
]
|
23 |
+
)
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.list_path)
|
27 |
+
|
28 |
+
def __getitem__(self, index):
|
29 |
+
img = Image.open(os.path.join(self.data_dir, self.list_path[index]))
|
30 |
+
return self.transform(img)
|
31 |
+
|
32 |
+
|
33 |
+
class CelebADataModule(pl.LightningDataModule):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
data_dir: str = "./",
|
37 |
+
batch_size: int = 32,
|
38 |
+
num_workers: int = 0,
|
39 |
+
seed: int = 42,
|
40 |
+
train_ratio: float = 0.99
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.data_dir = data_dir
|
44 |
+
self.batch_size = batch_size
|
45 |
+
self.num_workers = num_workers
|
46 |
+
self.train_ratio = min(train_ratio, 0.99)
|
47 |
+
self.seed = seed
|
48 |
+
|
49 |
+
self.loader = partial(
|
50 |
+
DataLoader,
|
51 |
+
batch_size=self.batch_size,
|
52 |
+
pin_memory=True,
|
53 |
+
num_workers=self.num_workers,
|
54 |
+
persistent_workers=True
|
55 |
+
)
|
56 |
+
|
57 |
+
def setup(self, stage: str):
|
58 |
+
if stage == "fit":
|
59 |
+
dataset = CelebADataset(self.data_dir)
|
60 |
+
self.CelebA_train, self.CelebA_val, _ = random_split(
|
61 |
+
dataset=dataset,
|
62 |
+
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
|
63 |
+
generator=torch.Generator().manual_seed(self.seed)
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
pass
|
67 |
+
|
68 |
+
def train_dataloader(self):
|
69 |
+
return self.loader(dataset=self.CelebA_train)
|
70 |
+
|
71 |
+
def val_dataloader(self):
|
72 |
+
return self.loader(dataset=self.CelebA_val)
|
diffusion/dataset/cifar10.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from torchvision.datasets import CIFAR10
|
4 |
+
from torch.utils.data import DataLoader, random_split
|
5 |
+
from torchvision import transforms
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
|
9 |
+
class CIFAR10DataModule(pl.LightningDataModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
data_dir: str = "./",
|
13 |
+
batch_size: int = 32,
|
14 |
+
num_workers: int = 0,
|
15 |
+
seed: int = 42,
|
16 |
+
train_ratio: float = 0.99
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.data_dir = data_dir
|
20 |
+
self.batch_size = batch_size
|
21 |
+
self.num_workers = num_workers
|
22 |
+
self.seed = seed
|
23 |
+
self.train_ratio = min(train_ratio, 0.99)
|
24 |
+
self.transform = transforms.Compose(
|
25 |
+
[
|
26 |
+
transforms.Resize((32, 32)),
|
27 |
+
transforms.ToTensor(),
|
28 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
29 |
+
]
|
30 |
+
)
|
31 |
+
self.loader = partial(
|
32 |
+
DataLoader,
|
33 |
+
batch_size=self.batch_size,
|
34 |
+
pin_memory=True,
|
35 |
+
num_workers=self.num_workers,
|
36 |
+
persistent_workers=True
|
37 |
+
)
|
38 |
+
|
39 |
+
def setup(self, stage: str):
|
40 |
+
cifar_partial = partial(
|
41 |
+
CIFAR10,
|
42 |
+
root=self.data_dir, transform=self.transform, download=True
|
43 |
+
)
|
44 |
+
if stage == "fit":
|
45 |
+
retrying = True
|
46 |
+
while retrying:
|
47 |
+
try:
|
48 |
+
cifar_full = cifar_partial(train=True)
|
49 |
+
retrying = False
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
self.cifar_train, self.cifar_val, _ = random_split(
|
53 |
+
dataset=cifar_full,
|
54 |
+
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
|
55 |
+
generator=torch.Generator().manual_seed(self.seed)
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
retrying = True
|
59 |
+
while retrying:
|
60 |
+
try:
|
61 |
+
self.cifar_test = cifar_partial(train=False)
|
62 |
+
retrying = False
|
63 |
+
except:
|
64 |
+
pass
|
65 |
+
|
66 |
+
def train_dataloader(self):
|
67 |
+
return self.loader(dataset=self.cifar_train)
|
68 |
+
|
69 |
+
def val_dataloader(self):
|
70 |
+
return self.loader(dataset=self.cifar_val)
|
71 |
+
|
72 |
+
def test_dataloader(self):
|
73 |
+
return self.loader(dataset=self.cifar_test)
|
diffusion/dataset/mnist.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from torchvision.datasets import MNIST
|
4 |
+
from torch.utils.data import DataLoader, random_split
|
5 |
+
from torchvision import transforms
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
|
9 |
+
class MNISTDataModule(pl.LightningDataModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
data_dir: str = "./",
|
13 |
+
batch_size: int = 32,
|
14 |
+
num_workers: int = 0,
|
15 |
+
seed: int = 42,
|
16 |
+
train_ratio: float = 0.99
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.data_dir = data_dir
|
20 |
+
self.batch_size = batch_size
|
21 |
+
self.num_workers = num_workers
|
22 |
+
self.train_ratio = min(train_ratio, 0.99)
|
23 |
+
self.seed = seed
|
24 |
+
self.transform = transforms.Compose(
|
25 |
+
[
|
26 |
+
transforms.Resize((32, 32)),
|
27 |
+
transforms.ToTensor(),
|
28 |
+
transforms.Normalize(mean=(0.5), std=(0.5))
|
29 |
+
]
|
30 |
+
)
|
31 |
+
self.loader = partial(
|
32 |
+
DataLoader,
|
33 |
+
batch_size=self.batch_size,
|
34 |
+
pin_memory=True,
|
35 |
+
num_workers=self.num_workers,
|
36 |
+
persistent_workers=True
|
37 |
+
)
|
38 |
+
|
39 |
+
def setup(self, stage: str):
|
40 |
+
mnist_partial = partial(
|
41 |
+
MNIST,
|
42 |
+
root=self.data_dir, transform=self.transform, download=True
|
43 |
+
)
|
44 |
+
if stage == "fit":
|
45 |
+
retrying = True
|
46 |
+
while retrying:
|
47 |
+
try:
|
48 |
+
mnist_full = mnist_partial(train=True)
|
49 |
+
retrying = False
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
self.mnist_train, self.mnist_val, _ = random_split(
|
53 |
+
dataset=mnist_full,
|
54 |
+
lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
|
55 |
+
generator=torch.Generator().manual_seed(self.seed)
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
retrying = True
|
59 |
+
while retrying:
|
60 |
+
try:
|
61 |
+
self.mnist_test = mnist_partial(train=False)
|
62 |
+
retrying = False
|
63 |
+
except:
|
64 |
+
pass
|
65 |
+
|
66 |
+
def train_dataloader(self):
|
67 |
+
return self.loader(dataset=self.mnist_train)
|
68 |
+
|
69 |
+
def val_dataloader(self):
|
70 |
+
return self.loader(dataset=self.mnist_val)
|
71 |
+
|
72 |
+
def test_dataloader(self):
|
73 |
+
return self.loader(dataset=self.mnist_test)
|
diffusion/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .diffusion import *
|
2 |
+
from .ldm import *
|
diffusion/model/diffusion/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .unet import *
|
2 |
+
from .model import *
|
3 |
+
from .sampling import *
|
4 |
+
from .scheduler import *
|
diffusion/model/diffusion/model.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import diffusion
|
6 |
+
import wandb
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
9 |
+
|
10 |
+
|
11 |
+
class DiffusionModel(pl.LightningModule):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
lr: float = 1e-4,
|
15 |
+
max_timesteps: int = 1000,
|
16 |
+
beta_1: float = 0.0001,
|
17 |
+
beta_2: float = 0.02,
|
18 |
+
in_channels: int = 3,
|
19 |
+
dim: int = 32,
|
20 |
+
num_classes: int | None = 10,
|
21 |
+
sample_per_epochs: int = 50,
|
22 |
+
n_samples: int = 16
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.save_hyperparameters()
|
26 |
+
|
27 |
+
self.model = diffusion.ConditionalUNet(
|
28 |
+
c_in=in_channels,
|
29 |
+
c_out=in_channels,
|
30 |
+
num_classes=num_classes
|
31 |
+
)
|
32 |
+
self.lr = lr
|
33 |
+
self.max_timesteps = max_timesteps
|
34 |
+
self.in_channels = in_channels
|
35 |
+
self.dim = dim
|
36 |
+
self.num_classes = num_classes
|
37 |
+
|
38 |
+
self.scheduler = diffusion.LinearScheduler(
|
39 |
+
max_timesteps, beta_1, beta_2
|
40 |
+
)
|
41 |
+
|
42 |
+
self.criterion = nn.MSELoss()
|
43 |
+
|
44 |
+
self.spe = sample_per_epochs
|
45 |
+
self.n_samples = n_samples
|
46 |
+
self.epoch_count = 0
|
47 |
+
self.train_loss = []
|
48 |
+
self.val_loss = []
|
49 |
+
|
50 |
+
self.sampling_kwargs = {
|
51 |
+
'model': self.model,
|
52 |
+
'scheduler': self.scheduler,
|
53 |
+
'max_timesteps': self.max_timesteps,
|
54 |
+
'in_channels': self.in_channels,
|
55 |
+
'dim': self.dim,
|
56 |
+
}
|
57 |
+
|
58 |
+
def _batch_index_select(
|
59 |
+
self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
t: torch.Tensor,
|
62 |
+
device: torch.device
|
63 |
+
):
|
64 |
+
# x.shape = [T,]
|
65 |
+
# t.shape = [B,]
|
66 |
+
if x.device != device:
|
67 |
+
x = x.to(device)
|
68 |
+
if t.device != device:
|
69 |
+
t = t.to(device)
|
70 |
+
x_select = x.gather(dim=-1, index=t)
|
71 |
+
return x_select[:, None, None, None] # [B,1]
|
72 |
+
|
73 |
+
def noising(
|
74 |
+
self,
|
75 |
+
x_0: torch.Tensor,
|
76 |
+
t: torch.Tensor
|
77 |
+
):
|
78 |
+
noise = torch.randn_like(x_0, device=x_0.device)
|
79 |
+
new_x = self.scheduler.get('sqrt_alpha_hat', t) * x_0
|
80 |
+
new_noise = self.scheduler.get('sqrt_one_minus_alpha_hat', t) * noise
|
81 |
+
return new_x + new_noise, noise
|
82 |
+
|
83 |
+
def sampling(self, labels=None, n_samples: int = 16):
|
84 |
+
return diffusion.ddpm_sampling(
|
85 |
+
n_samples=n_samples,
|
86 |
+
labels=labels,
|
87 |
+
**self.sampling_kwargs
|
88 |
+
)
|
89 |
+
|
90 |
+
def sampling_demo(self, labels=None, n_samples: int = 16):
|
91 |
+
return diffusion.ddpm_sampling_demo(
|
92 |
+
n_samples=n_samples,
|
93 |
+
labels=labels,
|
94 |
+
**self.sampling_kwargs
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, x_0, labels):
|
98 |
+
t = torch.randint(
|
99 |
+
low=0, high=self.max_timesteps, size=(x_0.shape[0],), device=x_0.device
|
100 |
+
)
|
101 |
+
x_noise, noise = self.noising(x_0, t)
|
102 |
+
noise_pred = self.model(x_noise, t, labels)
|
103 |
+
return noise, noise_pred
|
104 |
+
|
105 |
+
def training_step(self, batch, idx):
|
106 |
+
if isinstance(batch, torch.Tensor):
|
107 |
+
x_0 = batch
|
108 |
+
labels = None
|
109 |
+
else:
|
110 |
+
x_0, labels = batch
|
111 |
+
if np.random.random() < 0.1:
|
112 |
+
labels = None
|
113 |
+
noise, noise_pred = self(x_0, labels)
|
114 |
+
loss = self.criterion(noise, noise_pred)
|
115 |
+
self.train_loss.append(loss)
|
116 |
+
return loss
|
117 |
+
|
118 |
+
def validation_step(self, batch, idx):
|
119 |
+
if isinstance(batch, torch.Tensor):
|
120 |
+
x_0 = batch
|
121 |
+
labels = None
|
122 |
+
else:
|
123 |
+
x_0, labels = batch
|
124 |
+
noise, noise_pred = self(x_0, labels)
|
125 |
+
loss = self.criterion(noise, noise_pred)
|
126 |
+
self.val_loss.append(loss)
|
127 |
+
return loss
|
128 |
+
|
129 |
+
def on_train_epoch_end(self) -> None:
|
130 |
+
self.log_dict(
|
131 |
+
{
|
132 |
+
"train_loss": sum(self.train_loss) / len(self.train_loss)
|
133 |
+
},
|
134 |
+
sync_dist=True
|
135 |
+
)
|
136 |
+
self.train_loss.clear()
|
137 |
+
|
138 |
+
if self.epoch_count % self.spe == 0:
|
139 |
+
wandblog = self.logger.experiment
|
140 |
+
x_t = self.sampling(n_samples=self.n_samples)
|
141 |
+
img_array = [x_t[i] for i in range(x_t.shape[0])]
|
142 |
+
|
143 |
+
wandblog.log(
|
144 |
+
{
|
145 |
+
"sampling": wandb.Image(
|
146 |
+
make_grid(img_array, nrow=4).permute(1, 2, 0).cpu().numpy(),
|
147 |
+
caption="Sampled Image!"
|
148 |
+
)
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
self.epoch_count += 1
|
153 |
+
|
154 |
+
def on_validation_epoch_end(self):
|
155 |
+
self.log_dict(
|
156 |
+
{
|
157 |
+
"val_loss": sum(self.val_loss) / len(self.val_loss)
|
158 |
+
},
|
159 |
+
sync_dist=True
|
160 |
+
)
|
161 |
+
self.val_loss.clear()
|
162 |
+
|
163 |
+
def configure_optimizers(self):
|
164 |
+
optimizer = torch.optim.AdamW(
|
165 |
+
params=self.parameters(),
|
166 |
+
lr=self.lr,
|
167 |
+
weight_decay=0.001,
|
168 |
+
betas=(0.9, 0.999)
|
169 |
+
)
|
170 |
+
scheduler = OneCycleLR(
|
171 |
+
optimizer=optimizer,
|
172 |
+
max_lr=self.lr,
|
173 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
174 |
+
|
175 |
+
)
|
176 |
+
return {
|
177 |
+
'optimizer': optimizer,
|
178 |
+
'lr_scheduler': scheduler
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
a = torch.randn(32, 3, 32, 32)
|
184 |
+
model = DiffusionModel(max_timesteps=10)
|
185 |
+
n, n_pred = model(a)
|
186 |
+
print(n.shape, n_pred.shape)
|
187 |
+
print(torch.mean((n-n_pred)**2))
|
188 |
+
print(model.sampling(1))
|
diffusion/model/diffusion/sampling.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def ddpm_sampling_timestep(
|
5 |
+
x_t,
|
6 |
+
model,
|
7 |
+
scheduler,
|
8 |
+
labels,
|
9 |
+
t,
|
10 |
+
n_samples: int = 16,
|
11 |
+
cfg_scale: int = 3,
|
12 |
+
):
|
13 |
+
time = torch.full((n_samples,), fill_value=t, device=model.device)
|
14 |
+
pred_noise = model(x_t, time, labels)
|
15 |
+
if cfg_scale > 0:
|
16 |
+
uncond_pred_noise = model(x_t, time, None)
|
17 |
+
pred_noise = torch.lerp(uncond_pred_noise, pred_noise, cfg_scale)
|
18 |
+
alpha = scheduler.get('alpha', time)
|
19 |
+
sqrt_alpha = scheduler.get('sqrt_alpha', time)
|
20 |
+
somah = scheduler.get('sqrt_one_minus_alpha_hat', time)
|
21 |
+
sqrt_beta = scheduler.get('sqrt_beta', time)
|
22 |
+
if t > 0:
|
23 |
+
noise = torch.randn_like(x_t, device=model.device)
|
24 |
+
else:
|
25 |
+
noise = torch.zeros_like(x_t, device=model.device)
|
26 |
+
|
27 |
+
x_t_new = 1 / sqrt_alpha * (x_t - (1-alpha) / somah * pred_noise) + sqrt_beta * noise
|
28 |
+
return x_t_new.clamp(-1, 1)
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def ddpm_sampling(
|
33 |
+
model,
|
34 |
+
scheduler,
|
35 |
+
n_samples: int = 16,
|
36 |
+
max_timesteps: int = 1000,
|
37 |
+
in_channels: int = 3,
|
38 |
+
dim: int = 32,
|
39 |
+
cfg_scale: int = 3,
|
40 |
+
|
41 |
+
labels=None
|
42 |
+
):
|
43 |
+
if labels is not None:
|
44 |
+
n_samples = labels.shape[0]
|
45 |
+
|
46 |
+
x_t = torch.randn(
|
47 |
+
n_samples, in_channels, dim, dim, device=model.device
|
48 |
+
)
|
49 |
+
model.eval()
|
50 |
+
for t in range(max_timesteps-1, -1, -1):
|
51 |
+
x_t = ddpm_sampling_timestep(x_t=x_t, model=model, scheduler=scheduler,
|
52 |
+
labels=labels, t=t, n_samples=n_samples,
|
53 |
+
cfg_scale=cfg_scale)
|
54 |
+
|
55 |
+
model.train()
|
56 |
+
x_t = (x_t + 1) / 2 * 255. # range [0,255]
|
57 |
+
return x_t.type(torch.uint8)
|
58 |
+
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def ddpm_sampling_demo(
|
62 |
+
model,
|
63 |
+
scheduler,
|
64 |
+
n_samples: int = 16,
|
65 |
+
max_timesteps: int = 1000,
|
66 |
+
in_channels: int = 3,
|
67 |
+
dim: int = 32,
|
68 |
+
cfg_scale: int = 3,
|
69 |
+
labels=None
|
70 |
+
):
|
71 |
+
if labels is not None:
|
72 |
+
n_samples = labels.shape[0]
|
73 |
+
|
74 |
+
x_t = torch.randn(
|
75 |
+
n_samples, in_channels, dim, dim, device=model.device
|
76 |
+
)
|
77 |
+
model.eval()
|
78 |
+
for t in range(max_timesteps-1, -1, -1):
|
79 |
+
x_t = ddpm_sampling_timestep(x_t=x_t, model=model, scheduler=scheduler,
|
80 |
+
labels=labels, t=t, n_samples=n_samples,
|
81 |
+
cfg_scale=cfg_scale)
|
82 |
+
yield ((x_t + 1) / 2 * 255).type(torch.uint8)
|
diffusion/model/diffusion/scheduler.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class LinearScheduler:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
max_timesteps: int = 1000,
|
8 |
+
beta_1: int = 0.0001,
|
9 |
+
beta_2: int = 0.02
|
10 |
+
) -> None:
|
11 |
+
self.beta = torch.linspace(beta_1, beta_2, max_timesteps)
|
12 |
+
self.sqrt_beta = torch.sqrt(self.beta)[:, None, None, None]
|
13 |
+
self.alpha = (1 - self.beta)[:, None, None, None]
|
14 |
+
self.sqrt_alpha = torch.sqrt(self.alpha)
|
15 |
+
self.alpha_hat = torch.cumprod(1 - self.beta, dim=0)[:, None, None, None]
|
16 |
+
self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
|
17 |
+
self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)
|
18 |
+
|
19 |
+
def get(self, key: str, t: torch.Tensor):
|
20 |
+
return self.__dict__[key].to(t.device)[t]
|
diffusion/model/diffusion/unet.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
|
8 |
+
class SelfAttention(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
channels: int
|
12 |
+
):
|
13 |
+
super(SelfAttention, self).__init__()
|
14 |
+
self.channels = channels
|
15 |
+
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
|
16 |
+
self.ln = nn.LayerNorm([channels])
|
17 |
+
self.ff_self = nn.Sequential(
|
18 |
+
nn.LayerNorm([channels]),
|
19 |
+
nn.Linear(channels, channels),
|
20 |
+
nn.GELU(),
|
21 |
+
nn.Linear(channels, channels),
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
B, C, H, W = x.shape
|
26 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
27 |
+
x_ln = self.ln(x)
|
28 |
+
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
|
29 |
+
attention_value = attention_value + x
|
30 |
+
attention_value = self.ff_self(attention_value) + attention_value
|
31 |
+
return rearrange(attention_value, 'b (h w) c -> b c h w', h=H, w=W).contiguous()
|
32 |
+
|
33 |
+
|
34 |
+
class DoubleConv(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
in_channels: int,
|
38 |
+
out_channels: int,
|
39 |
+
mid_channels: int | None = None,
|
40 |
+
residual: bool = False
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.residual = residual
|
44 |
+
if not mid_channels:
|
45 |
+
mid_channels = out_channels
|
46 |
+
self.double_conv = nn.Sequential(
|
47 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
48 |
+
nn.GroupNorm(8, mid_channels),
|
49 |
+
nn.GELU(),
|
50 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
51 |
+
nn.GroupNorm(8, out_channels),
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
if self.residual:
|
56 |
+
return (x + self.double_conv(x)) / 1.414
|
57 |
+
else:
|
58 |
+
return self.double_conv(x)
|
59 |
+
|
60 |
+
|
61 |
+
class DownSample(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
in_channels: int,
|
65 |
+
out_channels: int,
|
66 |
+
emb_dim: int = 256
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.maxpool_conv = nn.Sequential(
|
70 |
+
nn.MaxPool2d(2),
|
71 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
72 |
+
DoubleConv(in_channels, out_channels),
|
73 |
+
)
|
74 |
+
|
75 |
+
self.emb_layer = nn.Sequential(
|
76 |
+
nn.SiLU(),
|
77 |
+
nn.Linear(emb_dim, out_channels),
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x, t):
|
81 |
+
x = self.maxpool_conv(x)
|
82 |
+
_, _, H, W = x.shape
|
83 |
+
emb = repeat(self.emb_layer(t), 'b d -> b d h w', h=H, w=W).contiguous()
|
84 |
+
return x + emb
|
85 |
+
|
86 |
+
|
87 |
+
class UpSample(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels: int,
|
91 |
+
out_channels: int,
|
92 |
+
emb_dim: int = 256
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
|
97 |
+
self.conv = nn.Sequential(
|
98 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
99 |
+
DoubleConv(in_channels, out_channels, in_channels // 2),
|
100 |
+
)
|
101 |
+
|
102 |
+
self.emb_layer = nn.Sequential(
|
103 |
+
nn.SiLU(),
|
104 |
+
nn.Linear(emb_dim, out_channels)
|
105 |
+
)
|
106 |
+
|
107 |
+
def forward(self, x, skip_x, t):
|
108 |
+
x = self.up(x)
|
109 |
+
x = torch.cat([skip_x, x], dim=1)
|
110 |
+
x = self.conv(x)
|
111 |
+
_, _, H, W = x.shape
|
112 |
+
emb = repeat(self.emb_layer(t), 'b d -> b d h w', h=H, w=W).contiguous()
|
113 |
+
return x + emb
|
114 |
+
|
115 |
+
|
116 |
+
class UNet(pl.LightningModule):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
c_in: int = 3,
|
120 |
+
c_out: int = 3,
|
121 |
+
time_dim: int = 256
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
self.time_dim = time_dim
|
125 |
+
|
126 |
+
self.time_embed = nn.Sequential(
|
127 |
+
nn.Linear(time_dim, time_dim),
|
128 |
+
nn.SiLU(),
|
129 |
+
nn.Linear(time_dim, time_dim),
|
130 |
+
)
|
131 |
+
self.inc = DoubleConv(in_channels=c_in, out_channels=64)
|
132 |
+
self.down1 = DownSample(in_channels=64, out_channels=128)
|
133 |
+
self.sa1 = SelfAttention(channels=128)
|
134 |
+
self.down2 = DownSample(in_channels=128, out_channels=256)
|
135 |
+
self.sa2 = SelfAttention(channels=256)
|
136 |
+
self.down3 = DownSample(in_channels=256, out_channels=256)
|
137 |
+
self.sa3 = SelfAttention(channels=256)
|
138 |
+
|
139 |
+
self.mid1 = DoubleConv(in_channels=256, out_channels=512)
|
140 |
+
self.mid2 = DoubleConv(in_channels=512, out_channels=512)
|
141 |
+
|
142 |
+
self.up1 = UpSample(in_channels=512, out_channels=256)
|
143 |
+
self.sa4 = SelfAttention(channels=256)
|
144 |
+
self.up2 = UpSample(in_channels=256, out_channels=128)
|
145 |
+
self.sa5 = SelfAttention(channels=128)
|
146 |
+
self.up3 = UpSample(in_channels=128, out_channels=64)
|
147 |
+
self.sa6 = SelfAttention(channels=64)
|
148 |
+
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
|
149 |
+
|
150 |
+
def pos_encoding(self, t, channels):
|
151 |
+
inv_freq = 1.0 / (
|
152 |
+
10000
|
153 |
+
** (torch.arange(0, channels, 2).float().to(t.device) / channels)
|
154 |
+
) * t.repeat(1, channels // 2)
|
155 |
+
|
156 |
+
pos_enc = torch.zeros((t.shape[0], channels), device=t.device)
|
157 |
+
pos_enc[:, 0::2] = torch.sin(inv_freq)
|
158 |
+
pos_enc[:, 1::2] = torch.cos(inv_freq)
|
159 |
+
return pos_enc
|
160 |
+
|
161 |
+
def forward_unet(self, x, t):
|
162 |
+
x1 = self.inc(x)
|
163 |
+
x2 = self.down1(x1, t)
|
164 |
+
x2 = self.sa1(x2)
|
165 |
+
x3 = self.down2(x2, t)
|
166 |
+
x3 = self.sa2(x3)
|
167 |
+
x4 = self.down3(x3, t)
|
168 |
+
x4 = self.sa3(x4)
|
169 |
+
|
170 |
+
x4 = self.mid1(x4)
|
171 |
+
x4 = self.mid2(x4)
|
172 |
+
|
173 |
+
x = self.up1(x4, x3, t)
|
174 |
+
x = self.sa4(x)
|
175 |
+
x = self.up2(x, x2, t)
|
176 |
+
x = self.sa5(x)
|
177 |
+
x = self.up3(x, x1, t)
|
178 |
+
x = self.sa6(x)
|
179 |
+
output = self.outc(x)
|
180 |
+
return output
|
181 |
+
|
182 |
+
def forward(
|
183 |
+
self,
|
184 |
+
x: torch.Tensor,
|
185 |
+
t: torch.Tensor
|
186 |
+
):
|
187 |
+
t = t.unsqueeze(-1).type(torch.float)
|
188 |
+
t = self.pos_encoding(t, self.time_dim)
|
189 |
+
t = self.time_embed(t)
|
190 |
+
return self.forward_unet(x, t)
|
191 |
+
|
192 |
+
|
193 |
+
class ConditionalUNet(UNet):
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
c_in: int = 3,
|
197 |
+
c_out: int = 3,
|
198 |
+
time_dim: int = 256,
|
199 |
+
num_classes: int | None = None,
|
200 |
+
):
|
201 |
+
super().__init__(c_in, c_out, time_dim)
|
202 |
+
self.num_classes = num_classes
|
203 |
+
if num_classes is not None:
|
204 |
+
self.cls_embed = nn.Embedding(num_classes, time_dim)
|
205 |
+
|
206 |
+
def forward(
|
207 |
+
self,
|
208 |
+
x: torch.Tensor,
|
209 |
+
t: torch.Tensor,
|
210 |
+
label: torch.Tensor | None = None
|
211 |
+
):
|
212 |
+
t = t.unsqueeze(-1).type(torch.float)
|
213 |
+
t = self.pos_encoding(t, self.time_dim)
|
214 |
+
t = self.time_embed(t)
|
215 |
+
|
216 |
+
if label is not None and self.num_classes is not None:
|
217 |
+
t += self.cls_embed(label)
|
218 |
+
return self.forward_unet(x, t)
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == '__main__':
|
222 |
+
net = ConditionalUNet()
|
223 |
+
print(sum([p.numel() for p in net.parameters()]))
|
224 |
+
x = torch.randn(2, 3, 32, 32)
|
225 |
+
t = x.new_tensor([500] * x.shape[0]).long()
|
226 |
+
print(t)
|
227 |
+
print(net(x, t).shape)
|
diffusion/model/ldm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import *
|
diffusion/model/ldm/model.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
|
4 |
+
class LatentDiffusionModel(pl.LightningModule):
|
5 |
+
pass
|
diffusion/model/ldm/tests/__init__.py
ADDED
File without changes
|
diffusion/tests/__init__.py
ADDED
File without changes
|
diffusion/train/__init__.py
ADDED
File without changes
|
diffusion/train/__main__.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning.loggers import WandbLogger
|
2 |
+
import diffusion
|
3 |
+
import torch
|
4 |
+
import wandb
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
|
9 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
# PARSERs
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
'--dataset', '-d', type=str, default='mnist',
|
17 |
+
help='choose dataset'
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
'--data_dir', '-dd', type=str, default='./data/',
|
21 |
+
help='model name'
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
'--max_epochs', '-me', type=int, default=200,
|
25 |
+
help='max epoch'
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
'--batch_size', '-bs', type=int, default=32,
|
29 |
+
help='batch size'
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
'--train_ratio', '-tr', type=float, default=0.99,
|
33 |
+
help='batch size'
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
'--timesteps', '-ts', type=int, default=1000,
|
37 |
+
help='max timesteps diffusion'
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
'--max_batch_size', '-mbs', type=int, default=32,
|
41 |
+
help='max batch size'
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
'--lr', '-l', type=float, default=1e-4,
|
45 |
+
help='learning rate'
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
'--num_workers', '-nw', type=int, default=4,
|
49 |
+
help='number of workers'
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
'--seed', '-s', type=int, default=42,
|
53 |
+
help='seed'
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
'--name', '-n', type=str, default=None,
|
57 |
+
help='name of the experiment'
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
'--pbar', action='store_true',
|
61 |
+
help='progress bar'
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
'--precision', '-p', type=str, default='32',
|
65 |
+
help='numerical precision'
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
'--sample_per_epochs', '-spe', type=int, default=25,
|
69 |
+
help='sample every n epochs'
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
'--n_samples', '-ns', type=int, default=4,
|
73 |
+
help='number of workers'
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
'--monitor', '-m', type=str, default='val_loss',
|
77 |
+
help='callbacks monitor'
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
'--wandb', '-wk', type=str, default=None,
|
81 |
+
help='wandb API key'
|
82 |
+
)
|
83 |
+
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
# SEED
|
87 |
+
pl.seed_everything(args.seed, workers=True)
|
88 |
+
|
89 |
+
# WANDB (OPTIONAL)
|
90 |
+
if args.wandb is not None:
|
91 |
+
wandb.login(key=args.wandb) # API KEY
|
92 |
+
name = args.name or f"diffusion-{args.max_epochs}-{args.batch_size}-{args.lr}"
|
93 |
+
logger = WandbLogger(
|
94 |
+
project="diffusion-model",
|
95 |
+
name=name,
|
96 |
+
log_model=False
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
logger = None
|
100 |
+
|
101 |
+
# DATAMODULE
|
102 |
+
if args.dataset == "mnist":
|
103 |
+
DATAMODULE = diffusion.MNISTDataModule
|
104 |
+
img_dim = 32
|
105 |
+
num_classes = 10
|
106 |
+
elif args.dataset == "cifar10":
|
107 |
+
DATAMODULE = diffusion.CIFAR10DataModule
|
108 |
+
img_dim = 32
|
109 |
+
num_classes = 10
|
110 |
+
elif args.dataset == "celeba":
|
111 |
+
DATAMODULE = diffusion.CelebADataModule
|
112 |
+
img_dim = 64
|
113 |
+
num_classes = None
|
114 |
+
|
115 |
+
datamodule = DATAMODULE(
|
116 |
+
data_dir=args.data_dir,
|
117 |
+
batch_size=args.batch_size,
|
118 |
+
num_workers=args.num_workers,
|
119 |
+
seed=args.seed,
|
120 |
+
train_ratio=args.train_ratio
|
121 |
+
)
|
122 |
+
|
123 |
+
# MODEL
|
124 |
+
in_channels = 1 if args.dataset == "mnist" else 3
|
125 |
+
model = diffusion.DiffusionModel(
|
126 |
+
lr=args.lr,
|
127 |
+
in_channels=in_channels,
|
128 |
+
sample_per_epochs=args.sample_per_epochs,
|
129 |
+
max_timesteps=args.timesteps,
|
130 |
+
dim=img_dim,
|
131 |
+
num_classes=num_classes,
|
132 |
+
n_samples=args.n_samples
|
133 |
+
)
|
134 |
+
|
135 |
+
# CALLBACK
|
136 |
+
root_path = os.path.join(os.getcwd(), "checkpoints")
|
137 |
+
callback = diffusion.ModelCallback(
|
138 |
+
root_path=root_path,
|
139 |
+
ckpt_monitor=args.monitor
|
140 |
+
)
|
141 |
+
|
142 |
+
# STRATEGY
|
143 |
+
strategy = 'ddp_find_unused_parameters_true' if torch.cuda.is_available() else 'auto'
|
144 |
+
|
145 |
+
# TRAINER
|
146 |
+
trainer = pl.Trainer(
|
147 |
+
default_root_dir=root_path,
|
148 |
+
logger=logger,
|
149 |
+
callbacks=callback.get_callback(),
|
150 |
+
gradient_clip_val=0.5,
|
151 |
+
max_epochs=args.max_epochs,
|
152 |
+
enable_progress_bar=args.pbar,
|
153 |
+
deterministic=False,
|
154 |
+
precision=args.precision,
|
155 |
+
strategy=strategy,
|
156 |
+
accumulate_grad_batches=max(int(args.max_batch_size / args.batch_size), 1)
|
157 |
+
)
|
158 |
+
|
159 |
+
# FIT MODEL
|
160 |
+
trainer.fit(model=model, datamodule=datamodule)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == '__main__':
|
164 |
+
main()
|
diffusion/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .callback import *
|
2 |
+
from .ema import *
|
diffusion/utils/callback.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import diffusion
|
3 |
+
from pytorch_lightning.callbacks import (
|
4 |
+
ModelCheckpoint,
|
5 |
+
LearningRateMonitor
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
class ModelCallback:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
root_path: str,
|
13 |
+
ckpt_monitor: str = "val_loss",
|
14 |
+
ckpt_mode: str = "min",
|
15 |
+
):
|
16 |
+
ckpt_path = os.path.join(os.path.join(root_path, "model/"))
|
17 |
+
if not os.path.exists(root_path):
|
18 |
+
os.makedirs(root_path)
|
19 |
+
if not os.path.exists(ckpt_path):
|
20 |
+
os.makedirs(ckpt_path)
|
21 |
+
|
22 |
+
self.ckpt_callback = ModelCheckpoint(
|
23 |
+
monitor=ckpt_monitor,
|
24 |
+
dirpath=ckpt_path,
|
25 |
+
filename="model",
|
26 |
+
save_top_k=1,
|
27 |
+
mode=ckpt_mode,
|
28 |
+
save_weights_only=True
|
29 |
+
)
|
30 |
+
|
31 |
+
self.lr_callback = LearningRateMonitor("step")
|
32 |
+
|
33 |
+
self.ema_callback = diffusion.EMACallback(decay=0.995)
|
34 |
+
|
35 |
+
def get_callback(self):
|
36 |
+
return [
|
37 |
+
self.ckpt_callback, self.lr_callback, self.ema_callback
|
38 |
+
]
|
diffusion/utils/ema.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning.callbacks import Callback
|
2 |
+
from timm.utils.model import get_state_dict, unwrap_model
|
3 |
+
from timm.utils.model_ema import ModelEmaV2
|
4 |
+
# Cell
|
5 |
+
|
6 |
+
|
7 |
+
class EMACallback(Callback):
|
8 |
+
"""
|
9 |
+
Model Exponential Moving Average. Empirically it has been found that using the moving average
|
10 |
+
of the trained parameters of a deep network is better than using its trained parameters directly.
|
11 |
+
|
12 |
+
If `use_ema_weights`, then the ema parameters of the network is set after training end.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, decay=0.9999, use_ema_weights: bool = True):
|
16 |
+
self.decay = decay
|
17 |
+
self.ema = None
|
18 |
+
self.use_ema_weights = use_ema_weights
|
19 |
+
|
20 |
+
def on_fit_start(self, trainer, pl_module, *args):
|
21 |
+
"Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights"
|
22 |
+
self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None)
|
23 |
+
|
24 |
+
def on_train_batch_end(
|
25 |
+
self, trainer, pl_module, *args
|
26 |
+
):
|
27 |
+
"Update the stored parameters using a moving average"
|
28 |
+
# Update currently maintained parameters.
|
29 |
+
self.ema.update(pl_module)
|
30 |
+
|
31 |
+
def on_validation_epoch_start(self, trainer, pl_module, *args):
|
32 |
+
"do validation using the stored parameters"
|
33 |
+
# save original parameters before replacing with EMA version
|
34 |
+
self.store(pl_module.parameters())
|
35 |
+
|
36 |
+
# update the LightningModule with the EMA weights
|
37 |
+
# ~ Copy EMA parameters to LightningModule
|
38 |
+
self.copy_to(self.ema.module.parameters(), pl_module.parameters())
|
39 |
+
|
40 |
+
def on_validation_end(self, trainer, pl_module, *args):
|
41 |
+
"Restore original parameters to resume training later"
|
42 |
+
self.restore(pl_module.parameters())
|
43 |
+
|
44 |
+
def on_train_end(self, trainer, pl_module, *args):
|
45 |
+
# update the LightningModule with the EMA weights
|
46 |
+
if self.use_ema_weights:
|
47 |
+
self.copy_to(self.ema.module.parameters(), pl_module.parameters())
|
48 |
+
msg = "Model weights replaced with the EMA version."
|
49 |
+
|
50 |
+
def on_save_checkpoint(self, trainer, pl_module, checkpoint, *args):
|
51 |
+
if self.ema is not None:
|
52 |
+
return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)}
|
53 |
+
|
54 |
+
def on_load_checkpoint(self, callback_state, *args):
|
55 |
+
if self.ema is not None:
|
56 |
+
self.ema.module.load_state_dict(callback_state["state_dict_ema"])
|
57 |
+
|
58 |
+
def store(self, parameters):
|
59 |
+
"Save the current parameters for restoring later."
|
60 |
+
self.collected_params = [param.clone() for param in parameters]
|
61 |
+
|
62 |
+
def restore(self, parameters):
|
63 |
+
"""
|
64 |
+
Restore the parameters stored with the `store` method.
|
65 |
+
Useful to validate the model with EMA parameters without affecting the
|
66 |
+
original optimization process.
|
67 |
+
"""
|
68 |
+
for c_param, param in zip(self.collected_params, parameters):
|
69 |
+
param.data.copy_(c_param.data)
|
70 |
+
|
71 |
+
def copy_to(self, shadow_parameters, parameters):
|
72 |
+
"Copy current parameters into given collection of parameters."
|
73 |
+
for s_param, param in zip(shadow_parameters, parameters):
|
74 |
+
if param.requires_grad:
|
75 |
+
param.data.copy_(s_param.data)
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "diffusion"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Võ Đình Đạt <[email protected]>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.10"
|
10 |
+
torch = "*"
|
11 |
+
lightning = "*"
|
12 |
+
einops = "^0.7.0"
|
13 |
+
torchvision = "*"
|
14 |
+
wandb = "^0.16.3"
|
15 |
+
torchaudio = "*"
|
16 |
+
tqdm = "^4.66.2"
|
17 |
+
timm = "^0.9.12"
|
18 |
+
gradio = "^4.18.0"
|
19 |
+
|
20 |
+
[build-system]
|
21 |
+
requires = ["poetry-core"]
|
22 |
+
build-backend = "poetry.core.masonry.api"
|
script/setup.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip install .
|