upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +58 -0
- README.md +1 -1
- configs/autoregressive_config.yaml +77 -0
- configs/onenode_config.yaml +12 -0
- configs/tokenizer_config.yaml +64 -0
- demo.py +170 -0
- examples/city.jpg +0 -0
- examples/food.jpg +0 -0
- examples/highland.webp +0 -0
- fid_stats/adm_in256_stats.npz +3 -0
- gen_demo.py +137 -0
- making_cache.py +46 -0
- paintmind/__init__.py +2 -0
- paintmind/config.py +110 -0
- paintmind/engine/gpt_trainer.py +892 -0
- paintmind/engine/misc.py +260 -0
- paintmind/engine/trainer.py +695 -0
- paintmind/engine/util.py +572 -0
- paintmind/stage1/__init__.py +0 -0
- paintmind/stage1/diffuse_slot.py +808 -0
- paintmind/stage1/diffusion/__init__.py +46 -0
- paintmind/stage1/diffusion/diffusion_utils.py +88 -0
- paintmind/stage1/diffusion/gaussian_diffusion.py +886 -0
- paintmind/stage1/diffusion/respace.py +130 -0
- paintmind/stage1/diffusion/timestep_sampler.py +150 -0
- paintmind/stage1/diffusion_transfomers.py +372 -0
- paintmind/stage1/fused_attention.py +94 -0
- paintmind/stage1/pos_embed.py +102 -0
- paintmind/stage1/quantize.py +93 -0
- paintmind/stage1/transport/__init__.py +63 -0
- paintmind/stage1/transport/integrators.py +130 -0
- paintmind/stage1/transport/path.py +192 -0
- paintmind/stage1/transport/transport.py +456 -0
- paintmind/stage1/transport/utils.py +29 -0
- paintmind/stage1/vision_transformers.py +267 -0
- paintmind/stage2/__init__.py +0 -0
- paintmind/stage2/causaldit.py +422 -0
- paintmind/stage2/diffloss.py +314 -0
- paintmind/stage2/generate.py +127 -0
- paintmind/stage2/gpt.py +451 -0
- paintmind/utils/__init__.py +0 -0
- paintmind/utils/datasets.py +77 -0
- paintmind/utils/device_utils.py +20 -0
- paintmind/utils/logger.py +170 -0
- paintmind/utils/lr_scheduler.py +15 -0
- paintmind/utils/transform.py +35 -0
- paintmind/version.py +1 -0
- requirements.txt +27 -0
- submitit_test.py +290 -0
- submitit_train.py +148 -0
Dockerfile
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
|
2 |
+
LABEL maintainer="Bingchen Zhao"
|
3 |
+
LABEL repository="Semanticist"
|
4 |
+
|
5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
6 |
+
|
7 |
+
RUN apt-get -y update \
|
8 |
+
&& apt-get install -y software-properties-common \
|
9 |
+
&& add-apt-repository ppa:deadsnakes/ppa
|
10 |
+
|
11 |
+
RUN apt install -y bash \
|
12 |
+
build-essential \
|
13 |
+
git \
|
14 |
+
git-lfs \
|
15 |
+
curl \
|
16 |
+
ca-certificates \
|
17 |
+
libsndfile1-dev \
|
18 |
+
libgl1 \
|
19 |
+
python3.10 \
|
20 |
+
python3.10-dev \
|
21 |
+
python3-pip \
|
22 |
+
python3.10-venv rsync sudo tmux && \
|
23 |
+
rm -rf /var/lib/apt/lists
|
24 |
+
|
25 |
+
# make sure to use venv
|
26 |
+
RUN python3.10 -m venv /opt/venv
|
27 |
+
ENV PATH="/opt/venv/bin:$PATH"
|
28 |
+
|
29 |
+
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
30 |
+
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
31 |
+
python3.10 -m uv pip install --no-cache-dir \
|
32 |
+
torch \
|
33 |
+
torchvision \
|
34 |
+
torchaudio \
|
35 |
+
invisible_watermark && \
|
36 |
+
python3.10 -m pip install --no-cache-dir \
|
37 |
+
accelerate \
|
38 |
+
datasets \
|
39 |
+
hf-doc-builder \
|
40 |
+
huggingface-hub \
|
41 |
+
hf_transfer \
|
42 |
+
Jinja2 \
|
43 |
+
librosa \
|
44 |
+
numpy==1.26.4 \
|
45 |
+
scipy \
|
46 |
+
tensorboard \
|
47 |
+
transformers \
|
48 |
+
pytorch-lightning matplotlib \
|
49 |
+
hf_transfer
|
50 |
+
|
51 |
+
# start Semanticist part
|
52 |
+
COPY . /work/Semanticist
|
53 |
+
WORKDIR /work/Semanticist
|
54 |
+
RUN ls && python3.10 -m pip install -r req_min.txt && \
|
55 |
+
python3.10 -m pip install git+https://github.com/cocodataset/panopticapi.git
|
56 |
+
|
57 |
+
CMD ["/bin/bash"]
|
58 |
+
# docker run -it --rm --runtime=nvidia --gpus all xx/xx:xx
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
|
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.20.1
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.20.1
|
8 |
+
app_file: demo.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
configs/autoregressive_config.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer:
|
2 |
+
target: paintmind.engine.gpt_trainer.GPTTrainer
|
3 |
+
params:
|
4 |
+
num_epoch: 400
|
5 |
+
blr: 1e-4
|
6 |
+
cosine_lr: False
|
7 |
+
warmup_epochs: 100
|
8 |
+
batch_size: 32
|
9 |
+
num_workers: 8
|
10 |
+
pin_memory: True
|
11 |
+
grad_accum_steps: 1
|
12 |
+
precision: 'bf16'
|
13 |
+
max_grad_norm: 1.0
|
14 |
+
enable_ema: True
|
15 |
+
save_every: 10000
|
16 |
+
sample_every: 5000
|
17 |
+
fid_every: 50000
|
18 |
+
eval_fid: False
|
19 |
+
result_folder: "./output/autoregressive"
|
20 |
+
log_dir: "./output/autoregressive/logs"
|
21 |
+
ae_cfg: 1.5
|
22 |
+
cfg: 1.5
|
23 |
+
cfg_schedule: "constant"
|
24 |
+
train_num_slots: 32
|
25 |
+
test_num_slots: 32
|
26 |
+
compile: True
|
27 |
+
enable_cache_latents: True
|
28 |
+
ae_model:
|
29 |
+
target: paintmind.stage1.diffuse_slot.DiffuseSlot
|
30 |
+
params:
|
31 |
+
encoder: 'vit_base_patch16'
|
32 |
+
enc_img_size: 256
|
33 |
+
enc_causal: True
|
34 |
+
enc_use_mlp: False
|
35 |
+
num_slots: 256
|
36 |
+
slot_dim: 16
|
37 |
+
norm_slots: True
|
38 |
+
dit_mask_type: 'replace'
|
39 |
+
cond_method: 'token'
|
40 |
+
dit_model: 'DiT-XL-2'
|
41 |
+
vae: 'xwen99/mar-vae-kl16'
|
42 |
+
num_sampling_steps: '250'
|
43 |
+
ckpt_path: ./output/tokenizer/models/step250000/custom_checkpoint_1.pkl
|
44 |
+
|
45 |
+
gpt_model:
|
46 |
+
target: GPT-L
|
47 |
+
params:
|
48 |
+
num_slots: 32
|
49 |
+
slot_dim: 16
|
50 |
+
num_classes: 1000
|
51 |
+
cls_token_num: 1
|
52 |
+
resid_dropout_p: 0.1
|
53 |
+
ffn_dropout_p: 0.1
|
54 |
+
diffloss_d: 12
|
55 |
+
diffloss_w: 1536
|
56 |
+
num_sampling_steps: '100'
|
57 |
+
diffusion_batch_mul: 4
|
58 |
+
token_drop_prob: 0
|
59 |
+
use_si: True
|
60 |
+
cond_method: 'concat'
|
61 |
+
decoupled_cfg: False
|
62 |
+
ckpt_path: None
|
63 |
+
|
64 |
+
dataset:
|
65 |
+
target: paintmind.utils.datasets.ImageNet
|
66 |
+
params:
|
67 |
+
root: ./dataset/imagenet/
|
68 |
+
split: train
|
69 |
+
aug: tencrop_cached
|
70 |
+
img_size: 256
|
71 |
+
|
72 |
+
test_dataset:
|
73 |
+
target: paintmind.utils.datasets.ImageNet
|
74 |
+
params:
|
75 |
+
root: ./dataset/imagenet/
|
76 |
+
split: val
|
77 |
+
img_size: 256
|
configs/onenode_config.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
deepspeed_config: {}
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
fsdp_config: {}
|
5 |
+
machine_rank: 0
|
6 |
+
main_process_ip: null
|
7 |
+
main_process_port: null
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 8
|
12 |
+
use_cpu: false
|
configs/tokenizer_config.yaml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer:
|
2 |
+
target: paintmind.engine.trainer.DiffusionTrainer
|
3 |
+
params:
|
4 |
+
num_epoch: 400
|
5 |
+
valid_size: 64
|
6 |
+
blr: 2.5e-5
|
7 |
+
cosine_lr: True
|
8 |
+
warmup_epochs: 100
|
9 |
+
batch_size: 32
|
10 |
+
num_workers: 16
|
11 |
+
pin_memory: True
|
12 |
+
grad_accum_steps: 1
|
13 |
+
precision: 'bf16'
|
14 |
+
max_grad_norm: 3.0
|
15 |
+
enable_ema: True
|
16 |
+
save_every: 10000
|
17 |
+
sample_every: 5000
|
18 |
+
fid_every: 50000
|
19 |
+
result_folder: "./output/tokenizer"
|
20 |
+
log_dit: "./output/tokenizer/logs"
|
21 |
+
cfg: 3.0
|
22 |
+
compile: True
|
23 |
+
model:
|
24 |
+
target: paintmind.stage1.diffuse_slot.DiffuseSlot
|
25 |
+
params:
|
26 |
+
encoder: 'vit_base_patch16'
|
27 |
+
enc_img_size: 256
|
28 |
+
enc_causal: True
|
29 |
+
enc_use_mlp: False
|
30 |
+
num_slots: 256
|
31 |
+
slot_dim: 16
|
32 |
+
norm_slots: True
|
33 |
+
dit_mask_type: 'replace'
|
34 |
+
cond_method: 'token'
|
35 |
+
dit_model: 'DiT-XL-2'
|
36 |
+
vae: 'xwen99/mar-vae-kl16'
|
37 |
+
enable_nest: False
|
38 |
+
enable_nest_after: 50
|
39 |
+
nest_rho: 0.03
|
40 |
+
nest_dist: uniform
|
41 |
+
nest_null_prob: 0
|
42 |
+
nest_allow_zero: False
|
43 |
+
use_repa: True
|
44 |
+
repa_encoder: dinov2_vitb
|
45 |
+
repa_encoder_depth: 8
|
46 |
+
repa_loss_weight: 1.0
|
47 |
+
eval_fid: True
|
48 |
+
fid_stats: 'fid_stats/adm_in256_stats.npz'
|
49 |
+
num_sampling_steps: '250'
|
50 |
+
ckpt_path: None
|
51 |
+
|
52 |
+
dataset:
|
53 |
+
target: paintmind.utils.datasets.ImageNet
|
54 |
+
params:
|
55 |
+
root: ./dataset/imagenet/
|
56 |
+
split: train
|
57 |
+
img_size: 256
|
58 |
+
|
59 |
+
test_dataset:
|
60 |
+
target: paintmind.utils.datasets.ImageNet
|
61 |
+
params:
|
62 |
+
root: ./dataset/imagenet/
|
63 |
+
split: val
|
64 |
+
img_size: 256
|
demo.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
from paintmind.engine.util import instantiate_from_config
|
11 |
+
from paintmind.stage1.diffuse_slot import DiffuseSlot
|
12 |
+
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename='semanticist_tok_XL.pkl')
|
15 |
+
config_path = 'configs/tokenizer_config.yaml'
|
16 |
+
cfg = OmegaConf.load(config_path)
|
17 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
18 |
+
from paintmind.utils.datasets import vae_transforms
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
transform = vae_transforms('test')
|
22 |
+
|
23 |
+
|
24 |
+
def norm_ip(img, low, high):
|
25 |
+
img.clamp_(min=low, max=high)
|
26 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
27 |
+
|
28 |
+
def norm_range(t, value_range):
|
29 |
+
if value_range is not None:
|
30 |
+
norm_ip(t, value_range[0], value_range[1])
|
31 |
+
else:
|
32 |
+
norm_ip(t, float(t.min()), float(t.max()))
|
33 |
+
|
34 |
+
from PIL import Image
|
35 |
+
def convert_np(img):
|
36 |
+
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
|
37 |
+
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
38 |
+
return ndarr
|
39 |
+
def convert_PIL(img):
|
40 |
+
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
|
41 |
+
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
42 |
+
img = Image.fromarray(ndarr)
|
43 |
+
return img
|
44 |
+
|
45 |
+
ckpt = {k.replace('._orig_mod', ''): v for k, v in ckpt.items()}
|
46 |
+
|
47 |
+
model = DiffuseSlot(**cfg['trainer']['params']['model']['params'])
|
48 |
+
msg = model.load_state_dict(ckpt, strict=False)
|
49 |
+
model = model.to(device)
|
50 |
+
model = model.eval()
|
51 |
+
model.enable_nest = True
|
52 |
+
|
53 |
+
def viz_diff_slots(model, img, nums, cfg=1.0, return_img=False):
|
54 |
+
n_slots_inf = []
|
55 |
+
for num_slots_to_inference in nums:
|
56 |
+
recon_n = model(
|
57 |
+
img, None, sample=True, cfg=cfg,
|
58 |
+
inference_with_n_slots=num_slots_to_inference,
|
59 |
+
)
|
60 |
+
n_slots_inf.append(recon_n)
|
61 |
+
return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
|
62 |
+
|
63 |
+
# Removed process_image function as its functionality is now in the update_outputs function
|
64 |
+
|
65 |
+
with gr.Blocks() as demo:
|
66 |
+
with gr.Row():
|
67 |
+
# First column - Input and configs
|
68 |
+
with gr.Column(scale=1):
|
69 |
+
gr.Markdown("## Input")
|
70 |
+
input_image = gr.Image(label="Upload an image", type="numpy")
|
71 |
+
|
72 |
+
with gr.Group():
|
73 |
+
gr.Markdown("### Configuration")
|
74 |
+
show_gallery = gr.Checkbox(label="Show Gallery", value=False)
|
75 |
+
# You can add more config options here
|
76 |
+
# slider = gr.Slider(minimum=0, maximum=10, value=5, label="Processing Intensity")
|
77 |
+
slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
|
78 |
+
labels_input = gr.Textbox(
|
79 |
+
label="Gallery Labels (comma-separated)",
|
80 |
+
value="1, 4, 16, 64, 256",
|
81 |
+
placeholder="Enter comma-separated numbers for the number of slots to use"
|
82 |
+
)
|
83 |
+
|
84 |
+
# Second column - Output (conditionally rendered)
|
85 |
+
with gr.Column(scale=1):
|
86 |
+
gr.Markdown("## Output")
|
87 |
+
|
88 |
+
# Container for conditional rendering
|
89 |
+
with gr.Group(visible=False) as gallery_container:
|
90 |
+
gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
|
91 |
+
|
92 |
+
# Always visible output image
|
93 |
+
output_image = gr.Image(label="Processed Image", type="numpy")
|
94 |
+
|
95 |
+
# Handle form submission
|
96 |
+
submit_btn = gr.Button("Process")
|
97 |
+
|
98 |
+
# Define the processing logic
|
99 |
+
def update_outputs(image, show_gallery_value, slider_value, labels_text):
|
100 |
+
# Update the visibility of the gallery container
|
101 |
+
gallery_container.visible = show_gallery_value
|
102 |
+
|
103 |
+
try:
|
104 |
+
# Parse the labels from the text input
|
105 |
+
if labels_text and "," in labels_text:
|
106 |
+
labels = [int(label.strip()) for label in labels_text.split(",")]
|
107 |
+
else:
|
108 |
+
# Default labels if none provided or in wrong format
|
109 |
+
labels = [1, 4, 16, 64, 256]
|
110 |
+
except:
|
111 |
+
labels = [1, 4, 16, 64, 256]
|
112 |
+
while len(labels) < 3:
|
113 |
+
labels.append(256)
|
114 |
+
|
115 |
+
# Process the image based on configurations
|
116 |
+
if image is None:
|
117 |
+
# Return placeholder if no image is uploaded
|
118 |
+
placeholder = np.zeros((300, 300, 3), dtype=np.uint8)
|
119 |
+
return gallery_container, [], placeholder
|
120 |
+
image = Image.fromarray(image)
|
121 |
+
img = transform(image)
|
122 |
+
img = img.unsqueeze(0).to(device)
|
123 |
+
recon = viz_diff_slots(model, img, [256], cfg=slider_value)[0]
|
124 |
+
|
125 |
+
|
126 |
+
if not show_gallery_value:
|
127 |
+
# If only the image should be shown, return just the processed image
|
128 |
+
return gallery_container, [], recon
|
129 |
+
else:
|
130 |
+
model_decompose = viz_diff_slots(model, img, labels, cfg=slider_value)
|
131 |
+
# Create image variations and pair them with labels
|
132 |
+
gallery_images = [
|
133 |
+
(image, 'GT'),
|
134 |
+
# (np.array(Image.fromarray(image).convert("L").convert("RGB")), labels[1]),
|
135 |
+
# (np.array(Image.fromarray(image).rotate(180)), labels[2])
|
136 |
+
] + [(img, 'Recon. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
|
137 |
+
return gallery_container, gallery_images, image
|
138 |
+
|
139 |
+
# Connect the inputs and outputs
|
140 |
+
submit_btn.click(
|
141 |
+
fn=update_outputs,
|
142 |
+
inputs=[input_image, show_gallery, slider, labels_input],
|
143 |
+
outputs=[gallery_container, gallery, output_image]
|
144 |
+
)
|
145 |
+
|
146 |
+
# Also update when checkbox changes
|
147 |
+
show_gallery.change(
|
148 |
+
fn=lambda value: gr.update(visible=value),
|
149 |
+
inputs=[show_gallery],
|
150 |
+
outputs=[gallery_container]
|
151 |
+
)
|
152 |
+
|
153 |
+
# Add examples
|
154 |
+
examples = [
|
155 |
+
["examples/city.jpg", False, 4.0, "1,4,16,64,256"],
|
156 |
+
["examples/food.jpg", True, 4.0, "1,4,16,64,256"],
|
157 |
+
["examples/highland.webp", True, 4.0, "1,4,16,64,256"],
|
158 |
+
]
|
159 |
+
|
160 |
+
gr.Examples(
|
161 |
+
examples=examples,
|
162 |
+
inputs=[input_image, show_gallery, slider, labels_input],
|
163 |
+
outputs=[gallery_container, gallery, output_image],
|
164 |
+
fn=update_outputs,
|
165 |
+
cache_examples=True
|
166 |
+
)
|
167 |
+
|
168 |
+
# Launch the demo
|
169 |
+
if __name__ == "__main__":
|
170 |
+
demo.launch()
|
examples/city.jpg
ADDED
![]() |
examples/food.jpg
ADDED
![]() |
examples/highland.webp
ADDED
![]() |
fid_stats/adm_in256_stats.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e842c756177806d210e891893914620bee2cf5b779a5613ba5af5145d7c85289
|
3 |
+
size 33563124
|
gen_demo.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import time
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
|
8 |
+
def apply_transformations(image, numbers_str):
|
9 |
+
"""
|
10 |
+
Apply a series of transformations to an image based on a list of numbers.
|
11 |
+
Shows the progressive changes as each transformation is applied.
|
12 |
+
Returns both the current image and the full gallery of transformations.
|
13 |
+
"""
|
14 |
+
try:
|
15 |
+
# Parse the input numbers
|
16 |
+
numbers = [float(n.strip()) for n in numbers_str.split(',') if n.strip()]
|
17 |
+
if not numbers:
|
18 |
+
return image, [(image, "Original Image")]
|
19 |
+
|
20 |
+
# Convert PIL Image to numpy array for OpenCV operations
|
21 |
+
img = np.array(image)
|
22 |
+
|
23 |
+
# Initialize the result list with the original image
|
24 |
+
results = [(image, "Original Image")]
|
25 |
+
current_image = image
|
26 |
+
|
27 |
+
# Apply transformations based on each number
|
28 |
+
for i, value in enumerate(numbers):
|
29 |
+
# Make a copy of the current numpy image
|
30 |
+
if i == 0:
|
31 |
+
current_img = img.copy()
|
32 |
+
else:
|
33 |
+
current_img = np.array(current_image)
|
34 |
+
|
35 |
+
# Apply different transformations based on the value
|
36 |
+
transformation_type = ""
|
37 |
+
if i % 5 == 0: # Brightness adjustment
|
38 |
+
# Scale value to reasonable brightness adjustment
|
39 |
+
brightness = max(min(value, 100), -100) # Limit between -100 and 100
|
40 |
+
current_img = cv2.addWeighted(current_img, 1, np.zeros_like(current_img), 0, brightness)
|
41 |
+
transformation_type = f"Brightness: {brightness:.1f}"
|
42 |
+
|
43 |
+
elif i % 5 == 1: # Rotation
|
44 |
+
# Scale value to reasonable rotation angle
|
45 |
+
angle = value % 360
|
46 |
+
h, w = current_img.shape[:2]
|
47 |
+
center = (w // 2, h // 2)
|
48 |
+
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
49 |
+
current_img = cv2.warpAffine(current_img, rotation_matrix, (w, h))
|
50 |
+
transformation_type = f"Rotation: {angle:.1f}°"
|
51 |
+
|
52 |
+
elif i % 5 == 2: # Contrast adjustment
|
53 |
+
# Scale value to reasonable contrast adjustment
|
54 |
+
contrast = max(min(value / 10, 3), 0.5) # Limit between 0.5 and 3
|
55 |
+
current_img = cv2.convertScaleAbs(current_img, alpha=contrast, beta=0)
|
56 |
+
transformation_type = f"Contrast: {contrast:.1f}x"
|
57 |
+
|
58 |
+
elif i % 5 == 3: # Blur
|
59 |
+
# Scale value to reasonable blur kernel size
|
60 |
+
blur_amount = max(int(abs(value) % 20), 1)
|
61 |
+
if blur_amount % 2 == 0: # Ensure kernel size is odd
|
62 |
+
blur_amount += 1
|
63 |
+
current_img = cv2.GaussianBlur(current_img, (blur_amount, blur_amount), 0)
|
64 |
+
transformation_type = f"Blur: {blur_amount}px"
|
65 |
+
|
66 |
+
elif i % 5 == 4: # Hue shift (for color images)
|
67 |
+
if current_img.shape[-1] == 3: # Only for color images
|
68 |
+
# Convert to HSV
|
69 |
+
hsv_img = cv2.cvtColor(current_img, cv2.COLOR_RGB2HSV)
|
70 |
+
# Shift hue
|
71 |
+
hue_shift = int(value) % 180
|
72 |
+
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
|
73 |
+
# Convert back to RGB
|
74 |
+
current_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB)
|
75 |
+
transformation_type = f"Hue Shift: {hue_shift}°"
|
76 |
+
|
77 |
+
# Convert back to PIL Image and add to results
|
78 |
+
current_image = Image.fromarray(current_img)
|
79 |
+
|
80 |
+
# Add to results with a label for the gallery
|
81 |
+
results.append((current_image, f"Step {i+1}: {transformation_type}"))
|
82 |
+
|
83 |
+
# (Progress updates removed)
|
84 |
+
|
85 |
+
# Add a small delay to make the progressive changes visible
|
86 |
+
time.sleep(4)
|
87 |
+
|
88 |
+
# Yield intermediate results for real-time updates
|
89 |
+
if i < len(numbers) - 1:
|
90 |
+
yield current_image, results
|
91 |
+
|
92 |
+
return current_image, results
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
error_msg = f"Error: {str(e)}"
|
96 |
+
return image, [(image, "Error")]
|
97 |
+
|
98 |
+
# Create Gradio Interface
|
99 |
+
with gr.Blocks() as demo:
|
100 |
+
gr.Markdown("# Image Transformation Demo")
|
101 |
+
gr.Markdown("Upload an image and provide a comma-separated list of numbers. The demo will apply a series of transformations to the image based on these numbers.")
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
with gr.Column(scale=1):
|
105 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
106 |
+
numbers_input = gr.Textbox(label="Transformation Values (comma-separated numbers)",
|
107 |
+
placeholder="e.g., 50, -30, 1.5, 5, 90, 20",
|
108 |
+
value="30, 45, 1.5, 3, 60, -20, 90, 1.8, 7, 120")
|
109 |
+
transform_btn = gr.Button("Apply Transformations")
|
110 |
+
|
111 |
+
explanation = gr.Markdown("""
|
112 |
+
## How the transformations work:
|
113 |
+
|
114 |
+
The numbers you input will be used to apply these transformations in sequence:
|
115 |
+
1. First number: Brightness adjustment (-100 to 100)
|
116 |
+
2. Second number: Rotation (degrees)
|
117 |
+
3. Third number: Contrast adjustment (0.5 to 3)
|
118 |
+
4. Fourth number: Blur (kernel size)
|
119 |
+
5. Fifth number: Hue shift (color images only)
|
120 |
+
|
121 |
+
And the pattern repeats for longer lists of numbers.
|
122 |
+
""")
|
123 |
+
|
124 |
+
with gr.Column(scale=2):
|
125 |
+
with gr.Row():
|
126 |
+
current_image = gr.Image(label="Current Transformation", type="pil")
|
127 |
+
with gr.Row():
|
128 |
+
gallery = gr.Gallery(label="Transformation History", show_label=True, columns=4, rows=2, height="auto")
|
129 |
+
|
130 |
+
transform_btn.click(
|
131 |
+
fn=apply_transformations,
|
132 |
+
inputs=[input_image, numbers_input],
|
133 |
+
outputs=[current_image, gallery]
|
134 |
+
)
|
135 |
+
|
136 |
+
# Launch the app
|
137 |
+
demo.launch()
|
making_cache.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os, pdb, time
|
3 |
+
import torch_fidelity
|
4 |
+
import tqdm
|
5 |
+
import torch
|
6 |
+
import os.path as osp
|
7 |
+
import argparse
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from paintmind.engine.util import instantiate_from_config
|
10 |
+
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def caching():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--cfg', type=str, default='configs/vit_vqgan.yaml')
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
cfg_file = args.cfg
|
19 |
+
assert osp.exists(cfg_file)
|
20 |
+
config = OmegaConf.load(cfg_file)
|
21 |
+
dataset = instantiate_from_config(config.trainer.params.dataset)
|
22 |
+
model = instantiate_from_config(config.trainer.params.model)
|
23 |
+
dataloader = torch.utils.data.DataLoader(
|
24 |
+
dataset,
|
25 |
+
batch_size=config.trainer.params.batch_size,
|
26 |
+
shuffle=False,
|
27 |
+
num_workers=config.trainer.params.num_workers,
|
28 |
+
)
|
29 |
+
# Each batch will give us a (N, C, H, W) tensor of images
|
30 |
+
# We need to cache them and save them to a pth file
|
31 |
+
cache_save_file = config.trainer.params.latent_cache_file
|
32 |
+
cache = []
|
33 |
+
# import ipdb; ipdb.set_trace()
|
34 |
+
model.cuda()
|
35 |
+
model.eval()
|
36 |
+
for idx, batch in enumerate(tqdm.tqdm(dataloader)):
|
37 |
+
batch = batch[0].cuda()
|
38 |
+
latent = model.vae_encode(batch)
|
39 |
+
cache.append(latent.cpu())
|
40 |
+
cache = torch.cat(cache, dim=0)
|
41 |
+
torch.save(cache, cache_save_file)
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
|
45 |
+
caching()
|
46 |
+
|
paintmind/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .config import Config
|
2 |
+
from .version import __version__
|
paintmind/config.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, pdb
|
2 |
+
from copy import deepcopy
|
3 |
+
|
4 |
+
class Config:
|
5 |
+
def __init__(self, config=None):
|
6 |
+
if config is not None:
|
7 |
+
self.from_dict(config)
|
8 |
+
|
9 |
+
def __repr__(self):
|
10 |
+
return str(self.to_json_string())
|
11 |
+
|
12 |
+
def to_dict(self):
|
13 |
+
return deepcopy(self.__dict__)
|
14 |
+
|
15 |
+
def to_json(self, path):
|
16 |
+
with open(path, 'w') as f:
|
17 |
+
json.dump(self.to_dict(), f, indent=2)
|
18 |
+
|
19 |
+
def to_json_string(self):
|
20 |
+
return json.dumps(self.to_dict(), indent=2)
|
21 |
+
|
22 |
+
def from_dict(self, dct):
|
23 |
+
self.clear()
|
24 |
+
for key, value in dct.items():
|
25 |
+
self.__dict__[key] = value
|
26 |
+
|
27 |
+
return self.to_dict()
|
28 |
+
|
29 |
+
def from_json(self, json_path):
|
30 |
+
with open(json_path, 'r') as f:
|
31 |
+
config = json.load(f)
|
32 |
+
self.from_dict(config)
|
33 |
+
|
34 |
+
return self.to_dict()
|
35 |
+
|
36 |
+
def clear(self):
|
37 |
+
del self.__dict__
|
38 |
+
|
39 |
+
|
40 |
+
vit_s_vqgan_config = {
|
41 |
+
'n_embed' :8192,
|
42 |
+
'embed_dim' :16,
|
43 |
+
'beta' :0.25,
|
44 |
+
'enc':{
|
45 |
+
'image_size':320,
|
46 |
+
'patch_size':8,
|
47 |
+
'dim':512,
|
48 |
+
'depth':8,
|
49 |
+
'num_head':8,
|
50 |
+
'mlp_dim':2048,
|
51 |
+
'in_channels':3,
|
52 |
+
'dim_head':64,
|
53 |
+
'dropout':0.0,
|
54 |
+
},
|
55 |
+
'dec':{
|
56 |
+
'image_size':320,
|
57 |
+
'patch_size':8,
|
58 |
+
'dim':512,
|
59 |
+
'depth':8,
|
60 |
+
'num_head':8,
|
61 |
+
'mlp_dim':2048,
|
62 |
+
'out_channels':3,
|
63 |
+
'dim_head':64,
|
64 |
+
'dropout':0.0,
|
65 |
+
},
|
66 |
+
}
|
67 |
+
vit_m_vqgan_config = {
|
68 |
+
'n_embed' :8192,
|
69 |
+
'embed_dim' :32,
|
70 |
+
'beta' :0.25,
|
71 |
+
'enc':{
|
72 |
+
'image_size':256,
|
73 |
+
'patch_size':8,
|
74 |
+
'dim': 1024,
|
75 |
+
'depth': 16,
|
76 |
+
'num_head':16,
|
77 |
+
'mlp_dim':2048,
|
78 |
+
'in_channels':3,
|
79 |
+
'dim_head':64,
|
80 |
+
'dropout':0.0,
|
81 |
+
},
|
82 |
+
'dec':{
|
83 |
+
'image_size':256,
|
84 |
+
'patch_size':8,
|
85 |
+
'dim':1024,
|
86 |
+
'depth':16,
|
87 |
+
'num_head':16,
|
88 |
+
'mlp_dim':2048,
|
89 |
+
'out_channels':3,
|
90 |
+
'dim_head':64,
|
91 |
+
'dropout':0.0,
|
92 |
+
},
|
93 |
+
}
|
94 |
+
|
95 |
+
pipeline_v1_config = {
|
96 |
+
'stage1' :'vit-s-vqgan',
|
97 |
+
't5' :'t5-l',
|
98 |
+
'dim' :1024,
|
99 |
+
'dim_head' :64,
|
100 |
+
'mlp_dim' :4096,
|
101 |
+
'num_head' :16,
|
102 |
+
'depth' :12,
|
103 |
+
'dropout' :0.1,
|
104 |
+
}
|
105 |
+
|
106 |
+
ver2cfg = {
|
107 |
+
'vit-s-vqgan' : vit_s_vqgan_config,
|
108 |
+
'vit-m-vqgan' : vit_m_vqgan_config,
|
109 |
+
'paintmindv1' : pipeline_v1_config,
|
110 |
+
}
|
paintmind/engine/gpt_trainer.py
ADDED
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import os.path as osp
|
3 |
+
import cv2
|
4 |
+
import shutil
|
5 |
+
import numpy as np
|
6 |
+
import copy
|
7 |
+
import torch_fidelity
|
8 |
+
import torch.nn as nn
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from collections import OrderedDict
|
11 |
+
from einops import rearrange
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from .util import instantiate_from_config
|
14 |
+
from torchvision.utils import make_grid, save_image
|
15 |
+
from torch.utils.data import DataLoader, random_split, DistributedSampler, Sampler
|
16 |
+
from paintmind.utils.lr_scheduler import build_scheduler
|
17 |
+
from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
|
18 |
+
from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
|
19 |
+
from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
|
20 |
+
from torch.optim import AdamW
|
21 |
+
from concurrent.futures import ThreadPoolExecutor
|
22 |
+
from paintmind.stage2.gpt import GPT_models
|
23 |
+
from paintmind.stage2.causaldit import CausalDiT_models
|
24 |
+
from paintmind.stage2.generate import generate, generate_causal_dit
|
25 |
+
from pathlib import Path
|
26 |
+
import time
|
27 |
+
|
28 |
+
|
29 |
+
def requires_grad(model, flag=True):
|
30 |
+
for p in model.parameters():
|
31 |
+
p.requires_grad = flag
|
32 |
+
|
33 |
+
|
34 |
+
def save_img(img, save_path):
|
35 |
+
img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255)
|
36 |
+
img = img.astype(np.uint8)[:, :, ::-1]
|
37 |
+
cv2.imwrite(save_path, img)
|
38 |
+
|
39 |
+
def save_img_batch(imgs, save_paths):
|
40 |
+
"""Process and save multiple images at once using a thread pool."""
|
41 |
+
# Convert to numpy and prepare all images in one go
|
42 |
+
imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
|
43 |
+
imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
|
44 |
+
|
45 |
+
# Use ProcessPoolExecutor which is generally better for CPU-bound tasks
|
46 |
+
# ThreadPoolExecutor is better for I/O-bound tasks like file saving
|
47 |
+
with ThreadPoolExecutor(max_workers=32) as pool:
|
48 |
+
# Submit all tasks at once
|
49 |
+
futures = [pool.submit(cv2.imwrite, path, img)
|
50 |
+
for path, img in zip(save_paths, imgs)]
|
51 |
+
# Wait for all tasks to complete
|
52 |
+
for future in futures:
|
53 |
+
future.result() # This will raise any exceptions that occurred
|
54 |
+
|
55 |
+
def get_fid_stats(real_dir, rec_dir, fid_stats):
|
56 |
+
stats = torch_fidelity.calculate_metrics(
|
57 |
+
input1=real_dir,
|
58 |
+
input2=rec_dir,
|
59 |
+
fid_statistics_file=fid_stats,
|
60 |
+
cuda=True,
|
61 |
+
isc=True,
|
62 |
+
fid=True,
|
63 |
+
kid=False,
|
64 |
+
prc=False,
|
65 |
+
verbose=False,
|
66 |
+
)
|
67 |
+
return stats
|
68 |
+
|
69 |
+
|
70 |
+
class EMAModel:
|
71 |
+
"""Model Exponential Moving Average."""
|
72 |
+
def __init__(self, model, device, decay=0.999):
|
73 |
+
self.device = device
|
74 |
+
self.decay = decay
|
75 |
+
self.ema_params = OrderedDict(
|
76 |
+
(name, param.clone().detach().to(device))
|
77 |
+
for name, param in model.named_parameters()
|
78 |
+
if param.requires_grad
|
79 |
+
)
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def update(self, model):
|
83 |
+
for name, param in model.named_parameters():
|
84 |
+
if param.requires_grad:
|
85 |
+
if name in self.ema_params:
|
86 |
+
self.ema_params[name].lerp_(param.data, 1 - self.decay)
|
87 |
+
else:
|
88 |
+
self.ema_params[name] = param.data.clone().detach()
|
89 |
+
|
90 |
+
def state_dict(self):
|
91 |
+
return self.ema_params
|
92 |
+
|
93 |
+
def load_state_dict(self, params):
|
94 |
+
self.ema_params = OrderedDict(
|
95 |
+
(name, param.clone().detach().to(self.device))
|
96 |
+
for name, param in params.items()
|
97 |
+
)
|
98 |
+
|
99 |
+
class CacheDataLoader:
|
100 |
+
"""DataLoader-like interface for cached data with epoch-based shuffling."""
|
101 |
+
def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None):
|
102 |
+
self.slots = slots
|
103 |
+
self.targets = targets
|
104 |
+
self.batch_size = batch_size
|
105 |
+
self.num_augs = num_augs
|
106 |
+
self.seed = seed
|
107 |
+
self.epoch = 0
|
108 |
+
# Original dataset size (before augmentations)
|
109 |
+
self.num_samples = len(slots) // num_augs
|
110 |
+
|
111 |
+
def set_epoch(self, epoch):
|
112 |
+
"""Set epoch for deterministic shuffling."""
|
113 |
+
self.epoch = epoch
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
"""Return number of batches based on original dataset size."""
|
117 |
+
return self.num_samples // self.batch_size
|
118 |
+
|
119 |
+
def __iter__(self):
|
120 |
+
"""Return random indices for current epoch."""
|
121 |
+
g = torch.Generator()
|
122 |
+
g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch)
|
123 |
+
|
124 |
+
# Randomly sample indices from the entire augmented dataset
|
125 |
+
indices = torch.randint(
|
126 |
+
0, len(self.slots),
|
127 |
+
(self.num_samples,),
|
128 |
+
generator=g
|
129 |
+
).numpy()
|
130 |
+
|
131 |
+
# Yield batches of indices
|
132 |
+
for start in range(0, self.num_samples, self.batch_size):
|
133 |
+
end = min(start + self.batch_size, self.num_samples)
|
134 |
+
batch_indices = indices[start:end]
|
135 |
+
yield (
|
136 |
+
torch.from_numpy(self.slots[batch_indices]),
|
137 |
+
torch.from_numpy(self.targets[batch_indices])
|
138 |
+
)
|
139 |
+
|
140 |
+
class GPTTrainer(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
ae_model,
|
144 |
+
gpt_model,
|
145 |
+
dataset,
|
146 |
+
test_dataset=None,
|
147 |
+
test_only=False,
|
148 |
+
num_test_images=50000,
|
149 |
+
num_epoch=400,
|
150 |
+
eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear
|
151 |
+
lr=None,
|
152 |
+
blr=1e-4,
|
153 |
+
cosine_lr=False,
|
154 |
+
lr_min=0,
|
155 |
+
warmup_epochs=100,
|
156 |
+
warmup_steps=None,
|
157 |
+
warmup_lr_init=0,
|
158 |
+
decay_steps=None,
|
159 |
+
batch_size=32,
|
160 |
+
cache_bs=8,
|
161 |
+
test_bs=100,
|
162 |
+
num_workers=0,
|
163 |
+
pin_memory=False,
|
164 |
+
max_grad_norm=None,
|
165 |
+
grad_accum_steps=1,
|
166 |
+
precision="bf16",
|
167 |
+
save_every=10000,
|
168 |
+
sample_every=1000,
|
169 |
+
fid_every=50000,
|
170 |
+
result_folder=None,
|
171 |
+
log_dir="./log",
|
172 |
+
steps=0,
|
173 |
+
cfg=1.75,
|
174 |
+
ae_cfg=1.5,
|
175 |
+
diff_cfg=2.0,
|
176 |
+
temperature=1.0,
|
177 |
+
cfg_schedule="constant",
|
178 |
+
diff_cfg_schedule="inv_linear",
|
179 |
+
train_num_slots=None,
|
180 |
+
test_num_slots=None,
|
181 |
+
eval_fid=False,
|
182 |
+
fid_stats=None,
|
183 |
+
enable_ema=False,
|
184 |
+
compile=False,
|
185 |
+
enable_cache_latents=True,
|
186 |
+
cache_dir='/dev/shm/slot_cache',
|
187 |
+
seed=42
|
188 |
+
):
|
189 |
+
super().__init__()
|
190 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
191 |
+
self.accelerator = Accelerator(
|
192 |
+
kwargs_handlers=[kwargs],
|
193 |
+
mixed_precision="bf16",
|
194 |
+
gradient_accumulation_steps=grad_accum_steps,
|
195 |
+
log_with="tensorboard",
|
196 |
+
project_dir=log_dir,
|
197 |
+
)
|
198 |
+
|
199 |
+
self.ae_model = instantiate_from_config(ae_model)
|
200 |
+
if hasattr(ae_model.params, "ema_path") and ae_model.params.ema_path is not None:
|
201 |
+
ae_model_path = ae_model.params.ema_path
|
202 |
+
else:
|
203 |
+
ae_model_path = ae_model.params.ckpt_path
|
204 |
+
assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl")
|
205 |
+
assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist"
|
206 |
+
self._load_checkpoint(ae_model_path, self.ae_model)
|
207 |
+
|
208 |
+
self.ae_model.to(self.device)
|
209 |
+
for param in self.ae_model.parameters():
|
210 |
+
param.requires_grad = False
|
211 |
+
self.ae_model.eval()
|
212 |
+
|
213 |
+
self.model_name = gpt_model.target
|
214 |
+
if 'GPT' in gpt_model.target:
|
215 |
+
self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params)
|
216 |
+
elif 'CausalDiT' in gpt_model.target:
|
217 |
+
self.gpt_model = CausalDiT_models[gpt_model.target](**gpt_model.params)
|
218 |
+
else:
|
219 |
+
raise ValueError(f"Unknown model type: {gpt_model.target}")
|
220 |
+
self.num_slots = ae_model.params.num_slots
|
221 |
+
self.slot_dim = ae_model.params.slot_dim
|
222 |
+
|
223 |
+
assert precision in ["bf16", "fp32"]
|
224 |
+
precision = "fp32"
|
225 |
+
if self.accelerator.is_main_process:
|
226 |
+
print("Overlooking specified precision and using autocast bf16...")
|
227 |
+
self.precision = precision
|
228 |
+
|
229 |
+
self.test_only = test_only
|
230 |
+
self.test_bs = test_bs
|
231 |
+
self.num_test_images = num_test_images
|
232 |
+
self.num_classes = gpt_model.params.num_classes
|
233 |
+
|
234 |
+
self.batch_size = batch_size
|
235 |
+
if not test_only:
|
236 |
+
self.train_ds = instantiate_from_config(dataset)
|
237 |
+
train_size = len(self.train_ds)
|
238 |
+
if self.accelerator.is_main_process:
|
239 |
+
print(f"train dataset size: {train_size}")
|
240 |
+
|
241 |
+
sampler = DistributedSampler(
|
242 |
+
self.train_ds,
|
243 |
+
num_replicas=self.accelerator.num_processes,
|
244 |
+
rank=self.accelerator.process_index,
|
245 |
+
shuffle=True,
|
246 |
+
)
|
247 |
+
self.train_dl = DataLoader(
|
248 |
+
self.train_ds,
|
249 |
+
batch_size=batch_size if not enable_cache_latents else cache_bs,
|
250 |
+
sampler=sampler,
|
251 |
+
num_workers=num_workers,
|
252 |
+
pin_memory=pin_memory,
|
253 |
+
drop_last=True,
|
254 |
+
)
|
255 |
+
|
256 |
+
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
|
257 |
+
if lr is None:
|
258 |
+
lr = blr * effective_bs / 256
|
259 |
+
if self.accelerator.is_main_process:
|
260 |
+
print(f"Effective batch size is {effective_bs}")
|
261 |
+
|
262 |
+
self.g_optim = self._creat_optimizer(weight_decay=0.05, learning_rate=lr, betas=(0.9, 0.95))
|
263 |
+
self.g_sched = self._create_scheduler(
|
264 |
+
cosine_lr, warmup_epochs, warmup_steps, num_epoch,
|
265 |
+
lr_min, warmup_lr_init, decay_steps
|
266 |
+
)
|
267 |
+
self.accelerator.register_for_checkpointing(self.g_sched)
|
268 |
+
|
269 |
+
self.steps = steps
|
270 |
+
self.loaded_steps = -1
|
271 |
+
|
272 |
+
# Prepare everything together
|
273 |
+
if not test_only:
|
274 |
+
self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(
|
275 |
+
self.gpt_model, self.g_optim, self.g_sched
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
self.gpt_model = self.accelerator.prepare(self.gpt_model)
|
279 |
+
|
280 |
+
# assume _ori_model does not exist in checkpoints
|
281 |
+
if compile:
|
282 |
+
_model = self.accelerator.unwrap_model(self.gpt_model)
|
283 |
+
self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead")
|
284 |
+
_model = torch.compile(_model, mode="reduce-overhead")
|
285 |
+
|
286 |
+
self.enable_ema = enable_ema
|
287 |
+
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
|
288 |
+
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device)
|
289 |
+
self.accelerator.register_for_checkpointing(self.ema_model)
|
290 |
+
|
291 |
+
self._load_checkpoint(gpt_model.params.ckpt_path)
|
292 |
+
if self.test_only:
|
293 |
+
self.steps = self.loaded_steps
|
294 |
+
|
295 |
+
self.num_epoch = num_epoch
|
296 |
+
self.save_every = save_every
|
297 |
+
self.samp_every = sample_every
|
298 |
+
self.fid_every = fid_every
|
299 |
+
self.max_grad_norm = max_grad_norm
|
300 |
+
|
301 |
+
self.eval_classes = eval_classes
|
302 |
+
self.cfg = cfg
|
303 |
+
self.ae_cfg = ae_cfg
|
304 |
+
self.diff_cfg = diff_cfg
|
305 |
+
self.cfg_schedule = cfg_schedule
|
306 |
+
self.diff_cfg_schedule = diff_cfg_schedule
|
307 |
+
self.temperature = temperature
|
308 |
+
self.train_num_slots = train_num_slots
|
309 |
+
self.test_num_slots = test_num_slots
|
310 |
+
if self.train_num_slots is not None:
|
311 |
+
self.train_num_slots = min(self.train_num_slots, self.num_slots)
|
312 |
+
else:
|
313 |
+
self.train_num_slots = self.num_slots
|
314 |
+
if self.test_num_slots is not None:
|
315 |
+
self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots)
|
316 |
+
else:
|
317 |
+
self.num_slots_to_gen = self.train_num_slots
|
318 |
+
self.eval_fid = eval_fid
|
319 |
+
if eval_fid:
|
320 |
+
assert fid_stats is not None
|
321 |
+
self.fid_stats = fid_stats
|
322 |
+
|
323 |
+
self.result_folder = result_folder
|
324 |
+
self.model_saved_dir = os.path.join(result_folder, "models")
|
325 |
+
os.makedirs(self.model_saved_dir, exist_ok=True)
|
326 |
+
|
327 |
+
self.image_saved_dir = os.path.join(result_folder, "images")
|
328 |
+
os.makedirs(self.image_saved_dir, exist_ok=True)
|
329 |
+
|
330 |
+
self.cache_dir = Path(cache_dir)
|
331 |
+
self.enable_cache_latents = enable_cache_latents
|
332 |
+
self.seed = seed
|
333 |
+
self.cache_loader = None
|
334 |
+
|
335 |
+
@property
|
336 |
+
def device(self):
|
337 |
+
return self.accelerator.device
|
338 |
+
|
339 |
+
def _creat_optimizer(self, weight_decay, learning_rate, betas):
|
340 |
+
# start with all of the candidate parameters
|
341 |
+
param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()}
|
342 |
+
# filter out those that do not require grad
|
343 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
344 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
345 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
346 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
347 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
348 |
+
optim_groups = [
|
349 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
350 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
351 |
+
]
|
352 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
353 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
354 |
+
if self.accelerator.is_main_process:
|
355 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
356 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
357 |
+
optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas)
|
358 |
+
return optimizer
|
359 |
+
|
360 |
+
def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
|
361 |
+
if warmup_epochs is not None:
|
362 |
+
warmup_steps = warmup_epochs * len(self.train_dl)
|
363 |
+
else:
|
364 |
+
assert warmup_steps is not None
|
365 |
+
|
366 |
+
scheduler = build_scheduler(
|
367 |
+
self.g_optim,
|
368 |
+
num_epoch,
|
369 |
+
len(self.train_dl),
|
370 |
+
lr_min,
|
371 |
+
warmup_steps,
|
372 |
+
warmup_lr_init,
|
373 |
+
decay_steps,
|
374 |
+
cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix)
|
375 |
+
)
|
376 |
+
return scheduler
|
377 |
+
|
378 |
+
def _load_state_dict(self, state_dict, model):
|
379 |
+
"""Helper to load a state dict with proper prefix handling."""
|
380 |
+
if 'state_dict' in state_dict:
|
381 |
+
state_dict = state_dict['state_dict']
|
382 |
+
# Remove '_orig_mod' prefix if present
|
383 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
384 |
+
missing, unexpected = model.load_state_dict(
|
385 |
+
state_dict, strict=False
|
386 |
+
)
|
387 |
+
if self.accelerator.is_main_process:
|
388 |
+
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
|
389 |
+
|
390 |
+
def _load_safetensors(self, path, model):
|
391 |
+
"""Helper to load a safetensors checkpoint."""
|
392 |
+
from safetensors.torch import safe_open
|
393 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
394 |
+
state_dict = {k: f.get_tensor(k) for k in f.keys()}
|
395 |
+
self._load_state_dict(state_dict, model)
|
396 |
+
|
397 |
+
def _load_checkpoint(self, ckpt_path=None, model=None):
|
398 |
+
if ckpt_path is None or not osp.exists(ckpt_path):
|
399 |
+
return
|
400 |
+
|
401 |
+
if model is None:
|
402 |
+
model = self.accelerator.unwrap_model(self.gpt_model)
|
403 |
+
|
404 |
+
if osp.isdir(ckpt_path):
|
405 |
+
# ckpt_path is something like 'path/to/models/step10/'
|
406 |
+
self.loaded_steps = int(
|
407 |
+
ckpt_path.split("step")[-1].split("/")[0]
|
408 |
+
)
|
409 |
+
if not self.test_only:
|
410 |
+
self.accelerator.load_state(ckpt_path)
|
411 |
+
else:
|
412 |
+
if self.enable_ema:
|
413 |
+
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
|
414 |
+
if osp.exists(model_path):
|
415 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
416 |
+
self._load_state_dict(state_dict, model)
|
417 |
+
if self.accelerator.is_main_process:
|
418 |
+
print(f"Loaded ema model from {model_path}")
|
419 |
+
else:
|
420 |
+
model_path = osp.join(ckpt_path, "model.safetensors")
|
421 |
+
if osp.exists(model_path):
|
422 |
+
self._load_safetensors(model_path, model)
|
423 |
+
else:
|
424 |
+
# ckpt_path is something like 'path/to/models/step10.pt'
|
425 |
+
if ckpt_path.endswith(".safetensors"):
|
426 |
+
self._load_safetensors(ckpt_path, model)
|
427 |
+
else:
|
428 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
429 |
+
self._load_state_dict(state_dict, model)
|
430 |
+
|
431 |
+
if self.accelerator.is_main_process:
|
432 |
+
print(f"Loaded checkpoint from {ckpt_path}")
|
433 |
+
|
434 |
+
def _build_cache(self):
|
435 |
+
"""Build cache for slots and targets."""
|
436 |
+
rank = self.accelerator.process_index
|
437 |
+
world_size = self.accelerator.num_processes
|
438 |
+
|
439 |
+
# Clean up any existing cache files first
|
440 |
+
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
|
441 |
+
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
|
442 |
+
|
443 |
+
if slots_file.exists():
|
444 |
+
os.remove(slots_file)
|
445 |
+
if targets_file.exists():
|
446 |
+
os.remove(targets_file)
|
447 |
+
|
448 |
+
dataset_size = len(self.train_dl.dataset)
|
449 |
+
shard_size = dataset_size // world_size
|
450 |
+
|
451 |
+
# Detect number of augmentations from first batch
|
452 |
+
with torch.no_grad():
|
453 |
+
sample_batch = next(iter(self.train_dl))
|
454 |
+
img, _ = sample_batch
|
455 |
+
num_augs = img.shape[1] if len(img.shape) == 5 else 1
|
456 |
+
|
457 |
+
print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...")
|
458 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
459 |
+
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
|
460 |
+
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
|
461 |
+
|
462 |
+
# Create memory-mapped files
|
463 |
+
slots_mmap = np.memmap(
|
464 |
+
slots_file,
|
465 |
+
dtype='float32',
|
466 |
+
mode='w+',
|
467 |
+
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
|
468 |
+
)
|
469 |
+
|
470 |
+
targets_mmap = np.memmap(
|
471 |
+
targets_file,
|
472 |
+
dtype='int64',
|
473 |
+
mode='w+',
|
474 |
+
shape=(shard_size * num_augs,)
|
475 |
+
)
|
476 |
+
|
477 |
+
# Cache data
|
478 |
+
with torch.no_grad():
|
479 |
+
for i, batch in enumerate(tqdm(
|
480 |
+
self.train_dl,
|
481 |
+
desc=f"Rank {rank}: Caching data",
|
482 |
+
disable=not self.accelerator.is_local_main_process
|
483 |
+
)):
|
484 |
+
imgs, targets = batch
|
485 |
+
if len(imgs.shape) == 5: # [B, num_augs, C, H, W]
|
486 |
+
B, A, C, H, W = imgs.shape
|
487 |
+
imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W]
|
488 |
+
targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs]
|
489 |
+
|
490 |
+
# Split imgs into n chunks
|
491 |
+
num_splits = num_augs
|
492 |
+
split_size = imgs.shape[0] // num_splits
|
493 |
+
imgs_splits = torch.split(imgs, split_size)
|
494 |
+
targets_splits = torch.split(targets, split_size)
|
495 |
+
|
496 |
+
start_idx = i * self.train_dl.batch_size * num_augs
|
497 |
+
|
498 |
+
for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)):
|
499 |
+
img_split = img_split.to(self.device, non_blocking=True)
|
500 |
+
slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :]
|
501 |
+
|
502 |
+
split_start = start_idx + (split_idx * split_size)
|
503 |
+
split_end = split_start + img_split.shape[0]
|
504 |
+
|
505 |
+
# Write directly to mmap files
|
506 |
+
slots_mmap[split_start:split_end] = slots_split.cpu().numpy()
|
507 |
+
targets_mmap[split_start:split_end] = targets_split.numpy()
|
508 |
+
|
509 |
+
# Close the mmap files
|
510 |
+
del slots_mmap
|
511 |
+
del targets_mmap
|
512 |
+
|
513 |
+
# Reopen in read mode
|
514 |
+
self.cached_latents = np.memmap(
|
515 |
+
slots_file,
|
516 |
+
dtype='float32',
|
517 |
+
mode='r',
|
518 |
+
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
|
519 |
+
)
|
520 |
+
|
521 |
+
self.cached_targets = np.memmap(
|
522 |
+
targets_file,
|
523 |
+
dtype='int64',
|
524 |
+
mode='r',
|
525 |
+
shape=(shard_size * num_augs,)
|
526 |
+
)
|
527 |
+
|
528 |
+
# Store the number of augmentations for the cache loader
|
529 |
+
self.num_augs = num_augs
|
530 |
+
|
531 |
+
def _setup_cache(self):
|
532 |
+
"""Setup cache if enabled."""
|
533 |
+
self._build_cache()
|
534 |
+
self.accelerator.wait_for_everyone()
|
535 |
+
|
536 |
+
# Initialize cache loader if cache exists
|
537 |
+
if self.cached_latents is not None:
|
538 |
+
self.cache_loader = CacheDataLoader(
|
539 |
+
slots=self.cached_latents,
|
540 |
+
targets=self.cached_targets,
|
541 |
+
batch_size=self.batch_size,
|
542 |
+
num_augs=self.num_augs,
|
543 |
+
seed=self.seed + self.accelerator.process_index
|
544 |
+
)
|
545 |
+
|
546 |
+
def __del__(self):
|
547 |
+
"""Cleanup cache files."""
|
548 |
+
if self.enable_cache_latents:
|
549 |
+
rank = self.accelerator.process_index
|
550 |
+
world_size = self.accelerator.num_processes
|
551 |
+
|
552 |
+
# Clean up slots cache
|
553 |
+
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
|
554 |
+
if slots_file.exists():
|
555 |
+
os.remove(slots_file)
|
556 |
+
|
557 |
+
# Clean up targets cache
|
558 |
+
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
|
559 |
+
if targets_file.exists():
|
560 |
+
os.remove(targets_file)
|
561 |
+
|
562 |
+
def _train_step(self, slots, targets=None):
|
563 |
+
"""Execute single training step."""
|
564 |
+
|
565 |
+
with self.accelerator.accumulate(self.gpt_model):
|
566 |
+
with self.accelerator.autocast():
|
567 |
+
loss = self.gpt_model(slots, targets)
|
568 |
+
|
569 |
+
self.accelerator.backward(loss)
|
570 |
+
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
|
571 |
+
self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
|
572 |
+
self.g_optim.step()
|
573 |
+
if self.g_sched is not None:
|
574 |
+
self.g_sched.step_update(self.steps)
|
575 |
+
self.g_optim.zero_grad()
|
576 |
+
|
577 |
+
# Update EMA model if enabled
|
578 |
+
if self.enable_ema:
|
579 |
+
self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model))
|
580 |
+
|
581 |
+
return loss
|
582 |
+
|
583 |
+
def _train_epoch_cached(self, epoch, logger):
|
584 |
+
"""Train one epoch using cached data."""
|
585 |
+
self.cache_loader.set_epoch(epoch)
|
586 |
+
header = f'Epoch: [{epoch}/{self.num_epoch}]'
|
587 |
+
|
588 |
+
for batch in logger.log_every(self.cache_loader, 20, header):
|
589 |
+
slots, targets = (b.to(self.device, non_blocking=True) for b in batch)
|
590 |
+
|
591 |
+
self.steps += 1
|
592 |
+
|
593 |
+
if self.steps == 1:
|
594 |
+
print(f"Training batch size: {len(slots)}")
|
595 |
+
print(f"Hello from index {self.accelerator.local_process_index}")
|
596 |
+
|
597 |
+
loss = self._train_step(slots, targets)
|
598 |
+
self._handle_periodic_ops(loss, logger)
|
599 |
+
|
600 |
+
def _train_epoch_uncached(self, epoch, logger):
|
601 |
+
"""Train one epoch using raw data."""
|
602 |
+
header = f'Epoch: [{epoch}/{self.num_epoch}]'
|
603 |
+
|
604 |
+
for batch in logger.log_every(self.train_dl, 20, header):
|
605 |
+
img, targets = (b.to(self.device, non_blocking=True) for b in batch)
|
606 |
+
|
607 |
+
self.steps += 1
|
608 |
+
|
609 |
+
if self.steps == 1:
|
610 |
+
print(f"Training batch size: {img.size(0)}")
|
611 |
+
print(f"Hello from index {self.accelerator.local_process_index}")
|
612 |
+
|
613 |
+
slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :]
|
614 |
+
loss = self._train_step(slots, targets)
|
615 |
+
self._handle_periodic_ops(loss, logger)
|
616 |
+
|
617 |
+
def _handle_periodic_ops(self, loss, logger):
|
618 |
+
"""Handle periodic operations and logging."""
|
619 |
+
logger.update(loss=loss.item())
|
620 |
+
logger.update(lr=self.g_optim.param_groups[0]["lr"])
|
621 |
+
|
622 |
+
if self.steps % self.save_every == 0:
|
623 |
+
self.save()
|
624 |
+
|
625 |
+
if (self.steps % self.samp_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0):
|
626 |
+
empty_cache()
|
627 |
+
self.evaluate()
|
628 |
+
self.accelerator.wait_for_everyone()
|
629 |
+
empty_cache()
|
630 |
+
|
631 |
+
def _save_config(self, config):
|
632 |
+
"""Save configuration file."""
|
633 |
+
if config is not None and self.accelerator.is_main_process:
|
634 |
+
import shutil
|
635 |
+
from omegaconf import OmegaConf
|
636 |
+
|
637 |
+
if isinstance(config, str) and osp.exists(config):
|
638 |
+
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
|
639 |
+
else:
|
640 |
+
config_save_path = osp.join(self.result_folder, "config.yaml")
|
641 |
+
OmegaConf.save(config, config_save_path)
|
642 |
+
|
643 |
+
def _should_skip_epoch(self, epoch):
|
644 |
+
"""Check if epoch should be skipped due to loaded checkpoint."""
|
645 |
+
loader = self.train_dl if not self.enable_cache_latents else self.cache_loader
|
646 |
+
if ((epoch + 1) * len(loader)) <= self.loaded_steps:
|
647 |
+
if self.accelerator.is_main_process:
|
648 |
+
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
|
649 |
+
self.steps += len(loader)
|
650 |
+
return True
|
651 |
+
|
652 |
+
if self.steps < self.loaded_steps:
|
653 |
+
for _ in loader:
|
654 |
+
self.steps += 1
|
655 |
+
if self.steps >= self.loaded_steps:
|
656 |
+
break
|
657 |
+
return False
|
658 |
+
|
659 |
+
def train(self, config=None):
|
660 |
+
"""Main training loop."""
|
661 |
+
# Initial setup
|
662 |
+
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
663 |
+
if self.accelerator.is_main_process:
|
664 |
+
print(f"number of learnable parameters: {n_parameters//1e6}M")
|
665 |
+
|
666 |
+
self._save_config(config)
|
667 |
+
self.accelerator.init_trackers("gpt")
|
668 |
+
|
669 |
+
# Handle test-only mode
|
670 |
+
if self.test_only:
|
671 |
+
empty_cache()
|
672 |
+
self.evaluate()
|
673 |
+
self.accelerator.wait_for_everyone()
|
674 |
+
empty_cache()
|
675 |
+
return
|
676 |
+
|
677 |
+
# Setup cache if enabled
|
678 |
+
if self.enable_cache_latents:
|
679 |
+
self._setup_cache()
|
680 |
+
|
681 |
+
# Training loop
|
682 |
+
for epoch in range(self.num_epoch):
|
683 |
+
if self._should_skip_epoch(epoch):
|
684 |
+
continue
|
685 |
+
|
686 |
+
self.gpt_model.train()
|
687 |
+
logger = MetricLogger(delimiter=" ")
|
688 |
+
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
689 |
+
|
690 |
+
# Choose training path based on cache availability
|
691 |
+
if self.enable_cache_latents:
|
692 |
+
self._train_epoch_cached(epoch, logger)
|
693 |
+
else:
|
694 |
+
self._train_epoch_uncached(epoch, logger)
|
695 |
+
|
696 |
+
# Synchronize and log epoch stats
|
697 |
+
# logger.synchronize_between_processes()
|
698 |
+
# if self.accelerator.is_main_process:
|
699 |
+
# print("Averaged stats:", logger)
|
700 |
+
|
701 |
+
# Finish training
|
702 |
+
self.accelerator.end_training()
|
703 |
+
self.save()
|
704 |
+
if self.accelerator.is_main_process:
|
705 |
+
print("Train finished!")
|
706 |
+
|
707 |
+
def save(self):
|
708 |
+
self.accelerator.wait_for_everyone()
|
709 |
+
self.accelerator.save_state(
|
710 |
+
os.path.join(self.model_saved_dir, f"step{self.steps}")
|
711 |
+
)
|
712 |
+
|
713 |
+
@torch.no_grad()
|
714 |
+
def evaluate(self, use_ema=True):
|
715 |
+
self.gpt_model.eval()
|
716 |
+
unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model)
|
717 |
+
# switch to ema params, only when eval_fid is True
|
718 |
+
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
|
719 |
+
if use_ema:
|
720 |
+
if hasattr(self, "ema_model"):
|
721 |
+
model_without_ddp = self.accelerator.unwrap_model(self.gpt_model)
|
722 |
+
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
723 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
724 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
725 |
+
if "nested_sampler" in name:
|
726 |
+
continue
|
727 |
+
ema_state_dict[name] = self.ema_model.state_dict()[name]
|
728 |
+
if self.accelerator.is_main_process:
|
729 |
+
print("Switch to ema")
|
730 |
+
model_without_ddp.load_state_dict(ema_state_dict)
|
731 |
+
else:
|
732 |
+
print("EMA model not found, using original model")
|
733 |
+
use_ema = False
|
734 |
+
|
735 |
+
generate_fn = generate if 'GPT' in self.model_name else generate_causal_dit
|
736 |
+
if not self.test_only:
|
737 |
+
classes = torch.tensor(self.eval_classes, device=self.device)
|
738 |
+
with self.accelerator.autocast():
|
739 |
+
slots = generate_fn(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, diff_cfg=self.diff_cfg, cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, temperature=self.temperature)
|
740 |
+
if self.num_slots_to_gen < self.num_slots:
|
741 |
+
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
|
742 |
+
null_slots = null_slots[:, self.num_slots_to_gen:, :]
|
743 |
+
slots = torch.cat([slots, null_slots], dim=1)
|
744 |
+
imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now
|
745 |
+
|
746 |
+
imgs = concat_all_gather(imgs)
|
747 |
+
if self.accelerator.num_processes > 16:
|
748 |
+
imgs = imgs[:16*len(self.eval_classes)]
|
749 |
+
imgs = imgs.detach().cpu()
|
750 |
+
grid = make_grid(
|
751 |
+
imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1)
|
752 |
+
)
|
753 |
+
if self.accelerator.is_main_process:
|
754 |
+
save_image(
|
755 |
+
grid,
|
756 |
+
os.path.join(
|
757 |
+
self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg"
|
758 |
+
),
|
759 |
+
)
|
760 |
+
if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)):
|
761 |
+
# Create output directory (only on main process)
|
762 |
+
save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}")
|
763 |
+
if self.accelerator.is_main_process:
|
764 |
+
os.makedirs(save_folder, exist_ok=True)
|
765 |
+
|
766 |
+
# Setup for distributed generation
|
767 |
+
world_size = self.accelerator.num_processes
|
768 |
+
local_rank = self.accelerator.process_index
|
769 |
+
batch_size = self.test_bs
|
770 |
+
|
771 |
+
# Create balanced class distribution
|
772 |
+
num_classes = self.num_classes
|
773 |
+
images_per_class = self.num_test_images // num_classes
|
774 |
+
class_labels = np.repeat(np.arange(num_classes), images_per_class)
|
775 |
+
|
776 |
+
# Shuffle the class labels to ensure random ordering
|
777 |
+
np.random.shuffle(class_labels)
|
778 |
+
|
779 |
+
total_images = len(class_labels)
|
780 |
+
|
781 |
+
padding_size = world_size * batch_size - (total_images % (world_size * batch_size))
|
782 |
+
class_labels = np.pad(class_labels, (0, padding_size), 'constant')
|
783 |
+
padded_total_images = len(class_labels)
|
784 |
+
|
785 |
+
# Distribute workload across GPUs
|
786 |
+
images_per_gpu = padded_total_images // world_size
|
787 |
+
start_idx = local_rank * images_per_gpu
|
788 |
+
end_idx = min(start_idx + images_per_gpu, padded_total_images)
|
789 |
+
local_class_labels = class_labels[start_idx:end_idx]
|
790 |
+
local_num_steps = len(local_class_labels) // batch_size
|
791 |
+
|
792 |
+
if self.accelerator.is_main_process:
|
793 |
+
print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs")
|
794 |
+
|
795 |
+
used_time = 0
|
796 |
+
gen_img_cnt = 0
|
797 |
+
|
798 |
+
for i in range(local_num_steps):
|
799 |
+
if self.accelerator.is_main_process and i % 10 == 0:
|
800 |
+
print(f"Generation step {i}/{local_num_steps}")
|
801 |
+
|
802 |
+
# Get and pad labels for current batch
|
803 |
+
batch_start = i * batch_size
|
804 |
+
batch_end = batch_start + batch_size
|
805 |
+
labels = local_class_labels[batch_start:batch_end]
|
806 |
+
|
807 |
+
# Convert to tensors and track real vs padding
|
808 |
+
labels = torch.tensor(labels, device=self.device)
|
809 |
+
|
810 |
+
# Generate images
|
811 |
+
self.accelerator.wait_for_everyone()
|
812 |
+
start_time = time.time()
|
813 |
+
with torch.no_grad():
|
814 |
+
with self.accelerator.autocast():
|
815 |
+
slots = generate_fn(unwraped_gpt_model, labels, self.num_slots_to_gen,
|
816 |
+
cfg_scale=self.cfg, diff_cfg=self.diff_cfg,
|
817 |
+
cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule,
|
818 |
+
temperature=self.temperature)
|
819 |
+
if self.num_slots_to_gen < self.num_slots:
|
820 |
+
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
|
821 |
+
null_slots = null_slots[:, self.num_slots_to_gen:, :]
|
822 |
+
slots = torch.cat([slots, null_slots], dim=1)
|
823 |
+
imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg)
|
824 |
+
|
825 |
+
samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt)
|
826 |
+
|
827 |
+
# Update timing stats
|
828 |
+
used_time += time.time() - start_time
|
829 |
+
gen_img_cnt += samples_in_batch
|
830 |
+
if self.accelerator.is_main_process and i % 10 == 0:
|
831 |
+
print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image")
|
832 |
+
|
833 |
+
gathered_imgs = concat_all_gather(imgs)
|
834 |
+
gathered_imgs = gathered_imgs[:samples_in_batch]
|
835 |
+
|
836 |
+
# Save images (only on main process)
|
837 |
+
if self.accelerator.is_main_process:
|
838 |
+
real_imgs = gathered_imgs.detach().cpu()
|
839 |
+
|
840 |
+
save_paths = [
|
841 |
+
os.path.join(save_folder, f"{str(idx).zfill(5)}.png")
|
842 |
+
for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt)
|
843 |
+
]
|
844 |
+
save_img_batch(real_imgs, save_paths)
|
845 |
+
|
846 |
+
# Calculate metrics (only on main process)
|
847 |
+
self.accelerator.wait_for_everyone()
|
848 |
+
if self.accelerator.is_main_process:
|
849 |
+
generated_files = len(os.listdir(save_folder))
|
850 |
+
print(f"Generated {generated_files} images out of {total_images} expected")
|
851 |
+
|
852 |
+
metrics_dict = get_fid_stats(save_folder, None, self.fid_stats)
|
853 |
+
fid = metrics_dict["frechet_inception_distance"]
|
854 |
+
inception_score = metrics_dict["inception_score_mean"]
|
855 |
+
|
856 |
+
metric_prefix = "fid_ema" if use_ema else "fid"
|
857 |
+
isc_prefix = "isc_ema" if use_ema else "isc"
|
858 |
+
|
859 |
+
self.accelerator.log({
|
860 |
+
metric_prefix: fid,
|
861 |
+
isc_prefix: inception_score,
|
862 |
+
"gpt_cfg": self.cfg,
|
863 |
+
"ae_cfg": self.ae_cfg,
|
864 |
+
"diff_cfg": self.diff_cfg,
|
865 |
+
"cfg_schedule": self.cfg_schedule,
|
866 |
+
"diff_cfg_schedule": self.diff_cfg_schedule,
|
867 |
+
"temperature": self.temperature,
|
868 |
+
"num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots
|
869 |
+
}, step=self.steps)
|
870 |
+
|
871 |
+
# Print comprehensive CFG information
|
872 |
+
cfg_info = (
|
873 |
+
f"{'EMA ' if use_ema else ''}CFG params: "
|
874 |
+
f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, diff_cfg={self.diff_cfg}, "
|
875 |
+
f"cfg_schedule={self.cfg_schedule}, diff_cfg_schedule={self.diff_cfg_schedule}, "
|
876 |
+
f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, "
|
877 |
+
f"temperature={self.temperature}"
|
878 |
+
)
|
879 |
+
print(cfg_info)
|
880 |
+
print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}")
|
881 |
+
|
882 |
+
# Cleanup
|
883 |
+
shutil.rmtree(save_folder)
|
884 |
+
|
885 |
+
# back to no ema
|
886 |
+
if use_ema:
|
887 |
+
if self.accelerator.is_main_process:
|
888 |
+
print("Switch back from ema")
|
889 |
+
model_without_ddp.load_state_dict(model_state_dict)
|
890 |
+
|
891 |
+
self.gpt_model.train()
|
892 |
+
|
paintmind/engine/misc.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import os, sys, pdb
|
3 |
+
from torch import inf
|
4 |
+
import os.path as osp
|
5 |
+
from pathlib import Path
|
6 |
+
import builtins, datetime
|
7 |
+
import torch.distributed as dist
|
8 |
+
import os, sys, time, torch, copy, pdb
|
9 |
+
from collections import defaultdict, deque
|
10 |
+
|
11 |
+
def print_available_port():
|
12 |
+
|
13 |
+
return _find_free_port()
|
14 |
+
|
15 |
+
def _find_free_port():
|
16 |
+
|
17 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
18 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
19 |
+
sock.bind(("", 0))
|
20 |
+
port = sock.getsockname()[1]
|
21 |
+
sock.close()
|
22 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
23 |
+
return port
|
24 |
+
|
25 |
+
def ensure_dir(dirpath):
|
26 |
+
|
27 |
+
if not osp.exists(dirpath):
|
28 |
+
os.makedirs(dirpath, exist_ok=True)
|
29 |
+
|
30 |
+
def setup_for_distributed(is_master):
|
31 |
+
"""
|
32 |
+
This function disables printing when not in master process
|
33 |
+
"""
|
34 |
+
builtin_print = builtins.print
|
35 |
+
|
36 |
+
def print(*args, **kwargs):
|
37 |
+
force = kwargs.pop('force', False)
|
38 |
+
force = force or (get_world_size() > 8)
|
39 |
+
if is_master or force:
|
40 |
+
now = datetime.datetime.now().time()
|
41 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
42 |
+
builtin_print(*args, **kwargs)
|
43 |
+
|
44 |
+
builtins.print = print
|
45 |
+
|
46 |
+
|
47 |
+
def is_dist_avail_and_initialized():
|
48 |
+
if not dist.is_available():
|
49 |
+
return False
|
50 |
+
if not dist.is_initialized():
|
51 |
+
return False
|
52 |
+
return True
|
53 |
+
|
54 |
+
|
55 |
+
def get_world_size():
|
56 |
+
if not is_dist_avail_and_initialized():
|
57 |
+
return 1
|
58 |
+
return dist.get_world_size()
|
59 |
+
|
60 |
+
|
61 |
+
def get_rank():
|
62 |
+
if not is_dist_avail_and_initialized():
|
63 |
+
return 0
|
64 |
+
return dist.get_rank()
|
65 |
+
|
66 |
+
|
67 |
+
def concat_all_gather(tensor):
|
68 |
+
"""
|
69 |
+
Performs all_gather operation on the provided tensors.
|
70 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
71 |
+
"""
|
72 |
+
tensors_gather = [torch.ones_like(tensor)
|
73 |
+
for _ in range(torch.distributed.get_world_size())]
|
74 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
75 |
+
|
76 |
+
output = torch.cat(tensors_gather, dim=0)
|
77 |
+
return output
|
78 |
+
|
79 |
+
|
80 |
+
def is_main_process():
|
81 |
+
return get_rank() == 0
|
82 |
+
|
83 |
+
|
84 |
+
def save_on_master(*args, **kwargs):
|
85 |
+
if is_main_process():
|
86 |
+
torch.save(*args, **kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
def init_distributed_mode(args):
|
90 |
+
|
91 |
+
if args.dist_on_itp:
|
92 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
93 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
94 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
95 |
+
assert isinstance(args.port, int) & (args.port > 0) & (args.port < 1<<30)
|
96 |
+
port = _find_free_port()
|
97 |
+
# args.dist_url = "tcp://%s:%s" % (port, os.environ['MASTER_PORT'])
|
98 |
+
args.dist_url = f'tcp://127.0.0.1:{port}'
|
99 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
100 |
+
os.environ['RANK'] = str(args.rank)
|
101 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
102 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
103 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
104 |
+
args.rank = int(os.environ["RANK"])
|
105 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
106 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
107 |
+
elif 'SLURM_PROCID' in os.environ:
|
108 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
109 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
110 |
+
else:
|
111 |
+
print('Not using distributed mode')
|
112 |
+
setup_for_distributed(is_master=True) # hack
|
113 |
+
args.distributed = False
|
114 |
+
return
|
115 |
+
|
116 |
+
args.distributed = True
|
117 |
+
|
118 |
+
torch.cuda.set_device(args.gpu)
|
119 |
+
args.dist_backend = 'nccl'
|
120 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
121 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
122 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
123 |
+
world_size=args.world_size, rank=args.rank)
|
124 |
+
torch.distributed.barrier()
|
125 |
+
setup_for_distributed(args.rank == 0)
|
126 |
+
|
127 |
+
class NativeScalerWithGradNormCount:
|
128 |
+
state_dict_key = "amp_scaler"
|
129 |
+
|
130 |
+
def __init__(self):
|
131 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
132 |
+
|
133 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
134 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
135 |
+
if update_grad:
|
136 |
+
if clip_grad is not None:
|
137 |
+
assert parameters is not None
|
138 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
139 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
140 |
+
else:
|
141 |
+
self._scaler.unscale_(optimizer)
|
142 |
+
norm = get_grad_norm_(parameters)
|
143 |
+
self._scaler.step(optimizer)
|
144 |
+
self._scaler.update()
|
145 |
+
else:
|
146 |
+
norm = None
|
147 |
+
return norm
|
148 |
+
|
149 |
+
def state_dict(self):
|
150 |
+
return self._scaler.state_dict()
|
151 |
+
|
152 |
+
def load_state_dict(self, state_dict):
|
153 |
+
self._scaler.load_state_dict(state_dict)
|
154 |
+
|
155 |
+
|
156 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
157 |
+
if isinstance(parameters, torch.Tensor):
|
158 |
+
parameters = [parameters]
|
159 |
+
parameters = [p for p in parameters if p.grad is not None]
|
160 |
+
norm_type = float(norm_type)
|
161 |
+
if len(parameters) == 0:
|
162 |
+
return torch.tensor(0.)
|
163 |
+
device = parameters[0].grad.device
|
164 |
+
if norm_type == inf:
|
165 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
166 |
+
else:
|
167 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
168 |
+
return total_norm
|
169 |
+
|
170 |
+
|
171 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None):
|
172 |
+
output_dir = Path(args.output_dir)
|
173 |
+
epoch_name = str(epoch)
|
174 |
+
if loss_scaler is not None:
|
175 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
176 |
+
|
177 |
+
# ema
|
178 |
+
if ema_params is not None:
|
179 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
180 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
181 |
+
assert name in ema_state_dict
|
182 |
+
ema_state_dict[name] = ema_params[i]
|
183 |
+
else:
|
184 |
+
ema_state_dict = None
|
185 |
+
|
186 |
+
for checkpoint_path in checkpoint_paths:
|
187 |
+
to_save = {
|
188 |
+
'model': model_without_ddp.state_dict(),
|
189 |
+
'model_ema': ema_state_dict,
|
190 |
+
'optimizer': optimizer.state_dict(),
|
191 |
+
'epoch': epoch,
|
192 |
+
'scaler': loss_scaler.state_dict(),
|
193 |
+
'args': args,
|
194 |
+
}
|
195 |
+
|
196 |
+
save_on_master(to_save, checkpoint_path)
|
197 |
+
else:
|
198 |
+
client_state = {'epoch': epoch}
|
199 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
200 |
+
|
201 |
+
|
202 |
+
def save_model_last(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None):
|
203 |
+
|
204 |
+
output_dir = Path(args.output_dir)
|
205 |
+
epoch_name = 'last'
|
206 |
+
if loss_scaler is not None:
|
207 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
208 |
+
|
209 |
+
# ema
|
210 |
+
if ema_params is not None:
|
211 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
212 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
213 |
+
assert name in ema_state_dict
|
214 |
+
ema_state_dict[name] = ema_params[i]
|
215 |
+
else:
|
216 |
+
ema_state_dict = None
|
217 |
+
|
218 |
+
for checkpoint_path in checkpoint_paths:
|
219 |
+
to_save = {
|
220 |
+
'model': model_without_ddp.state_dict(),
|
221 |
+
'model_ema': ema_state_dict,
|
222 |
+
'optimizer': optimizer.state_dict(),
|
223 |
+
'epoch': epoch,
|
224 |
+
'scaler': loss_scaler.state_dict(),
|
225 |
+
'args': args,
|
226 |
+
}
|
227 |
+
|
228 |
+
save_on_master(to_save, checkpoint_path)
|
229 |
+
else:
|
230 |
+
client_state = {'epoch': epoch}
|
231 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
232 |
+
|
233 |
+
|
234 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
235 |
+
|
236 |
+
if osp.exists(osp.join(args.resume, "checkpoint-last.pth")):
|
237 |
+
resume_path = osp.join(args.resume, "checkpoint-last.pth")
|
238 |
+
else:
|
239 |
+
resume_path = args.resume
|
240 |
+
if args.resume:
|
241 |
+
checkpoint = torch.load(resume_path, map_location='cpu')
|
242 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
243 |
+
print("Resume checkpoint %s" % resume_path)
|
244 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate):
|
245 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
246 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
247 |
+
if 'scaler' in checkpoint:
|
248 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
249 |
+
print("With optim & sched!")
|
250 |
+
|
251 |
+
def all_reduce_mean(x):
|
252 |
+
|
253 |
+
world_size = get_world_size()
|
254 |
+
if world_size > 1:
|
255 |
+
x_reduce = torch.tensor(x).cuda()
|
256 |
+
dist.all_reduce(x_reduce)
|
257 |
+
x_reduce /= world_size
|
258 |
+
return x_reduce.item()
|
259 |
+
else:
|
260 |
+
return x
|
paintmind/engine/trainer.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import os.path as osp
|
3 |
+
import cv2
|
4 |
+
import shutil
|
5 |
+
import numpy as np
|
6 |
+
import copy
|
7 |
+
import torch_fidelity
|
8 |
+
import torch.nn as nn
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from collections import OrderedDict
|
11 |
+
from einops import rearrange
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from .util import instantiate_from_config
|
14 |
+
from torchvision.utils import make_grid, save_image
|
15 |
+
from torch.utils.data import DataLoader, random_split, DistributedSampler
|
16 |
+
from paintmind.utils.lr_scheduler import build_scheduler
|
17 |
+
from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
|
18 |
+
from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
|
19 |
+
from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
|
20 |
+
from torch.optim import AdamW
|
21 |
+
from concurrent.futures import ThreadPoolExecutor
|
22 |
+
from torchmetrics.functional.image import (
|
23 |
+
peak_signal_noise_ratio as psnr,
|
24 |
+
structural_similarity_index_measure as ssim
|
25 |
+
)
|
26 |
+
|
27 |
+
def requires_grad(model, flag=True):
|
28 |
+
for p in model.parameters():
|
29 |
+
p.requires_grad = flag
|
30 |
+
|
31 |
+
|
32 |
+
def save_img(img, save_path):
|
33 |
+
img = np.clip(img.numpy().transpose([1, 2, 0]) * 255, 0, 255)
|
34 |
+
img = img.astype(np.uint8)[:, :, ::-1]
|
35 |
+
cv2.imwrite(save_path, img)
|
36 |
+
|
37 |
+
def save_img_batch(imgs, save_paths):
|
38 |
+
"""Process and save multiple images at once using a thread pool."""
|
39 |
+
# Convert to numpy and prepare all images in one go
|
40 |
+
imgs = np.clip(imgs.numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
|
41 |
+
imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
|
42 |
+
|
43 |
+
# Use ProcessPoolExecutor which is generally better for CPU-bound tasks
|
44 |
+
# ThreadPoolExecutor is better for I/O-bound tasks like file saving
|
45 |
+
with ThreadPoolExecutor(max_workers=32) as pool:
|
46 |
+
# Submit all tasks at once
|
47 |
+
futures = [pool.submit(cv2.imwrite, path, img)
|
48 |
+
for path, img in zip(save_paths, imgs)]
|
49 |
+
# Wait for all tasks to complete
|
50 |
+
for future in futures:
|
51 |
+
future.result() # This will raise any exceptions that occurred
|
52 |
+
|
53 |
+
def get_fid_stats(real_dir, rec_dir, fid_stats):
|
54 |
+
stats = torch_fidelity.calculate_metrics(
|
55 |
+
input1=rec_dir,
|
56 |
+
input2=real_dir,
|
57 |
+
fid_statistics_file=fid_stats,
|
58 |
+
cuda=True,
|
59 |
+
isc=True,
|
60 |
+
fid=True,
|
61 |
+
kid=False,
|
62 |
+
prc=False,
|
63 |
+
verbose=False,
|
64 |
+
)
|
65 |
+
return stats
|
66 |
+
|
67 |
+
|
68 |
+
class EMAModel:
|
69 |
+
"""Model Exponential Moving Average."""
|
70 |
+
def __init__(self, model, device, decay=0.999):
|
71 |
+
self.device = device
|
72 |
+
self.decay = decay
|
73 |
+
self.ema_params = OrderedDict(
|
74 |
+
(name, param.clone().detach().to(device))
|
75 |
+
for name, param in model.named_parameters()
|
76 |
+
if param.requires_grad
|
77 |
+
)
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def update(self, model):
|
81 |
+
for name, param in model.named_parameters():
|
82 |
+
if param.requires_grad:
|
83 |
+
if name in self.ema_params:
|
84 |
+
self.ema_params[name].lerp_(param.data, 1 - self.decay)
|
85 |
+
else:
|
86 |
+
self.ema_params[name] = param.data.clone().detach()
|
87 |
+
|
88 |
+
def state_dict(self):
|
89 |
+
return self.ema_params
|
90 |
+
|
91 |
+
def load_state_dict(self, params):
|
92 |
+
self.ema_params = OrderedDict(
|
93 |
+
(name, param.clone().detach().to(self.device))
|
94 |
+
for name, param in params.items()
|
95 |
+
)
|
96 |
+
|
97 |
+
class DiffusionTrainer(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
model,
|
101 |
+
dataset,
|
102 |
+
test_dataset=None,
|
103 |
+
test_only=False,
|
104 |
+
num_epoch=400,
|
105 |
+
valid_size=32,
|
106 |
+
lr=None,
|
107 |
+
blr=1e-4,
|
108 |
+
cosine_lr=True,
|
109 |
+
lr_min=0,
|
110 |
+
warmup_epochs=100,
|
111 |
+
warmup_steps=None,
|
112 |
+
warmup_lr_init=0,
|
113 |
+
decay_steps=None,
|
114 |
+
batch_size=32,
|
115 |
+
eval_bs=32,
|
116 |
+
test_bs=64,
|
117 |
+
num_workers=0,
|
118 |
+
pin_memory=False,
|
119 |
+
max_grad_norm=None,
|
120 |
+
grad_accum_steps=1,
|
121 |
+
precision="bf16",
|
122 |
+
save_every=10000,
|
123 |
+
sample_every=1000,
|
124 |
+
fid_every=50000,
|
125 |
+
result_folder=None,
|
126 |
+
log_dir="./log",
|
127 |
+
steps=0,
|
128 |
+
cfg=1.0,
|
129 |
+
test_num_slots=None,
|
130 |
+
eval_fid=False,
|
131 |
+
fid_stats=None,
|
132 |
+
enable_ema=False,
|
133 |
+
use_multi_epochs_dataloader=False,
|
134 |
+
compile=False,
|
135 |
+
overfit=False,
|
136 |
+
making_cache=False,
|
137 |
+
cache_mode=False,
|
138 |
+
latent_cache_file=None,
|
139 |
+
):
|
140 |
+
super().__init__()
|
141 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
142 |
+
self.accelerator = Accelerator(
|
143 |
+
kwargs_handlers=[kwargs],
|
144 |
+
mixed_precision="bf16",
|
145 |
+
gradient_accumulation_steps=grad_accum_steps,
|
146 |
+
log_with="tensorboard",
|
147 |
+
project_dir=log_dir,
|
148 |
+
)
|
149 |
+
|
150 |
+
self.model = instantiate_from_config(model)
|
151 |
+
self.num_slots = model.params.num_slots
|
152 |
+
|
153 |
+
assert precision in ["bf16", "fp32"]
|
154 |
+
precision = "fp32"
|
155 |
+
if self.accelerator.is_main_process:
|
156 |
+
print("Overlooking specified precision and using autocast bf16...")
|
157 |
+
self.precision = precision
|
158 |
+
|
159 |
+
if test_dataset is not None:
|
160 |
+
test_dataset = instantiate_from_config(test_dataset)
|
161 |
+
self.test_ds = test_dataset
|
162 |
+
|
163 |
+
# Calculate padded dataset size to ensure even distribution
|
164 |
+
total_size = len(test_dataset)
|
165 |
+
world_size = self.accelerator.num_processes
|
166 |
+
padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
|
167 |
+
self.test_dataset_size = total_size
|
168 |
+
|
169 |
+
# Create a padded dataset wrapper
|
170 |
+
class PaddedDataset(torch.utils.data.Dataset):
|
171 |
+
def __init__(self, dataset, padding_size):
|
172 |
+
self.dataset = dataset
|
173 |
+
self.padding_size = padding_size
|
174 |
+
|
175 |
+
def __len__(self):
|
176 |
+
return len(self.dataset) + self.padding_size
|
177 |
+
|
178 |
+
def __getitem__(self, idx):
|
179 |
+
if idx < len(self.dataset):
|
180 |
+
return self.dataset[idx]
|
181 |
+
return self.dataset[0]
|
182 |
+
|
183 |
+
self.test_ds = PaddedDataset(self.test_ds, padding_size)
|
184 |
+
self.test_dl = DataLoader(
|
185 |
+
self.test_ds,
|
186 |
+
batch_size=test_bs,
|
187 |
+
num_workers=num_workers,
|
188 |
+
pin_memory=pin_memory,
|
189 |
+
shuffle=False,
|
190 |
+
drop_last=True,
|
191 |
+
)
|
192 |
+
if self.accelerator.is_main_process:
|
193 |
+
print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
|
194 |
+
else:
|
195 |
+
self.test_dl = None
|
196 |
+
self.test_only = test_only
|
197 |
+
|
198 |
+
if not test_only:
|
199 |
+
dataset = instantiate_from_config(dataset)
|
200 |
+
train_size = len(dataset) - valid_size
|
201 |
+
self.train_ds, self.valid_ds = random_split(
|
202 |
+
dataset,
|
203 |
+
[train_size, valid_size],
|
204 |
+
generator=torch.Generator().manual_seed(42),
|
205 |
+
)
|
206 |
+
if self.accelerator.is_main_process:
|
207 |
+
print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
|
208 |
+
|
209 |
+
sampler = DistributedSampler(
|
210 |
+
self.train_ds,
|
211 |
+
num_replicas=self.accelerator.num_processes,
|
212 |
+
rank=self.accelerator.process_index,
|
213 |
+
shuffle=True,
|
214 |
+
)
|
215 |
+
self.train_dl = DataLoader(
|
216 |
+
self.train_ds,
|
217 |
+
batch_size=batch_size,
|
218 |
+
sampler=sampler,
|
219 |
+
num_workers=num_workers,
|
220 |
+
pin_memory=pin_memory,
|
221 |
+
drop_last=True,
|
222 |
+
)
|
223 |
+
self.valid_dl = DataLoader(
|
224 |
+
self.valid_ds,
|
225 |
+
batch_size=eval_bs,
|
226 |
+
shuffle=False,
|
227 |
+
num_workers=num_workers,
|
228 |
+
pin_memory=pin_memory,
|
229 |
+
)
|
230 |
+
|
231 |
+
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
|
232 |
+
if lr is None:
|
233 |
+
lr = blr * effective_bs / 256
|
234 |
+
if self.accelerator.is_main_process:
|
235 |
+
print(f"Effective batch size is {effective_bs}")
|
236 |
+
|
237 |
+
params = filter(lambda p: p.requires_grad, self.model.parameters())
|
238 |
+
self.g_optim = AdamW(params, lr=lr, betas=(0.9, 0.95), weight_decay=0)
|
239 |
+
self.g_sched = self._create_scheduler(
|
240 |
+
cosine_lr, warmup_epochs, warmup_steps, num_epoch,
|
241 |
+
lr_min, warmup_lr_init, decay_steps
|
242 |
+
)
|
243 |
+
if self.g_sched is not None:
|
244 |
+
self.accelerator.register_for_checkpointing(self.g_sched)
|
245 |
+
|
246 |
+
self.steps = steps
|
247 |
+
self.loaded_steps = -1
|
248 |
+
|
249 |
+
# Prepare everything together
|
250 |
+
if not test_only:
|
251 |
+
self.model, self.g_optim, self.g_sched = self.accelerator.prepare(
|
252 |
+
self.model, self.g_optim, self.g_sched
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
|
256 |
+
|
257 |
+
if compile:
|
258 |
+
_model = self.accelerator.unwrap_model(self.model)
|
259 |
+
_model.vae = torch.compile(_model.vae, mode="reduce-overhead")
|
260 |
+
_model.dit = torch.compile(_model.dit, mode="reduce-overhead")
|
261 |
+
# _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why
|
262 |
+
_model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
|
263 |
+
|
264 |
+
self.enable_ema = enable_ema
|
265 |
+
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
|
266 |
+
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
|
267 |
+
self.accelerator.register_for_checkpointing(self.ema_model)
|
268 |
+
|
269 |
+
self._load_checkpoint(model.params.ckpt_path)
|
270 |
+
if self.test_only:
|
271 |
+
self.steps = self.loaded_steps
|
272 |
+
|
273 |
+
self.num_epoch = num_epoch
|
274 |
+
self.save_every = save_every
|
275 |
+
self.samp_every = sample_every
|
276 |
+
self.fid_every = fid_every
|
277 |
+
self.max_grad_norm = max_grad_norm
|
278 |
+
|
279 |
+
self.cache_mode = cache_mode
|
280 |
+
|
281 |
+
self.cfg = cfg
|
282 |
+
self.test_num_slots = test_num_slots
|
283 |
+
if self.test_num_slots is not None:
|
284 |
+
self.test_num_slots = min(self.test_num_slots, self.num_slots)
|
285 |
+
else:
|
286 |
+
self.test_num_slots = self.num_slots
|
287 |
+
eval_fid = eval_fid or model.params.eval_fid # legacy
|
288 |
+
self.eval_fid = eval_fid
|
289 |
+
if eval_fid:
|
290 |
+
if fid_stats is None:
|
291 |
+
fid_stats = model.params.fid_stats # legacy
|
292 |
+
assert fid_stats is not None
|
293 |
+
assert test_dataset is not None
|
294 |
+
self.fid_stats = fid_stats
|
295 |
+
|
296 |
+
self.use_vq = model.params.use_vq if hasattr(model.params, "use_vq") else False
|
297 |
+
self.vq_beta = model.params.code_beta if hasattr(model.params, "code_beta") else 0.25
|
298 |
+
|
299 |
+
self.result_folder = result_folder
|
300 |
+
self.model_saved_dir = os.path.join(result_folder, "models")
|
301 |
+
os.makedirs(self.model_saved_dir, exist_ok=True)
|
302 |
+
|
303 |
+
self.image_saved_dir = os.path.join(result_folder, "images")
|
304 |
+
os.makedirs(self.image_saved_dir, exist_ok=True)
|
305 |
+
|
306 |
+
@property
|
307 |
+
def device(self):
|
308 |
+
return self.accelerator.device
|
309 |
+
|
310 |
+
def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
|
311 |
+
if warmup_epochs is not None:
|
312 |
+
warmup_steps = warmup_epochs * len(self.train_dl)
|
313 |
+
else:
|
314 |
+
assert warmup_steps is not None
|
315 |
+
|
316 |
+
scheduler = build_scheduler(
|
317 |
+
self.g_optim,
|
318 |
+
num_epoch,
|
319 |
+
len(self.train_dl),
|
320 |
+
lr_min,
|
321 |
+
warmup_steps,
|
322 |
+
warmup_lr_init,
|
323 |
+
decay_steps,
|
324 |
+
cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix)
|
325 |
+
)
|
326 |
+
return scheduler
|
327 |
+
|
328 |
+
def _load_state_dict(self, state_dict):
|
329 |
+
"""Helper to load a state dict with proper prefix handling."""
|
330 |
+
if 'state_dict' in state_dict:
|
331 |
+
state_dict = state_dict['state_dict']
|
332 |
+
# Remove '_orig_mod' prefix if present
|
333 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
334 |
+
missing, unexpected = self.accelerator.unwrap_model(self.model).load_state_dict(
|
335 |
+
state_dict, strict=False
|
336 |
+
)
|
337 |
+
if self.accelerator.is_main_process:
|
338 |
+
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
|
339 |
+
|
340 |
+
def _load_safetensors(self, path):
|
341 |
+
"""Helper to load a safetensors checkpoint."""
|
342 |
+
from safetensors.torch import safe_open
|
343 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
344 |
+
state_dict = {k: f.get_tensor(k) for k in f.keys()}
|
345 |
+
self._load_state_dict(state_dict)
|
346 |
+
|
347 |
+
def _load_checkpoint(self, ckpt_path=None):
|
348 |
+
if ckpt_path is None or not osp.exists(ckpt_path):
|
349 |
+
return
|
350 |
+
|
351 |
+
if osp.isdir(ckpt_path):
|
352 |
+
# ckpt_path is something like 'path/to/models/step10/'
|
353 |
+
self.loaded_steps = int(
|
354 |
+
ckpt_path.split("step")[-1].split("/")[0]
|
355 |
+
)
|
356 |
+
if not self.test_only:
|
357 |
+
self.accelerator.load_state(ckpt_path)
|
358 |
+
else:
|
359 |
+
if self.enable_ema:
|
360 |
+
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
|
361 |
+
if osp.exists(model_path):
|
362 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
363 |
+
self._load_state_dict(state_dict)
|
364 |
+
if self.accelerator.is_main_process:
|
365 |
+
print(f"Loaded ema model from {model_path}")
|
366 |
+
else:
|
367 |
+
model_path = osp.join(ckpt_path, "model.safetensors")
|
368 |
+
if osp.exists(model_path):
|
369 |
+
self._load_safetensors(model_path)
|
370 |
+
else:
|
371 |
+
# ckpt_path is something like 'path/to/models/step10.pt'
|
372 |
+
if ckpt_path.endswith(".safetensors"):
|
373 |
+
self._load_safetensors(ckpt_path)
|
374 |
+
else:
|
375 |
+
state_dict = torch.load(ckpt_path)
|
376 |
+
self._load_state_dict(state_dict)
|
377 |
+
if self.accelerator.is_main_process:
|
378 |
+
print(f"Loaded checkpoint from {ckpt_path}")
|
379 |
+
|
380 |
+
def train(self, config=None):
|
381 |
+
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
382 |
+
if self.accelerator.is_main_process:
|
383 |
+
print(f"number of learnable parameters: {n_parameters//1e6}M")
|
384 |
+
if config is not None:
|
385 |
+
# save the config
|
386 |
+
import shutil
|
387 |
+
from omegaconf import OmegaConf
|
388 |
+
|
389 |
+
if isinstance(config, str) and osp.exists(config):
|
390 |
+
# If it's a path, copy the file to config.yaml
|
391 |
+
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
|
392 |
+
else:
|
393 |
+
# If it's an OmegaConf object, dump it
|
394 |
+
config_save_path = osp.join(self.result_folder, "config.yaml")
|
395 |
+
OmegaConf.save(config, config_save_path)
|
396 |
+
|
397 |
+
self.accelerator.init_trackers("vqgan")
|
398 |
+
|
399 |
+
if self.test_only:
|
400 |
+
empty_cache()
|
401 |
+
self.evaluate()
|
402 |
+
self.accelerator.wait_for_everyone()
|
403 |
+
empty_cache()
|
404 |
+
return
|
405 |
+
|
406 |
+
for epoch in range(self.num_epoch):
|
407 |
+
if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
|
408 |
+
if self.accelerator.is_main_process:
|
409 |
+
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
|
410 |
+
self.steps += len(self.train_dl)
|
411 |
+
continue
|
412 |
+
|
413 |
+
if self.steps < self.loaded_steps:
|
414 |
+
for _ in self.train_dl:
|
415 |
+
self.steps += 1
|
416 |
+
if self.steps >= self.loaded_steps:
|
417 |
+
break
|
418 |
+
|
419 |
+
|
420 |
+
self.accelerator.unwrap_model(self.model).current_epoch = epoch
|
421 |
+
self.model.train() # Set model to training mode
|
422 |
+
|
423 |
+
logger = MetricLogger(delimiter=" ")
|
424 |
+
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
425 |
+
header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
|
426 |
+
print_freq = 20
|
427 |
+
for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
|
428 |
+
# Move batch to device once
|
429 |
+
if isinstance(batch, tuple) or isinstance(batch, list):
|
430 |
+
batch = tuple(b.to(self.device, non_blocking=True) for b in batch)
|
431 |
+
if self.cache_mode:
|
432 |
+
img, latent, targets = batch[0], batch[1], batch[2]
|
433 |
+
img = img.to(self.device, non_blocking=True)
|
434 |
+
latent = latent.to(self.device, non_blocking=True)
|
435 |
+
targets = targets.to(self.device, non_blocking=True)
|
436 |
+
else:
|
437 |
+
latent = None
|
438 |
+
img, targets = batch[0], batch[1]
|
439 |
+
img = img.to(self.device, non_blocking=True)
|
440 |
+
targets = targets.to(self.device, non_blocking=True)
|
441 |
+
else:
|
442 |
+
img = batch
|
443 |
+
latent = None
|
444 |
+
|
445 |
+
self.steps += 1
|
446 |
+
|
447 |
+
with self.accelerator.accumulate(self.model):
|
448 |
+
with self.accelerator.autocast():
|
449 |
+
if self.steps == 1:
|
450 |
+
print(f"Training batch size: {img.size(0)}")
|
451 |
+
print(f"Hello from index {self.accelerator.local_process_index}")
|
452 |
+
losses = self.model(img, targets, latents=latent, epoch=epoch)
|
453 |
+
# combine
|
454 |
+
loss = sum([v for _, v in losses.items()])
|
455 |
+
diff_loss = losses["diff_loss"]
|
456 |
+
if self.use_vq:
|
457 |
+
vq_loss = losses["vq_loss"]
|
458 |
+
|
459 |
+
self.accelerator.backward(loss)
|
460 |
+
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
|
461 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
462 |
+
self.accelerator.unwrap_model(self.model).cancel_gradients_encoder(epoch)
|
463 |
+
self.g_optim.step()
|
464 |
+
if self.g_sched is not None:
|
465 |
+
self.g_sched.step_update(self.steps)
|
466 |
+
self.g_optim.zero_grad()
|
467 |
+
|
468 |
+
# synchronize_processes()
|
469 |
+
|
470 |
+
# update ema with state dict
|
471 |
+
if self.enable_ema:
|
472 |
+
self.ema_model.update(self.accelerator.unwrap_model(self.model))
|
473 |
+
|
474 |
+
logger.update(diff_loss=diff_loss.item())
|
475 |
+
if self.use_vq:
|
476 |
+
logger.update(vq_loss=vq_loss.item() / self.vq_beta)
|
477 |
+
if 'kl_loss' in losses:
|
478 |
+
logger.update(kl_loss=losses["kl_loss"].item())
|
479 |
+
if 'repa_loss' in losses:
|
480 |
+
logger.update(repa_loss=losses["repa_loss"].item())
|
481 |
+
logger.update(lr=self.g_optim.param_groups[0]["lr"])
|
482 |
+
|
483 |
+
if self.steps % self.save_every == 0:
|
484 |
+
self.save()
|
485 |
+
|
486 |
+
if (self.steps % self.samp_every == 0) or (self.steps % self.fid_every == 0):
|
487 |
+
empty_cache()
|
488 |
+
self.evaluate()
|
489 |
+
self.accelerator.wait_for_everyone()
|
490 |
+
empty_cache()
|
491 |
+
|
492 |
+
# omitted all_gather here
|
493 |
+
# write_dict = dict(epoch=epoch)
|
494 |
+
# write_dict.update(diff_loss=diff_loss.item())
|
495 |
+
# if "kl_loss" in losses:
|
496 |
+
# write_dict.update(kl_loss=losses["kl_loss"].item())
|
497 |
+
# if self.use_vq:
|
498 |
+
# write_dict.update(vq_loss=vq_loss.item() / self.vq_beta)
|
499 |
+
# write_dict.update(lr=self.g_optim.param_groups[0]["lr"])
|
500 |
+
# self.accelerator.log(write_dict, step=self.steps)
|
501 |
+
|
502 |
+
logger.synchronize_between_processes()
|
503 |
+
if self.accelerator.is_main_process:
|
504 |
+
print("Averaged stats:", logger)
|
505 |
+
|
506 |
+
self.accelerator.end_training()
|
507 |
+
self.save()
|
508 |
+
if self.accelerator.is_main_process:
|
509 |
+
print("Train finished!")
|
510 |
+
|
511 |
+
def save(self):
|
512 |
+
self.accelerator.wait_for_everyone()
|
513 |
+
self.accelerator.save_state(
|
514 |
+
os.path.join(self.model_saved_dir, f"step{self.steps}")
|
515 |
+
)
|
516 |
+
|
517 |
+
@torch.no_grad()
|
518 |
+
def evaluate(self, use_ema=True):
|
519 |
+
self.model.eval()
|
520 |
+
# switch to ema params, only when eval_fid is True
|
521 |
+
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
|
522 |
+
# use_ema = False
|
523 |
+
if use_ema:
|
524 |
+
if hasattr(self, "ema_model"):
|
525 |
+
model_without_ddp = self.accelerator.unwrap_model(self.model)
|
526 |
+
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
527 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
528 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
529 |
+
if "nested_sampler" in name:
|
530 |
+
continue
|
531 |
+
if name in self.ema_model.state_dict():
|
532 |
+
ema_state_dict[name] = self.ema_model.state_dict()[name]
|
533 |
+
if self.accelerator.is_main_process:
|
534 |
+
print("Switch to ema")
|
535 |
+
msg = model_without_ddp.load_state_dict(ema_state_dict)
|
536 |
+
if self.accelerator.is_main_process:
|
537 |
+
print(msg)
|
538 |
+
else:
|
539 |
+
print("EMA model not found, using original model")
|
540 |
+
use_ema = False
|
541 |
+
|
542 |
+
if not self.test_only:
|
543 |
+
with tqdm(
|
544 |
+
self.valid_dl,
|
545 |
+
dynamic_ncols=True,
|
546 |
+
disable=not self.accelerator.is_main_process,
|
547 |
+
) as valid_dl:
|
548 |
+
for batch_i, batch in enumerate(valid_dl):
|
549 |
+
if isinstance(batch, tuple) or isinstance(batch, list):
|
550 |
+
img, targets = batch[0], batch[1]
|
551 |
+
else:
|
552 |
+
img = batch
|
553 |
+
|
554 |
+
with self.accelerator.autocast():
|
555 |
+
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
|
556 |
+
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
|
557 |
+
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
|
558 |
+
imgs_and_recs = imgs_and_recs.detach().cpu().float()
|
559 |
+
|
560 |
+
grid = make_grid(
|
561 |
+
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
|
562 |
+
)
|
563 |
+
if self.accelerator.is_main_process:
|
564 |
+
save_image(
|
565 |
+
grid,
|
566 |
+
os.path.join(
|
567 |
+
self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
|
568 |
+
),
|
569 |
+
)
|
570 |
+
|
571 |
+
if self.cfg != 1.0:
|
572 |
+
with self.accelerator.autocast():
|
573 |
+
rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
|
574 |
+
|
575 |
+
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
|
576 |
+
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
|
577 |
+
imgs_and_recs = imgs_and_recs.detach().cpu().float()
|
578 |
+
|
579 |
+
grid = make_grid(
|
580 |
+
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
|
581 |
+
)
|
582 |
+
if self.accelerator.is_main_process:
|
583 |
+
save_image(
|
584 |
+
grid,
|
585 |
+
os.path.join(
|
586 |
+
self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
|
587 |
+
),
|
588 |
+
)
|
589 |
+
if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
|
590 |
+
# Create output directories
|
591 |
+
if self.test_dataset_size > 10000:
|
592 |
+
real_dir = "./dataset/imagenet/val256"
|
593 |
+
else:
|
594 |
+
real_dir = "./dataset/coco/val2017_256"
|
595 |
+
rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
|
596 |
+
os.makedirs(rec_dir, exist_ok=True)
|
597 |
+
|
598 |
+
if self.cfg != 1.0:
|
599 |
+
rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
|
600 |
+
os.makedirs(rec_cfg_dir, exist_ok=True)
|
601 |
+
|
602 |
+
def process_batch(cfg_value, save_dir, header):
|
603 |
+
logger = MetricLogger(delimiter=" ")
|
604 |
+
print_freq = 5
|
605 |
+
psnr_values = []
|
606 |
+
ssim_values = []
|
607 |
+
total_processed = 0
|
608 |
+
|
609 |
+
for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
|
610 |
+
imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
|
611 |
+
|
612 |
+
# Skip processing if we've already processed all real samples
|
613 |
+
if total_processed >= self.test_dataset_size:
|
614 |
+
break
|
615 |
+
|
616 |
+
imgs = imgs.to(self.device, non_blocking=True)
|
617 |
+
if targets is not None:
|
618 |
+
targets = targets.to(self.device, non_blocking=True)
|
619 |
+
|
620 |
+
with self.accelerator.autocast():
|
621 |
+
recs = self.model(imgs, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
|
622 |
+
|
623 |
+
psnr_val = psnr(recs, imgs, data_range=1.0)
|
624 |
+
ssim_val = ssim(recs, imgs, data_range=1.0)
|
625 |
+
|
626 |
+
recs = concat_all_gather(recs).detach()
|
627 |
+
psnr_val = concat_all_gather(psnr_val.view(1))
|
628 |
+
ssim_val = concat_all_gather(ssim_val.view(1))
|
629 |
+
|
630 |
+
# Remove padding after gathering from all GPUs
|
631 |
+
samples_in_batch = min(
|
632 |
+
recs.size(0), # Always use the gathered size
|
633 |
+
self.test_dataset_size - total_processed
|
634 |
+
)
|
635 |
+
recs = recs[:samples_in_batch]
|
636 |
+
psnr_val = psnr_val[:samples_in_batch]
|
637 |
+
ssim_val = ssim_val[:samples_in_batch]
|
638 |
+
psnr_values.append(psnr_val)
|
639 |
+
ssim_values.append(ssim_val)
|
640 |
+
|
641 |
+
if self.accelerator.is_main_process:
|
642 |
+
rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
|
643 |
+
for j in range(recs.size(0))]
|
644 |
+
save_img_batch(recs.cpu(), rec_paths)
|
645 |
+
|
646 |
+
total_processed += samples_in_batch
|
647 |
+
|
648 |
+
self.accelerator.wait_for_everyone()
|
649 |
+
|
650 |
+
return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
|
651 |
+
|
652 |
+
# Helper function to calculate and log metrics
|
653 |
+
def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
|
654 |
+
if self.accelerator.is_main_process:
|
655 |
+
metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
|
656 |
+
fid = metrics_dict["frechet_inception_distance"]
|
657 |
+
inception_score = metrics_dict["inception_score_mean"]
|
658 |
+
|
659 |
+
metric_prefix = "fid_ema" if use_ema else "fid"
|
660 |
+
isc_prefix = "isc_ema" if use_ema else "isc"
|
661 |
+
self.accelerator.log({
|
662 |
+
metric_prefix: fid,
|
663 |
+
isc_prefix: inception_score,
|
664 |
+
f"psnr_{'ema' if use_ema else 'test'}": psnr_val,
|
665 |
+
f"ssim_{'ema' if use_ema else 'test'}": ssim_val,
|
666 |
+
"cfg": cfg_value
|
667 |
+
}, step=self.steps)
|
668 |
+
|
669 |
+
print(f"{'EMA ' if use_ema else ''}{f'CFG: {cfg_value}'} "
|
670 |
+
f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
|
671 |
+
f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
|
672 |
+
|
673 |
+
# Process without CFG
|
674 |
+
if self.cfg == 1.0 or not self.test_only:
|
675 |
+
psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
|
676 |
+
calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
|
677 |
+
|
678 |
+
# Process with CFG if needed
|
679 |
+
if self.cfg != 1.0:
|
680 |
+
psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
|
681 |
+
calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
|
682 |
+
|
683 |
+
# Cleanup
|
684 |
+
if self.accelerator.is_main_process:
|
685 |
+
shutil.rmtree(rec_dir)
|
686 |
+
if self.cfg != 1.0:
|
687 |
+
shutil.rmtree(rec_cfg_dir)
|
688 |
+
|
689 |
+
# back to no ema
|
690 |
+
if use_ema:
|
691 |
+
if self.accelerator.is_main_process:
|
692 |
+
print("Switch back from ema")
|
693 |
+
model_without_ddp.load_state_dict(model_state_dict)
|
694 |
+
|
695 |
+
self.model.train()
|
paintmind/engine/util.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os.path as osp
|
3 |
+
import torch_fidelity
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
import pickle as pkl
|
7 |
+
import os, hashlib, pdb
|
8 |
+
from pathlib import Path
|
9 |
+
from torch import Tensor
|
10 |
+
import torch, torchvision
|
11 |
+
from einops import rearrange
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
import torch.distributed as dist
|
14 |
+
from typing import List, Optional
|
15 |
+
from torchvision import transforms
|
16 |
+
from io import BytesIO as Bytes2Data
|
17 |
+
from smart_open import open
|
18 |
+
from .misc import is_main_process, get_rank
|
19 |
+
import importlib, datetime, requests, time, shutil
|
20 |
+
from collections import defaultdict, deque, OrderedDict
|
21 |
+
from dotwiz import DotWiz
|
22 |
+
|
23 |
+
URL_MAP = {
|
24 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
25 |
+
}
|
26 |
+
|
27 |
+
CKPT_MAP = {
|
28 |
+
"vgg_lpips": "vgg.pth"
|
29 |
+
}
|
30 |
+
|
31 |
+
MD5_MAP = {
|
32 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
33 |
+
}
|
34 |
+
|
35 |
+
def disabled_train(self, mode=True):
|
36 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
37 |
+
does not change anymore."""
|
38 |
+
return self
|
39 |
+
|
40 |
+
def customized_collate_fn(batch):
|
41 |
+
|
42 |
+
collate_fn = {}
|
43 |
+
if len(batch) < 2:
|
44 |
+
for key, value in batch[0].items():
|
45 |
+
collate_fn[key] = [value]
|
46 |
+
else:
|
47 |
+
|
48 |
+
for i, dd in enumerate(batch):
|
49 |
+
if i < 1:
|
50 |
+
for key, value in dd.items():
|
51 |
+
collate_fn[key] = [value]
|
52 |
+
else:
|
53 |
+
for key, value in dd.items():
|
54 |
+
collate_fn[key].append(value)
|
55 |
+
|
56 |
+
return collate_fn
|
57 |
+
|
58 |
+
|
59 |
+
def trivial_batch_collator(batch):
|
60 |
+
"""
|
61 |
+
A batch collator that does nothing.
|
62 |
+
"""
|
63 |
+
return batch
|
64 |
+
|
65 |
+
class NestedTensor(object):
|
66 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
67 |
+
self.tensors = tensors
|
68 |
+
self.mask = mask
|
69 |
+
|
70 |
+
def to(self, device):
|
71 |
+
# type: (Device) -> NestedTensor # noqa
|
72 |
+
cast_tensor = self.tensors.to(device)
|
73 |
+
mask = self.mask
|
74 |
+
if mask is not None:
|
75 |
+
assert mask is not None
|
76 |
+
cast_mask = mask.to(device)
|
77 |
+
else:
|
78 |
+
cast_mask = None
|
79 |
+
return NestedTensor(cast_tensor, cast_mask)
|
80 |
+
|
81 |
+
def decompose(self):
|
82 |
+
return self.tensors, self.mask
|
83 |
+
|
84 |
+
def __repr__(self):
|
85 |
+
return str(self.tensors)
|
86 |
+
|
87 |
+
|
88 |
+
def _max_by_axis(the_list):
|
89 |
+
# type: (List[List[int]]) -> List[int]
|
90 |
+
maxes = the_list[0]
|
91 |
+
for sublist in the_list[1:]:
|
92 |
+
for index, item in enumerate(sublist):
|
93 |
+
maxes[index] = max(maxes[index], item)
|
94 |
+
return maxes
|
95 |
+
|
96 |
+
|
97 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
98 |
+
# TODO make this more general
|
99 |
+
if tensor_list[0].ndim == 3:
|
100 |
+
if torchvision._is_tracing():
|
101 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
102 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
103 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
104 |
+
|
105 |
+
# TODO make it support different-sized images
|
106 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
107 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
108 |
+
batch_shape = [len(tensor_list)] + max_size
|
109 |
+
b, c, h, w = batch_shape
|
110 |
+
dtype = tensor_list[0].dtype
|
111 |
+
device = tensor_list[0].device
|
112 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
113 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
114 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
115 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
116 |
+
m[: img.shape[1], : img.shape[2]] = False
|
117 |
+
else:
|
118 |
+
raise ValueError("not supported")
|
119 |
+
return NestedTensor(tensor, mask)
|
120 |
+
|
121 |
+
|
122 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
123 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
124 |
+
@torch.jit.unused
|
125 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
126 |
+
max_size = []
|
127 |
+
for i in range(tensor_list[0].dim()):
|
128 |
+
max_size_i = torch.max(
|
129 |
+
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
130 |
+
).to(torch.int64)
|
131 |
+
max_size.append(max_size_i)
|
132 |
+
max_size = tuple(max_size)
|
133 |
+
|
134 |
+
# work around for
|
135 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
136 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
137 |
+
# which is not yet supported in onnx
|
138 |
+
padded_imgs = []
|
139 |
+
padded_masks = []
|
140 |
+
for img in tensor_list:
|
141 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
142 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
143 |
+
padded_imgs.append(padded_img)
|
144 |
+
|
145 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
146 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
147 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
148 |
+
|
149 |
+
tensor = torch.stack(padded_imgs)
|
150 |
+
mask = torch.stack(padded_masks)
|
151 |
+
|
152 |
+
return NestedTensor(tensor, mask=mask)
|
153 |
+
|
154 |
+
|
155 |
+
def is_dist_avail_and_initialized():
|
156 |
+
if not dist.is_available():
|
157 |
+
return False
|
158 |
+
if not dist.is_initialized():
|
159 |
+
return False
|
160 |
+
return True
|
161 |
+
|
162 |
+
|
163 |
+
class SmoothedValue(object):
|
164 |
+
"""Track a series of values and provide access to smoothed values over a
|
165 |
+
window or the global series average.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, window_size=20, fmt=None):
|
169 |
+
if fmt is None:
|
170 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
171 |
+
self.deque = deque(maxlen=window_size)
|
172 |
+
self.total = 0.0
|
173 |
+
self.count = 0
|
174 |
+
self.fmt = fmt
|
175 |
+
|
176 |
+
def update(self, value, n=1):
|
177 |
+
self.deque.append(value)
|
178 |
+
self.count += n
|
179 |
+
self.total += value * n
|
180 |
+
|
181 |
+
def synchronize_between_processes(self):
|
182 |
+
"""
|
183 |
+
Warning: does not synchronize the deque!
|
184 |
+
"""
|
185 |
+
if not is_dist_avail_and_initialized():
|
186 |
+
return
|
187 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
188 |
+
dist.barrier()
|
189 |
+
dist.all_reduce(t)
|
190 |
+
t = t.tolist()
|
191 |
+
self.count = int(t[0])
|
192 |
+
self.total = t[1]
|
193 |
+
|
194 |
+
@property
|
195 |
+
def median(self):
|
196 |
+
d = torch.tensor(list(self.deque))
|
197 |
+
return d.median().item()
|
198 |
+
|
199 |
+
@property
|
200 |
+
def avg(self):
|
201 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
202 |
+
return d.mean().item()
|
203 |
+
|
204 |
+
@property
|
205 |
+
def global_avg(self):
|
206 |
+
return self.total / self.count
|
207 |
+
|
208 |
+
@property
|
209 |
+
def max(self):
|
210 |
+
return max(self.deque)
|
211 |
+
|
212 |
+
@property
|
213 |
+
def value(self):
|
214 |
+
return self.deque[-1]
|
215 |
+
|
216 |
+
def __str__(self):
|
217 |
+
return self.fmt.format(
|
218 |
+
median=self.median,
|
219 |
+
avg=self.avg,
|
220 |
+
global_avg=self.global_avg,
|
221 |
+
max=self.max,
|
222 |
+
value=self.value)
|
223 |
+
|
224 |
+
|
225 |
+
class MetricLogger(object):
|
226 |
+
def __init__(self, delimiter="\t"):
|
227 |
+
self.meters = defaultdict(SmoothedValue)
|
228 |
+
self.delimiter = delimiter
|
229 |
+
|
230 |
+
def update(self, **kwargs):
|
231 |
+
for k, v in kwargs.items():
|
232 |
+
if v is None:
|
233 |
+
continue
|
234 |
+
if isinstance(v, torch.Tensor):
|
235 |
+
v = v.item()
|
236 |
+
assert isinstance(v, (float, int))
|
237 |
+
self.meters[k].update(v)
|
238 |
+
|
239 |
+
def __getattr__(self, attr):
|
240 |
+
if attr in self.meters:
|
241 |
+
return self.meters[attr]
|
242 |
+
if attr in self.__dict__:
|
243 |
+
return self.__dict__[attr]
|
244 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
245 |
+
type(self).__name__, attr))
|
246 |
+
|
247 |
+
def __str__(self):
|
248 |
+
loss_str = []
|
249 |
+
for name, meter in self.meters.items():
|
250 |
+
loss_str.append(
|
251 |
+
"{}: {}".format(name, str(meter))
|
252 |
+
)
|
253 |
+
return self.delimiter.join(loss_str)
|
254 |
+
|
255 |
+
def synchronize_between_processes(self):
|
256 |
+
for meter in self.meters.values():
|
257 |
+
meter.synchronize_between_processes()
|
258 |
+
|
259 |
+
def add_meter(self, name, meter):
|
260 |
+
self.meters[name] = meter
|
261 |
+
|
262 |
+
def log_every(self, iterable, print_freq, header=None):
|
263 |
+
i = 0
|
264 |
+
if not header:
|
265 |
+
header = ''
|
266 |
+
start_time = time.time()
|
267 |
+
end = time.time()
|
268 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
269 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
270 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
271 |
+
log_msg = [
|
272 |
+
header,
|
273 |
+
'[{0' + space_fmt + '}/{1}]',
|
274 |
+
'eta: {eta}',
|
275 |
+
'{meters}',
|
276 |
+
'time: {time}',
|
277 |
+
'data: {data}'
|
278 |
+
]
|
279 |
+
if torch.cuda.is_available():
|
280 |
+
log_msg.append('max mem: {memory:.0f}')
|
281 |
+
log_msg = self.delimiter.join(log_msg)
|
282 |
+
MB = 1024.0 * 1024.0
|
283 |
+
for obj in iterable:
|
284 |
+
data_time.update(time.time() - end)
|
285 |
+
yield obj
|
286 |
+
iter_time.update(time.time() - end)
|
287 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
288 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
289 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
290 |
+
if torch.cuda.is_available():
|
291 |
+
print(log_msg.format(
|
292 |
+
i, len(iterable), eta=eta_string,
|
293 |
+
meters=str(self),
|
294 |
+
time=str(iter_time), data=str(data_time),
|
295 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
296 |
+
else:
|
297 |
+
print(log_msg.format(
|
298 |
+
i, len(iterable), eta=eta_string,
|
299 |
+
meters=str(self),
|
300 |
+
time=str(iter_time), data=str(data_time)))
|
301 |
+
i += 1
|
302 |
+
end = time.time()
|
303 |
+
total_time = time.time() - start_time
|
304 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
305 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
306 |
+
header, total_time_str, total_time / len(iterable)))
|
307 |
+
|
308 |
+
|
309 |
+
def all_reduce_mean(x):
|
310 |
+
world_size = dist.get_world_size()
|
311 |
+
if world_size > 1:
|
312 |
+
x_reduce = torch.tensor(x).cuda()
|
313 |
+
dist.all_reduce(x_reduce)
|
314 |
+
x_reduce /= world_size
|
315 |
+
return x_reduce.item()
|
316 |
+
else:
|
317 |
+
return x
|
318 |
+
|
319 |
+
|
320 |
+
class NativeScaler:
|
321 |
+
state_dict_key = "amp_scaler"
|
322 |
+
|
323 |
+
def __init__(self):
|
324 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
325 |
+
|
326 |
+
def __call__(self, loss, optimizer, clip_grad=3., parameters=None, create_graph=False, update_grad=True):
|
327 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
328 |
+
if update_grad:
|
329 |
+
if clip_grad is not None:
|
330 |
+
assert parameters is not None
|
331 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
332 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
333 |
+
else:
|
334 |
+
self._scaler.unscale_(optimizer)
|
335 |
+
norm = get_grad_norm_(parameters)
|
336 |
+
self._scaler.step(optimizer)
|
337 |
+
self._scaler.update()
|
338 |
+
else:
|
339 |
+
norm = None
|
340 |
+
return norm
|
341 |
+
|
342 |
+
def state_dict(self):
|
343 |
+
return self._scaler.state_dict()
|
344 |
+
|
345 |
+
def load_state_dict(self, state_dict):
|
346 |
+
self._scaler.load_state_dict(state_dict)
|
347 |
+
|
348 |
+
|
349 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
350 |
+
if isinstance(parameters, torch.Tensor):
|
351 |
+
parameters = [parameters]
|
352 |
+
parameters = [p for p in parameters if p.grad is not None and p.requires_grad]
|
353 |
+
norm_type = float(norm_type)
|
354 |
+
if len(parameters) == 0:
|
355 |
+
return torch.tensor(0.)
|
356 |
+
device = parameters[0].grad.device
|
357 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
|
358 |
+
norm_type)
|
359 |
+
return total_norm
|
360 |
+
|
361 |
+
|
362 |
+
def get_obj_from_str(string, reload=False):
|
363 |
+
module, cls = string.rsplit(".", 1)
|
364 |
+
if reload:
|
365 |
+
module_imp = importlib.import_module(module)
|
366 |
+
importlib.reload(module_imp)
|
367 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
368 |
+
|
369 |
+
|
370 |
+
def instantiate_from_config(config):
|
371 |
+
if not "target" in config:
|
372 |
+
raise KeyError("Expected key `target` to instantiate.")
|
373 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
374 |
+
|
375 |
+
|
376 |
+
def save_on_master(*args, **kwargs):
|
377 |
+
if dist.get_rank() == 0:
|
378 |
+
torch.save(*args, **kwargs)
|
379 |
+
|
380 |
+
|
381 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer_g, optimizer_d, loss_scaler):
|
382 |
+
output_dir = Path(args.output_dir)
|
383 |
+
epoch_name = str(epoch)
|
384 |
+
if loss_scaler is not None:
|
385 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
386 |
+
for checkpoint_path in checkpoint_paths:
|
387 |
+
to_save = {
|
388 |
+
'model': model_without_ddp.state_dict(),
|
389 |
+
'optimizer_g': optimizer_g.state_dict(),
|
390 |
+
'optimizer_d': optimizer_d.state_dict(),
|
391 |
+
'epoch': epoch,
|
392 |
+
'scaler': loss_scaler.state_dict(),
|
393 |
+
'args': args,
|
394 |
+
}
|
395 |
+
|
396 |
+
save_on_master(to_save, checkpoint_path)
|
397 |
+
else:
|
398 |
+
client_state = {'epoch': epoch}
|
399 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
400 |
+
|
401 |
+
|
402 |
+
def load_model(args, model_without_ddp, optimizer_g, optimizer_d, loss_scaler):
|
403 |
+
if args.resume:
|
404 |
+
if args.resume.startswith('https'):
|
405 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
406 |
+
args.resume, map_location='cpu', check_hash=True)
|
407 |
+
else:
|
408 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
409 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
410 |
+
print("Resume checkpoint %s" % args.resume)
|
411 |
+
if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
412 |
+
optimizer_g.load_state_dict(checkpoint['optimizer_g'])
|
413 |
+
optimizer_d.load_state_dict(checkpoint['optimizer_d'])
|
414 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
415 |
+
if 'scaler' in checkpoint:
|
416 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
417 |
+
print("With optim & sched!")
|
418 |
+
|
419 |
+
|
420 |
+
def download(url, local_path, chunk_size=1024):
|
421 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
422 |
+
with requests.get(url, stream=True) as r:
|
423 |
+
total_size = int(r.headers.get("content-length", 0))
|
424 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
425 |
+
with open(local_path, "wb") as f:
|
426 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
427 |
+
if data:
|
428 |
+
f.write(data)
|
429 |
+
pbar.update(chunk_size)
|
430 |
+
|
431 |
+
|
432 |
+
def md5_hash(path):
|
433 |
+
with open(path, "rb") as f:
|
434 |
+
content = f.read()
|
435 |
+
return hashlib.md5(content).hexdigest()
|
436 |
+
|
437 |
+
|
438 |
+
def get_ckpt_path(name, root, check=False):
|
439 |
+
assert name in URL_MAP
|
440 |
+
path = os.path.join(root, CKPT_MAP[name])
|
441 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
442 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
443 |
+
download(URL_MAP[name], path)
|
444 |
+
md5 = md5_hash(path)
|
445 |
+
assert md5 == MD5_MAP[name], md5
|
446 |
+
return path
|
447 |
+
|
448 |
+
|
449 |
+
class KeyNotFoundError(Exception):
|
450 |
+
def __init__(self, cause, keys=None, visited=None):
|
451 |
+
self.cause = cause
|
452 |
+
self.keys = keys
|
453 |
+
self.visited = visited
|
454 |
+
messages = list()
|
455 |
+
if keys is not None:
|
456 |
+
messages.append("Key not found: {}".format(keys))
|
457 |
+
if visited is not None:
|
458 |
+
messages.append("Visited: {}".format(visited))
|
459 |
+
messages.append("Cause:\n{}".format(cause))
|
460 |
+
message = "\n".join(messages)
|
461 |
+
super().__init__(message)
|
462 |
+
|
463 |
+
|
464 |
+
def retrieve(
|
465 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
466 |
+
):
|
467 |
+
"""Given a nested list or dict return the desired value at key expanding
|
468 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
469 |
+
is done in-place.
|
470 |
+
|
471 |
+
Parameters
|
472 |
+
----------
|
473 |
+
list_or_dict : list or dict
|
474 |
+
Possibly nested list or dictionary.
|
475 |
+
key : str
|
476 |
+
key/to/value, path like string describing all keys necessary to
|
477 |
+
consider to get to the desired value. List indices can also be
|
478 |
+
passed here.
|
479 |
+
splitval : str
|
480 |
+
String that defines the delimiter between keys of the
|
481 |
+
different depth levels in `key`.
|
482 |
+
default : obj
|
483 |
+
Value returned if :attr:`key` is not found.
|
484 |
+
expand : bool
|
485 |
+
Whether to expand callable nodes on the path or not.
|
486 |
+
|
487 |
+
Returns
|
488 |
+
-------
|
489 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
490 |
+
:attr:`key` is not found returns ``default``.
|
491 |
+
|
492 |
+
Raises
|
493 |
+
------
|
494 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
495 |
+
``None``.
|
496 |
+
"""
|
497 |
+
|
498 |
+
keys = key.split(splitval)
|
499 |
+
|
500 |
+
success = True
|
501 |
+
try:
|
502 |
+
visited = []
|
503 |
+
parent = None
|
504 |
+
last_key = None
|
505 |
+
for key in keys:
|
506 |
+
if callable(list_or_dict):
|
507 |
+
if not expand:
|
508 |
+
raise KeyNotFoundError(
|
509 |
+
ValueError(
|
510 |
+
"Trying to get past callable node with expand=False."
|
511 |
+
),
|
512 |
+
keys=keys,
|
513 |
+
visited=visited,
|
514 |
+
)
|
515 |
+
list_or_dict = list_or_dict()
|
516 |
+
parent[last_key] = list_or_dict
|
517 |
+
|
518 |
+
last_key = key
|
519 |
+
parent = list_or_dict
|
520 |
+
|
521 |
+
try:
|
522 |
+
if isinstance(list_or_dict, dict):
|
523 |
+
list_or_dict = list_or_dict[key]
|
524 |
+
else:
|
525 |
+
list_or_dict = list_or_dict[int(key)]
|
526 |
+
except (KeyError, IndexError, ValueError) as e:
|
527 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
528 |
+
|
529 |
+
visited += [key]
|
530 |
+
# final expansion of retrieved value
|
531 |
+
if expand and callable(list_or_dict):
|
532 |
+
list_or_dict = list_or_dict()
|
533 |
+
parent[last_key] = list_or_dict
|
534 |
+
except KeyNotFoundError as e:
|
535 |
+
if default is None:
|
536 |
+
raise e
|
537 |
+
else:
|
538 |
+
list_or_dict = default
|
539 |
+
success = False
|
540 |
+
|
541 |
+
if not pass_success:
|
542 |
+
return list_or_dict
|
543 |
+
else:
|
544 |
+
return list_or_dict, success
|
545 |
+
|
546 |
+
|
547 |
+
if __name__ == "__main__":
|
548 |
+
config = {"keya": "a",
|
549 |
+
"keyb": "b",
|
550 |
+
"keyc":
|
551 |
+
{"cc1": 1,
|
552 |
+
"cc2": 2,
|
553 |
+
}
|
554 |
+
}
|
555 |
+
from omegaconf import OmegaConf
|
556 |
+
|
557 |
+
config = OmegaConf.create(config)
|
558 |
+
print(config)
|
559 |
+
retrieve(config, "keya")
|
560 |
+
|
561 |
+
def instantiate_from_config(config):
|
562 |
+
|
563 |
+
if not "target" in config:
|
564 |
+
raise KeyError("Expected key `target` to instantiate.")
|
565 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
566 |
+
|
567 |
+
def get_obj_from_str(string, reload=False):
|
568 |
+
module, cls = string.rsplit(".", 1)
|
569 |
+
if reload:
|
570 |
+
module_imp = importlib.import_module(module)
|
571 |
+
importlib.reload(module_imp)
|
572 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
paintmind/stage1/__init__.py
ADDED
File without changes
|
paintmind/stage1/diffuse_slot.py
ADDED
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from paintmind.stage1.diffusion import create_diffusion
|
11 |
+
from paintmind.stage1.diffusion_transfomers import DiT
|
12 |
+
from paintmind.stage1.quantize import DiagonalGaussianDistribution
|
13 |
+
from paintmind.stage1.transport import create_transport, Sampler
|
14 |
+
|
15 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
16 |
+
from transformers import SiglipVisionModel, CLIPVisionModel
|
17 |
+
|
18 |
+
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
19 |
+
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
20 |
+
SIGLIP_DEFAULT_MEAN = (0.5, 0.5, 0.5)
|
21 |
+
SIGLIP_DEFAULT_STD = (0.5, 0.5, 0.5)
|
22 |
+
|
23 |
+
def build_mlp(hidden_size, projector_dim, z_dim):
|
24 |
+
return nn.Sequential(
|
25 |
+
nn.Linear(hidden_size, projector_dim),
|
26 |
+
nn.SiLU(),
|
27 |
+
nn.Linear(projector_dim, projector_dim),
|
28 |
+
nn.SiLU(),
|
29 |
+
nn.Linear(projector_dim, z_dim),
|
30 |
+
)
|
31 |
+
|
32 |
+
class DiT_with_autoenc_cond(DiT):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
*args,
|
36 |
+
num_autoenc=32,
|
37 |
+
autoenc_dim=4,
|
38 |
+
cond_method="adaln",
|
39 |
+
mask_type="simple",
|
40 |
+
class_cond=False,
|
41 |
+
use_repa=False,
|
42 |
+
z_dim=768,
|
43 |
+
encoder_depth=8,
|
44 |
+
projector_dim=2048,
|
45 |
+
**kwargs,
|
46 |
+
):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
self.autoenc_dim = autoenc_dim
|
49 |
+
self.class_cond = class_cond
|
50 |
+
self.mask_type = mask_type
|
51 |
+
self.hidden_size = kwargs["hidden_size"]
|
52 |
+
self.cond_drop_prob = self.y_embedder.dropout_prob # 0.1 without cond guidance
|
53 |
+
self.null_cond = nn.Parameter(torch.zeros(1, num_autoenc, autoenc_dim))
|
54 |
+
torch.nn.init.normal_(self.null_cond, std=.02)
|
55 |
+
# NOTE: adaln is adaptive layer normalization, token fed the cond to the attention layer
|
56 |
+
assert cond_method in [
|
57 |
+
"adaln",
|
58 |
+
"token",
|
59 |
+
"token+adaln",
|
60 |
+
], f"Invalid cond_method: {cond_method}"
|
61 |
+
self.cond_method = cond_method
|
62 |
+
if "token" in cond_method:
|
63 |
+
self.autoenc_cond_embedder = nn.Linear(autoenc_dim, self.hidden_size)
|
64 |
+
else:
|
65 |
+
self.autoenc_cond_embedder = nn.Linear(
|
66 |
+
num_autoenc * autoenc_dim, self.hidden_size
|
67 |
+
)
|
68 |
+
|
69 |
+
if cond_method == "token+adaln":
|
70 |
+
self.autoenc_proj_ln = nn.Linear(self.hidden_size, self.hidden_size)
|
71 |
+
|
72 |
+
if not class_cond:
|
73 |
+
self.y_embedder = nn.Identity()
|
74 |
+
|
75 |
+
self.use_repa = use_repa
|
76 |
+
self._repa_hook = None
|
77 |
+
self.encoder_depth = encoder_depth
|
78 |
+
if use_repa:
|
79 |
+
self.projector = build_mlp(self.hidden_size, projector_dim, z_dim)
|
80 |
+
|
81 |
+
def embed_cond(self, autoenc_cond, drop_mask=None):
|
82 |
+
# autoenc_cond: (N, K, D)
|
83 |
+
# drop_ids: (N)
|
84 |
+
# self.null_cond: (1, K, D)
|
85 |
+
# NOTE: this dropout will replace some condition from the autoencoder to null condition
|
86 |
+
# this is to enable classifier-free guidance.
|
87 |
+
batch_size = autoenc_cond.shape[0]
|
88 |
+
if drop_mask is None:
|
89 |
+
# randomly drop all conditions, for classifier-free guidance
|
90 |
+
if self.training:
|
91 |
+
drop_ids = (
|
92 |
+
torch.rand(batch_size, 1, 1, device=autoenc_cond.device)
|
93 |
+
< self.cond_drop_prob
|
94 |
+
)
|
95 |
+
autoenc_cond_drop = torch.where(drop_ids, self.null_cond, autoenc_cond)
|
96 |
+
else:
|
97 |
+
autoenc_cond_drop = autoenc_cond
|
98 |
+
else:
|
99 |
+
# randomly drop some conditions according to the drop_mask (N, K)
|
100 |
+
# True means keep
|
101 |
+
autoenc_cond_drop = torch.where(drop_mask[:, :, None], autoenc_cond, self.null_cond)
|
102 |
+
if "token" in self.cond_method:
|
103 |
+
return self.autoenc_cond_embedder(autoenc_cond_drop)
|
104 |
+
return self.autoenc_cond_embedder(autoenc_cond_drop.reshape(batch_size, -1))
|
105 |
+
|
106 |
+
# def forward(self, x, t, y, autoenc_cond):
|
107 |
+
def forward(self, x, t, autoenc_cond, drop_mask=None, y=None):
|
108 |
+
"""
|
109 |
+
Forward pass of DiT.
|
110 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
111 |
+
t: (N,) tensor of diffusion timesteps
|
112 |
+
y: (N,) tensor of class labels
|
113 |
+
autoenc_cond: (N, K, D) tensor of autoencoder conditions (slots)
|
114 |
+
"""
|
115 |
+
x = (
|
116 |
+
self.x_embedder(x) + self.pos_embed
|
117 |
+
) # (N, T, D), where T = H * W / patch_size ** 2
|
118 |
+
N, T, D = x.shape
|
119 |
+
|
120 |
+
c = self.t_embedder(t) # (N, D)
|
121 |
+
if y is not None and self.class_cond:
|
122 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
123 |
+
c = c + y # (N, D)
|
124 |
+
|
125 |
+
if self.mask_type == "replace":
|
126 |
+
autoenc = self.embed_cond(autoenc_cond, drop_mask)
|
127 |
+
else:
|
128 |
+
autoenc = self.embed_cond(autoenc_cond)
|
129 |
+
|
130 |
+
if self.cond_method == "adaln":
|
131 |
+
c = c + autoenc # add the encoder condition to adaln
|
132 |
+
elif self.cond_method == "token":
|
133 |
+
num_tokens = x.shape[1]
|
134 |
+
# append the autoencoder condition to the token sequence
|
135 |
+
x = torch.cat((x, autoenc), dim=1)
|
136 |
+
elif self.cond_method == "token+adaln":
|
137 |
+
c = c + self.autoenc_proj_ln(autoenc.mean(dim=1))
|
138 |
+
num_tokens = x.shape[1]
|
139 |
+
x = torch.cat((x, autoenc), dim=1)
|
140 |
+
else:
|
141 |
+
raise ValueError(f"Invalid cond_method: {self.cond_method}")
|
142 |
+
|
143 |
+
for i, block in enumerate(self.blocks):
|
144 |
+
if self.mask_type == "replace":
|
145 |
+
x = block(x, c) # (N, T, D)
|
146 |
+
else:
|
147 |
+
x = block(x, c, drop_mask) # (N, T, D)
|
148 |
+
if (i + 1) == self.encoder_depth and self.use_repa:
|
149 |
+
projected = self.projector(x)
|
150 |
+
self._repa_hook = projected[:, :num_tokens]
|
151 |
+
|
152 |
+
if "token" in self.cond_method:
|
153 |
+
x = x[:, :num_tokens]
|
154 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
155 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
156 |
+
return x
|
157 |
+
|
158 |
+
def forward_with_cfg(self, x, t, autoenc_cond, drop_mask, y=None, cfg_scale=1.0):
|
159 |
+
"""
|
160 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
161 |
+
"""
|
162 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
163 |
+
half = x[: len(x) // 2]
|
164 |
+
combined = torch.cat([half, half], dim=0)
|
165 |
+
model_out = self.forward(combined, t, autoenc_cond, drop_mask, y)
|
166 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
167 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
168 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
169 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
170 |
+
# eps, rest = model_out[:, :3], model_out[:, 3:]
|
171 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
172 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
173 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
174 |
+
return torch.cat([eps, rest], dim=1)
|
175 |
+
|
176 |
+
#################################################################################
|
177 |
+
# DiT Configs #
|
178 |
+
#################################################################################
|
179 |
+
|
180 |
+
|
181 |
+
def DiT_with_autoenc_cond_XL_2(**kwargs):
|
182 |
+
return DiT_with_autoenc_cond(
|
183 |
+
depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
def DiT_with_autoenc_cond_XL_4(**kwargs):
|
188 |
+
return DiT_with_autoenc_cond(
|
189 |
+
depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
def DiT_with_autoenc_cond_XL_8(**kwargs):
|
194 |
+
return DiT_with_autoenc_cond(
|
195 |
+
depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def DiT_with_autoenc_cond_L_2(**kwargs):
|
200 |
+
return DiT_with_autoenc_cond(
|
201 |
+
depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
def DiT_with_autoenc_cond_L_4(**kwargs):
|
206 |
+
return DiT_with_autoenc_cond(
|
207 |
+
depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
def DiT_with_autoenc_cond_L_8(**kwargs):
|
212 |
+
return DiT_with_autoenc_cond(
|
213 |
+
depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
def DiT_with_autoenc_cond_B_2(**kwargs):
|
218 |
+
return DiT_with_autoenc_cond(
|
219 |
+
depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs
|
220 |
+
)
|
221 |
+
|
222 |
+
|
223 |
+
def DiT_with_autoenc_cond_B_4(**kwargs):
|
224 |
+
return DiT_with_autoenc_cond(
|
225 |
+
depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
|
229 |
+
def DiT_with_autoenc_cond_B_8(**kwargs):
|
230 |
+
return DiT_with_autoenc_cond(
|
231 |
+
depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
def DiT_with_autoenc_cond_S_2(**kwargs):
|
236 |
+
return DiT_with_autoenc_cond(
|
237 |
+
depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
def DiT_with_autoenc_cond_S_4(**kwargs):
|
242 |
+
return DiT_with_autoenc_cond(
|
243 |
+
depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs
|
244 |
+
)
|
245 |
+
|
246 |
+
|
247 |
+
def DiT_with_autoenc_cond_S_8(**kwargs):
|
248 |
+
return DiT_with_autoenc_cond(
|
249 |
+
depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs
|
250 |
+
)
|
251 |
+
|
252 |
+
|
253 |
+
DiT_with_autoenc_cond_models = {
|
254 |
+
"DiT-XL-2": DiT_with_autoenc_cond_XL_2,
|
255 |
+
"DiT-XL-4": DiT_with_autoenc_cond_XL_4,
|
256 |
+
"DiT-XL-8": DiT_with_autoenc_cond_XL_8,
|
257 |
+
"DiT-L-2": DiT_with_autoenc_cond_L_2,
|
258 |
+
"DiT-L-4": DiT_with_autoenc_cond_L_4,
|
259 |
+
"DiT-L-8": DiT_with_autoenc_cond_L_8,
|
260 |
+
"DiT-B-2": DiT_with_autoenc_cond_B_2,
|
261 |
+
"DiT-B-4": DiT_with_autoenc_cond_B_4,
|
262 |
+
"DiT-B-8": DiT_with_autoenc_cond_B_8,
|
263 |
+
"DiT-S-2": DiT_with_autoenc_cond_S_2,
|
264 |
+
"DiT-S-4": DiT_with_autoenc_cond_S_4,
|
265 |
+
"DiT-S-8": DiT_with_autoenc_cond_S_8,
|
266 |
+
}
|
267 |
+
|
268 |
+
from torch.distributions import Geometric
|
269 |
+
|
270 |
+
class NestedSampler(nn.Module):
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
num_slots,
|
274 |
+
rho=0.03,
|
275 |
+
nest_dist="geometric",
|
276 |
+
mask_type="simple",
|
277 |
+
null_prob=0.1,
|
278 |
+
allow_zero=False,
|
279 |
+
one_slot_before=0, # Use only one slot before this epoch
|
280 |
+
):
|
281 |
+
super().__init__()
|
282 |
+
self.num_slots = num_slots
|
283 |
+
self.mask_type = mask_type
|
284 |
+
self.rho = rho
|
285 |
+
self.geometric = Geometric(rho)
|
286 |
+
self.nest_dist = nest_dist
|
287 |
+
self.null_prob = null_prob
|
288 |
+
self.allow_zero = allow_zero
|
289 |
+
self.one_slot_before = one_slot_before
|
290 |
+
self.register_buffer("arange", torch.arange(num_slots))
|
291 |
+
|
292 |
+
def _apply_epoch_constraint(self, samples, epoch=None):
|
293 |
+
if epoch is not None and epoch < self.one_slot_before:
|
294 |
+
return torch.ones_like(samples)
|
295 |
+
return samples
|
296 |
+
|
297 |
+
def _apply_null_prob(self, samples, num):
|
298 |
+
# First determine which samples will be 0 based on null_prob
|
299 |
+
null_mask = torch.rand(num, device=samples.device) < self.null_prob
|
300 |
+
# Replace with 0 where null_mask is True
|
301 |
+
return torch.where(null_mask, torch.zeros_like(samples), samples)
|
302 |
+
|
303 |
+
def geometric_sample(self, num):
|
304 |
+
return self.geometric.sample([num]) + int(not self.allow_zero)
|
305 |
+
|
306 |
+
def uniform_sample(self, num):
|
307 |
+
return torch.randint(int(not self.allow_zero), self.num_slots + 1, (num,))
|
308 |
+
|
309 |
+
def power2_uniform_sample(self, num):
|
310 |
+
# Get powers of 2 up to num_slots and add num_slots
|
311 |
+
choices = [2**i for i in range(int(math.log2(self.num_slots)) + 1)]
|
312 |
+
if self.num_slots not in choices:
|
313 |
+
choices.append(self.num_slots)
|
314 |
+
return torch.tensor(choices)[torch.randint(0, len(choices), (num,))]
|
315 |
+
|
316 |
+
def sample(self, num, epoch=None):
|
317 |
+
if self.nest_dist == "geometric":
|
318 |
+
samples = self.geometric_sample(num)
|
319 |
+
elif self.nest_dist == "uniform":
|
320 |
+
samples = self.uniform_sample(num)
|
321 |
+
elif self.nest_dist == "power2_uniform":
|
322 |
+
samples = self.power2_uniform_sample(num)
|
323 |
+
else:
|
324 |
+
raise ValueError(f"Invalid nest_dist: {self.nest_dist}")
|
325 |
+
samples = self._apply_epoch_constraint(samples, epoch)
|
326 |
+
return self._apply_null_prob(samples, num)
|
327 |
+
|
328 |
+
def forward(self, batch_size, num_patches, device, inference_with_n_slots=-1, coupled_value=None, epoch=None):
|
329 |
+
if self.training:
|
330 |
+
if coupled_value is None:
|
331 |
+
b = self.sample(batch_size, epoch).to(device)
|
332 |
+
else:
|
333 |
+
b = coupled_value.long().to(device)
|
334 |
+
else:
|
335 |
+
if inference_with_n_slots != -1:
|
336 |
+
b = torch.full((batch_size,), inference_with_n_slots, device=device)
|
337 |
+
else:
|
338 |
+
b = torch.full((batch_size,), self.num_slots, device=device)
|
339 |
+
b = torch.clamp(b, max=self.num_slots)
|
340 |
+
|
341 |
+
slot_mask = self.arange[None, :] < b[:, None] # (batch_size, num_slots)
|
342 |
+
if self.mask_type == "replace":
|
343 |
+
return slot_mask
|
344 |
+
else:
|
345 |
+
return self.get_cond_attn_mask(slot_mask.unsqueeze(1), num_patches, self.num_slots, device)
|
346 |
+
|
347 |
+
def get_cond_attn_mask(self, slot_mask, num_patches, num_slots, device):
|
348 |
+
num_tokens = num_patches + num_slots
|
349 |
+
batch_size = slot_mask.shape[0]
|
350 |
+
if self.mask_type == "simple":
|
351 |
+
attn_mask = torch.ones((batch_size, num_tokens, num_tokens), dtype=torch.bool, device=device)
|
352 |
+
attn_mask[:, :, num_patches:] = slot_mask.expand(-1, num_tokens, -1)
|
353 |
+
elif self.mask_type == "causal":
|
354 |
+
attn_mask = torch.zeros((batch_size, num_tokens, num_tokens), dtype=torch.bool, device=device)
|
355 |
+
# 1) patches can see each other
|
356 |
+
attn_mask[:, :num_patches, :num_patches] = True
|
357 |
+
# 2) pathes can not see the last few slots
|
358 |
+
slot_mask = slot_mask.expand(-1, num_patches, -1)
|
359 |
+
attn_mask[:, :num_patches, num_patches:] = slot_mask
|
360 |
+
# 3) remaining slots are causal to each other
|
361 |
+
causal_mask = torch.ones((num_slots, num_slots), dtype=torch.bool, device=device).tril(diagonal=0)
|
362 |
+
attn_mask[:, num_patches:, num_patches:] = causal_mask
|
363 |
+
# 4) only the first slot can see the patches
|
364 |
+
attn_mask[:, num_patches + 1:, :num_patches] = False
|
365 |
+
else:
|
366 |
+
raise NotImplementedError(f"Invalid mask_type: {self.mask_type}")
|
367 |
+
return attn_mask.unsqueeze(1) # (batch_size, 1, num_tokens, num_tokens)
|
368 |
+
|
369 |
+
class DiffuseSlot(nn.Module):
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
encoder="vit_base_patch16",
|
373 |
+
drop_path_rate=0.1,
|
374 |
+
enc_img_size=256,
|
375 |
+
enc_causal=True,
|
376 |
+
enc_use_mlp=False,
|
377 |
+
enc_hidden_dim=4096,
|
378 |
+
num_slots=16,
|
379 |
+
slot_dim=256,
|
380 |
+
slot_through=True,
|
381 |
+
norm_slots=False,
|
382 |
+
use_kl_loss=False,
|
383 |
+
kl_loss_weight=1e-6,
|
384 |
+
enable_nest=False,
|
385 |
+
enable_nest_after=-1,
|
386 |
+
nest_dist="geometric",
|
387 |
+
nest_rho=0.03,
|
388 |
+
nest_null_prob=0.1,
|
389 |
+
nest_allow_zero=False,
|
390 |
+
nest_one_slot_before=0,
|
391 |
+
coupled_sampling=False,
|
392 |
+
coupled_rho=-0.8,
|
393 |
+
dit_class_cond=False,
|
394 |
+
dit_mask_type="simple",
|
395 |
+
cond_method="adaln",
|
396 |
+
dit_model="DiT-B-4",
|
397 |
+
vae="stabilityai/sd-vae-ft-ema",
|
398 |
+
vae_path="pretrained_models/kl16.ckpt",
|
399 |
+
pretrained_dit=None,
|
400 |
+
pretrained_encoder=None,
|
401 |
+
freeze_dit=False,
|
402 |
+
freeze_vit_after=-1,
|
403 |
+
num_sampling_steps="ddim25",
|
404 |
+
ckpt_path=None,
|
405 |
+
ema_path=None,
|
406 |
+
use_repa=False,
|
407 |
+
repa_encoder="dinov2_vitb14",
|
408 |
+
repa_encoder_depth=8,
|
409 |
+
repa_loss_weight=1.0,
|
410 |
+
use_sit=False,
|
411 |
+
**kwargs,
|
412 |
+
):
|
413 |
+
super().__init__()
|
414 |
+
|
415 |
+
z_dim = 0
|
416 |
+
self.use_repa = use_repa
|
417 |
+
self.repa_encoder_name = repa_encoder
|
418 |
+
self.repa_loss_weight = repa_loss_weight
|
419 |
+
self.use_sit = use_sit
|
420 |
+
if use_repa:
|
421 |
+
if "dinov2" in repa_encoder:
|
422 |
+
if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder:
|
423 |
+
self.repa_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
424 |
+
elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder:
|
425 |
+
self.repa_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
|
426 |
+
else:
|
427 |
+
raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
|
428 |
+
self.repa_encoder.image_size = 224
|
429 |
+
elif "clip" in repa_encoder:
|
430 |
+
if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder or "vit-base" in repa_encoder:
|
431 |
+
self.repa_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")
|
432 |
+
elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder or "vit-large" in repa_encoder:
|
433 |
+
self.repa_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
|
434 |
+
else:
|
435 |
+
raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
|
436 |
+
self.repa_encoder.embed_dim = self.repa_encoder.config.hidden_size
|
437 |
+
self.repa_encoder.image_size = self.repa_encoder.config.image_size
|
438 |
+
elif "siglip2" in repa_encoder:
|
439 |
+
if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder or "vit-base" in repa_encoder:
|
440 |
+
self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip2-base-patch16-256")
|
441 |
+
elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder or "vit-large" in repa_encoder:
|
442 |
+
self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip2-large-patch16-256")
|
443 |
+
else:
|
444 |
+
raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
|
445 |
+
self.repa_encoder.embed_dim = self.repa_encoder.config.hidden_size
|
446 |
+
self.repa_encoder.image_size = self.repa_encoder.config.image_size
|
447 |
+
elif "siglip" in repa_encoder:
|
448 |
+
if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder or "vit-base" in repa_encoder:
|
449 |
+
self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-256")
|
450 |
+
elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder or "vit-large" in repa_encoder:
|
451 |
+
self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256")
|
452 |
+
else:
|
453 |
+
raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
|
454 |
+
self.repa_encoder.embed_dim = self.repa_encoder.config.hidden_size
|
455 |
+
self.repa_encoder.image_size = self.repa_encoder.config.image_size
|
456 |
+
else:
|
457 |
+
raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
|
458 |
+
for param in self.repa_encoder.parameters():
|
459 |
+
param.requires_grad = False
|
460 |
+
self.repa_encoder.eval()
|
461 |
+
z_dim = self.repa_encoder.embed_dim
|
462 |
+
|
463 |
+
# DiT part
|
464 |
+
if not use_sit:
|
465 |
+
self.diffusion = create_diffusion(timestep_respacing="")
|
466 |
+
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps)
|
467 |
+
else:
|
468 |
+
self.transport = create_transport()
|
469 |
+
self.sampler = Sampler(self.transport)
|
470 |
+
self.dit_input_size = enc_img_size // 8 if not "mar" in vae else enc_img_size // 16
|
471 |
+
self.dit_in_channels = 4 if not "mar" in vae else 16
|
472 |
+
self.dit = DiT_with_autoenc_cond_models[dit_model](
|
473 |
+
input_size=self.dit_input_size,
|
474 |
+
in_channels=self.dit_in_channels,
|
475 |
+
num_autoenc=num_slots,
|
476 |
+
autoenc_dim=slot_dim,
|
477 |
+
cond_method=cond_method,
|
478 |
+
class_cond=dit_class_cond,
|
479 |
+
mask_type=dit_mask_type,
|
480 |
+
use_repa=use_repa,
|
481 |
+
encoder_depth=repa_encoder_depth,
|
482 |
+
z_dim=z_dim,
|
483 |
+
learn_sigma=not use_sit,
|
484 |
+
)
|
485 |
+
self.dit_patch_size = self.dit.x_embedder.patch_size[0]
|
486 |
+
self.dit_num_patches = (self.dit_input_size // self.dit_patch_size) ** 2
|
487 |
+
self.pretrained_dit = pretrained_dit
|
488 |
+
if pretrained_dit is not None:
|
489 |
+
# now we load some pretrained model
|
490 |
+
dit_ckpt = torch.load(pretrained_dit, map_location="cpu")
|
491 |
+
msg = self.dit.load_state_dict(dit_ckpt, strict=False)
|
492 |
+
print("Load DiT from ckpt")
|
493 |
+
print(msg)
|
494 |
+
self.freeze_dit = freeze_dit
|
495 |
+
if freeze_dit:
|
496 |
+
assert pretrained_dit is not None, "pretrained_dit must be provided"
|
497 |
+
for param in self.dit.parameters():
|
498 |
+
param.requires_grad = False
|
499 |
+
|
500 |
+
if "mar" in vae:
|
501 |
+
from diffusers import AutoencoderKL
|
502 |
+
self.vae = AutoencoderKL.from_pretrained("xwen99/mar-vae-kl16")
|
503 |
+
self.scaling_factor = 0.2325
|
504 |
+
elif vae == "openai/consistency-decoder":
|
505 |
+
from diffusers import ConsistencyDecoderVAE
|
506 |
+
self.vae = ConsistencyDecoderVAE.from_pretrained(vae)
|
507 |
+
self.scaling_factor = 0.18215
|
508 |
+
else: # eg, "stabilityai/sd-vae-ft-ema"
|
509 |
+
from diffusers import AutoencoderKL
|
510 |
+
self.vae = AutoencoderKL.from_pretrained(vae)
|
511 |
+
self.scaling_factor = 0.18215
|
512 |
+
|
513 |
+
self.vae.eval().requires_grad_(False)
|
514 |
+
|
515 |
+
# image encoder part
|
516 |
+
import paintmind.stage1.vision_transformers as vision_transformer
|
517 |
+
|
518 |
+
self.enc_img_size = enc_img_size
|
519 |
+
self.enc_causal = enc_causal
|
520 |
+
encoder_fn = vision_transformer.__dict__[encoder]
|
521 |
+
|
522 |
+
self.encoder = encoder_fn(
|
523 |
+
img_size=[enc_img_size],
|
524 |
+
num_slots=num_slots,
|
525 |
+
slot_through=slot_through,
|
526 |
+
drop_path_rate=drop_path_rate,
|
527 |
+
)
|
528 |
+
self.num_slots = num_slots
|
529 |
+
self.norm_slots = norm_slots
|
530 |
+
self.use_kl_loss = use_kl_loss
|
531 |
+
self.kl_loss_weight = kl_loss_weight
|
532 |
+
self.num_channels = self.encoder.num_features
|
533 |
+
self.pretrained_encoder = pretrained_encoder
|
534 |
+
if pretrained_encoder is not None:
|
535 |
+
# __import__("ipdb").set_trace()
|
536 |
+
encoder_ckpt = torch.load(pretrained_encoder, map_location="cpu")
|
537 |
+
# drop pos_embed from ckpt
|
538 |
+
encoder_ckpt = {
|
539 |
+
k.replace("blocks.", "blocks.0."): v
|
540 |
+
for k, v in encoder_ckpt.items()
|
541 |
+
if not k.startswith("pos_embed")
|
542 |
+
}
|
543 |
+
msg = self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
544 |
+
print("Load encoder from ckpt")
|
545 |
+
print(msg)
|
546 |
+
|
547 |
+
if not enc_use_mlp:
|
548 |
+
self.encoder2slot = nn.Linear(self.num_channels, slot_dim * 2 if self.use_kl_loss else slot_dim)
|
549 |
+
else:
|
550 |
+
self.encoder2slot = nn.Sequential(
|
551 |
+
nn.Linear(self.num_channels, enc_hidden_dim),
|
552 |
+
nn.GELU(),
|
553 |
+
nn.Linear(enc_hidden_dim, slot_dim * 2 if self.use_kl_loss else slot_dim),
|
554 |
+
)
|
555 |
+
|
556 |
+
self.nested_sampler = NestedSampler(
|
557 |
+
num_slots,
|
558 |
+
rho=nest_rho,
|
559 |
+
nest_dist=nest_dist,
|
560 |
+
mask_type=dit_mask_type,
|
561 |
+
null_prob=nest_null_prob,
|
562 |
+
allow_zero=nest_allow_zero,
|
563 |
+
one_slot_before=nest_one_slot_before,
|
564 |
+
)
|
565 |
+
self.nest_allow_zero = nest_allow_zero
|
566 |
+
self.nest_rho = nest_rho
|
567 |
+
self.use_coupled_sampling = coupled_sampling
|
568 |
+
self.couple_sampling_rho = coupled_rho
|
569 |
+
self.enable_nest = enable_nest
|
570 |
+
self.enable_nest_after = enable_nest_after
|
571 |
+
self.freeze_vit_after = freeze_vit_after
|
572 |
+
self.current_epoch = 0
|
573 |
+
|
574 |
+
def coupled_sampling(self, timestamps):
|
575 |
+
"""
|
576 |
+
Convert timestamps to coupled num_slots values where higher timestamps
|
577 |
+
tend to produce lower num_slots values.
|
578 |
+
|
579 |
+
Args:
|
580 |
+
timestamps: Tensor of shape (batch_size,) with values in [0, 1000)
|
581 |
+
|
582 |
+
Returns:
|
583 |
+
Tensor of shape (batch_size,) with values in [1, num_slots + 1)
|
584 |
+
"""
|
585 |
+
# Normalize timestamps to [0, 1]
|
586 |
+
t_normalized = 1 - (timestamps.float() / timestamps.max())
|
587 |
+
# Scale to [1, num_slots + 1) and round to integers
|
588 |
+
adder = int(not self.nest_allow_zero)
|
589 |
+
scaled = adder + t_normalized * (self.num_slots + 1 - adder)
|
590 |
+
num_slots2use = scaled.long().clamp(adder, self.num_slots)
|
591 |
+
return num_slots2use
|
592 |
+
|
593 |
+
@torch.no_grad()
|
594 |
+
def vae_encode(self, x):
|
595 |
+
x = x * 2 - 1
|
596 |
+
x = self.vae.encode(x)
|
597 |
+
if hasattr(x, 'latent_dist'):
|
598 |
+
x = x.latent_dist
|
599 |
+
return x.sample().mul_(self.scaling_factor)
|
600 |
+
|
601 |
+
@torch.no_grad()
|
602 |
+
def vae_decode(self, z):
|
603 |
+
z = self.vae.decode(z / self.scaling_factor)
|
604 |
+
if hasattr(z, 'sample'):
|
605 |
+
z = z.sample
|
606 |
+
return (z + 1) / 2
|
607 |
+
|
608 |
+
@torch.no_grad()
|
609 |
+
def repa_encode(self, x):
|
610 |
+
if "dinov2" in self.repa_encoder_name:
|
611 |
+
mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
612 |
+
std = torch.Tensor(IMAGENET_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
613 |
+
elif "clip" in self.repa_encoder_name:
|
614 |
+
mean = torch.Tensor(CLIP_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
615 |
+
std = torch.Tensor(CLIP_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
616 |
+
elif "siglip" in self.repa_encoder_name:
|
617 |
+
mean = torch.Tensor(SIGLIP_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
618 |
+
std = torch.Tensor(SIGLIP_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
619 |
+
else:
|
620 |
+
raise ValueError(f"Invalid repa_encoder: {self.repa_encoder_name}")
|
621 |
+
x = (x - mean) / std
|
622 |
+
if self.repa_encoder.image_size != self.enc_img_size:
|
623 |
+
x = torch.nn.functional.interpolate(x, self.repa_encoder.image_size, mode='bicubic')
|
624 |
+
if "dinov2" in self.repa_encoder_name:
|
625 |
+
x = self.repa_encoder.forward_features(x)['x_norm_patchtokens']
|
626 |
+
else:
|
627 |
+
x = self.repa_encoder(x)["last_hidden_state"]
|
628 |
+
return x
|
629 |
+
|
630 |
+
def encode_slots(self, x):
|
631 |
+
if self.pretrained_encoder is not None:
|
632 |
+
x = F.interpolate(x, size=224, mode='bicubic')
|
633 |
+
slots = self.encoder(x, is_causal=self.enc_causal)
|
634 |
+
slots = self.encoder2slot(slots)
|
635 |
+
if self.norm_slots:
|
636 |
+
if not self.use_kl_loss:
|
637 |
+
slots_std = torch.std(slots, dim=-1, keepdim=True)
|
638 |
+
slots_mean = torch.mean(slots, dim=-1, keepdim=True)
|
639 |
+
slots = (slots - slots_mean) / slots_std # this works better than kl loss
|
640 |
+
else:
|
641 |
+
slots = DiagonalGaussianDistribution(slots)
|
642 |
+
return slots
|
643 |
+
|
644 |
+
def forward_with_latents(self,
|
645 |
+
x_vae,
|
646 |
+
slots,
|
647 |
+
z,
|
648 |
+
targets=None,
|
649 |
+
sample=False,
|
650 |
+
epoch=None,
|
651 |
+
inference_with_n_slots=-1,
|
652 |
+
cfg=1.0):
|
653 |
+
losses = {}
|
654 |
+
batch_size = x_vae.shape[0]
|
655 |
+
num_patches = self.dit_num_patches
|
656 |
+
device = x_vae.device
|
657 |
+
|
658 |
+
if (
|
659 |
+
epoch is not None
|
660 |
+
and epoch >= self.enable_nest_after
|
661 |
+
and self.enable_nest_after != -1
|
662 |
+
):
|
663 |
+
self.enable_nest = True
|
664 |
+
|
665 |
+
t = torch.randint(0, 1000, (x_vae.shape[0],), device=device)
|
666 |
+
|
667 |
+
if self.enable_nest or inference_with_n_slots != -1:
|
668 |
+
if self.use_coupled_sampling:
|
669 |
+
num_slots2use = self.coupled_sampling(t)
|
670 |
+
else:
|
671 |
+
num_slots2use = None
|
672 |
+
drop_mask = self.nested_sampler(
|
673 |
+
batch_size, num_patches, device,
|
674 |
+
inference_with_n_slots=inference_with_n_slots,
|
675 |
+
coupled_value=num_slots2use,
|
676 |
+
epoch=epoch
|
677 |
+
)
|
678 |
+
else:
|
679 |
+
drop_mask = None
|
680 |
+
|
681 |
+
if sample:
|
682 |
+
return self.sample(slots if not self.use_kl_loss else slots.sample(), drop_mask=drop_mask, targets=targets, cfg=cfg)
|
683 |
+
|
684 |
+
model_kwargs = dict(autoenc_cond=slots if not self.use_kl_loss else slots.sample(), drop_mask=drop_mask, y=targets)
|
685 |
+
if not self.use_sit:
|
686 |
+
loss_dict = self.diffusion.training_losses(self.dit, x_vae, t, model_kwargs)
|
687 |
+
else:
|
688 |
+
loss_dict = self.transport.training_losses(self.dit, x_vae, model_kwargs)
|
689 |
+
diff_loss = loss_dict["loss"].mean()
|
690 |
+
losses["diff_loss"] = diff_loss
|
691 |
+
|
692 |
+
if self.use_kl_loss:
|
693 |
+
kl_loss = slots.kl()
|
694 |
+
losses["kl_loss"] = kl_loss.mean() * self.kl_loss_weight
|
695 |
+
|
696 |
+
if self.use_repa:
|
697 |
+
assert self.dit._repa_hook is not None and z is not None
|
698 |
+
z_tilde = self.dit._repa_hook
|
699 |
+
|
700 |
+
if z_tilde.shape[1] != z.shape[1]:
|
701 |
+
z_tilde = interpolate_features(z_tilde, z.shape[1])
|
702 |
+
|
703 |
+
assert z_tilde.shape[-1] == z.shape[-1], f"Feature dimensions don't match: {z_tilde.shape} vs {z.shape}"
|
704 |
+
|
705 |
+
z_tilde = F.normalize(z_tilde, dim=-1)
|
706 |
+
z = F.normalize(z, dim=-1)
|
707 |
+
repa_loss = -torch.sum(z_tilde * z, dim=-1)
|
708 |
+
losses["repa_loss"] = repa_loss.mean() * self.repa_loss_weight
|
709 |
+
|
710 |
+
return losses
|
711 |
+
|
712 |
+
|
713 |
+
def forward(self,
|
714 |
+
x,
|
715 |
+
targets=None,
|
716 |
+
latents=None,
|
717 |
+
sample=False,
|
718 |
+
epoch=None,
|
719 |
+
inference_with_n_slots=-1,
|
720 |
+
cfg=1.0):
|
721 |
+
|
722 |
+
# it will be used in train() and decide whether to set the encoder to eval mode
|
723 |
+
if epoch is not None:
|
724 |
+
self.current_epoch = epoch
|
725 |
+
|
726 |
+
if latents is None:
|
727 |
+
x_vae = self.vae_encode(x) # (N, C, H, W)
|
728 |
+
else:
|
729 |
+
x_vae = latents
|
730 |
+
|
731 |
+
if self.use_repa:
|
732 |
+
z = self.repa_encode(x)
|
733 |
+
else:
|
734 |
+
z = None
|
735 |
+
|
736 |
+
slots = self.encode_slots(x)
|
737 |
+
return self.forward_with_latents(x_vae, slots, z, targets, sample, epoch, inference_with_n_slots, cfg)
|
738 |
+
|
739 |
+
|
740 |
+
@torch.no_grad()
|
741 |
+
def sample(self, slots, drop_mask=None, targets=None, cfg=1.0):
|
742 |
+
batch_size = slots.shape[0]
|
743 |
+
device = slots.device
|
744 |
+
z = torch.randn(batch_size, self.dit_in_channels, self.dit_input_size, self.dit_input_size, device=device)
|
745 |
+
if cfg != 1.0:
|
746 |
+
z = torch.cat([z, z], 0)
|
747 |
+
null_slots = self.dit.null_cond.expand(batch_size, -1, -1)
|
748 |
+
slots = torch.cat([slots, null_slots], 0)
|
749 |
+
if drop_mask is not None:
|
750 |
+
null_cond_mask = torch.ones_like(drop_mask)
|
751 |
+
drop_mask = torch.cat([drop_mask, null_cond_mask], 0)
|
752 |
+
if targets is not None:
|
753 |
+
targets = torch.cat([targets, targets], 0)
|
754 |
+
model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask, y=targets, cfg_scale=cfg)
|
755 |
+
sample_fn = self.dit.forward_with_cfg
|
756 |
+
else:
|
757 |
+
model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask, y=targets)
|
758 |
+
sample_fn = self.dit.forward
|
759 |
+
# Sample images:
|
760 |
+
if not self.use_sit:
|
761 |
+
samples = self.gen_diffusion.p_sample_loop(
|
762 |
+
sample_fn,
|
763 |
+
z.shape,
|
764 |
+
z,
|
765 |
+
clip_denoised=False,
|
766 |
+
model_kwargs=model_kwargs,
|
767 |
+
progress=False,
|
768 |
+
device=device,
|
769 |
+
)
|
770 |
+
else:
|
771 |
+
sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma")
|
772 |
+
samples = sde_sample_fn(z, sample_fn, **model_kwargs)[-1]
|
773 |
+
if cfg != 1.0:
|
774 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
775 |
+
samples = self.vae_decode(samples)
|
776 |
+
return samples
|
777 |
+
|
778 |
+
def cancel_gradients_encoder(self, epoch):
|
779 |
+
"""Cancel gradients for encoder components after backward pass"""
|
780 |
+
if (epoch is not None
|
781 |
+
and epoch >= self.freeze_vit_after
|
782 |
+
and self.freeze_vit_after != -1):
|
783 |
+
# Directly access parameters from the modules
|
784 |
+
for p in self.encoder.parameters():
|
785 |
+
if p.grad is not None:
|
786 |
+
p.grad = None
|
787 |
+
for p in self.encoder2slot.parameters():
|
788 |
+
if p.grad is not None:
|
789 |
+
p.grad = None
|
790 |
+
|
791 |
+
def train(self, mode=True):
|
792 |
+
"""Override train() to keep certain components in eval mode"""
|
793 |
+
super().train(mode)
|
794 |
+
# VAE should always be in eval mode
|
795 |
+
self.vae.eval()
|
796 |
+
|
797 |
+
# Keep encoder in eval mode if frozen
|
798 |
+
if (self.freeze_vit_after != -1 and
|
799 |
+
hasattr(self, 'current_epoch') and
|
800 |
+
self.current_epoch >= self.freeze_vit_after):
|
801 |
+
self.encoder.eval()
|
802 |
+
self.encoder2slot.eval()
|
803 |
+
|
804 |
+
# Keep DiT in eval mode if frozen
|
805 |
+
if self.freeze_dit:
|
806 |
+
self.dit.eval()
|
807 |
+
|
808 |
+
return self
|
paintmind/stage1/diffusion/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from . import gaussian_diffusion as gd
|
7 |
+
from .respace import SpacedDiffusion, space_timesteps
|
8 |
+
|
9 |
+
|
10 |
+
def create_diffusion(
|
11 |
+
timestep_respacing,
|
12 |
+
noise_schedule="linear",
|
13 |
+
use_kl=False,
|
14 |
+
sigma_small=False,
|
15 |
+
predict_xstart=False,
|
16 |
+
learn_sigma=True,
|
17 |
+
rescale_learned_sigmas=False,
|
18 |
+
diffusion_steps=1000
|
19 |
+
):
|
20 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
21 |
+
if use_kl:
|
22 |
+
loss_type = gd.LossType.RESCALED_KL
|
23 |
+
elif rescale_learned_sigmas:
|
24 |
+
loss_type = gd.LossType.RESCALED_MSE
|
25 |
+
else:
|
26 |
+
loss_type = gd.LossType.MSE
|
27 |
+
if timestep_respacing is None or timestep_respacing == "":
|
28 |
+
timestep_respacing = [diffusion_steps]
|
29 |
+
return SpacedDiffusion(
|
30 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
31 |
+
betas=betas,
|
32 |
+
model_mean_type=(
|
33 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
34 |
+
),
|
35 |
+
model_var_type=(
|
36 |
+
(
|
37 |
+
gd.ModelVarType.FIXED_LARGE
|
38 |
+
if not sigma_small
|
39 |
+
else gd.ModelVarType.FIXED_SMALL
|
40 |
+
)
|
41 |
+
if not learn_sigma
|
42 |
+
else gd.ModelVarType.LEARNED_RANGE
|
43 |
+
),
|
44 |
+
loss_type=loss_type
|
45 |
+
# rescale_timesteps=rescale_timesteps,
|
46 |
+
)
|
paintmind/stage1/diffusion/diffusion_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import torch as th
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [
|
26 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
27 |
+
for x in (logvar1, logvar2)
|
28 |
+
]
|
29 |
+
|
30 |
+
return 0.5 * (
|
31 |
+
-1.0
|
32 |
+
+ logvar2
|
33 |
+
- logvar1
|
34 |
+
+ th.exp(logvar1 - logvar2)
|
35 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def approx_standard_normal_cdf(x):
|
40 |
+
"""
|
41 |
+
A fast approximation of the cumulative distribution function of the
|
42 |
+
standard normal.
|
43 |
+
"""
|
44 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
45 |
+
|
46 |
+
|
47 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
48 |
+
"""
|
49 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
50 |
+
:param x: the targets
|
51 |
+
:param means: the Gaussian mean Tensor.
|
52 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
53 |
+
:return: a tensor like x of log probabilities (in nats).
|
54 |
+
"""
|
55 |
+
centered_x = x - means
|
56 |
+
inv_stdv = th.exp(-log_scales)
|
57 |
+
normalized_x = centered_x * inv_stdv
|
58 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
59 |
+
return log_probs
|
60 |
+
|
61 |
+
|
62 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
63 |
+
"""
|
64 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
65 |
+
given image.
|
66 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
67 |
+
rescaled to the range [-1, 1].
|
68 |
+
:param means: the Gaussian mean Tensor.
|
69 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
70 |
+
:return: a tensor like x of log probabilities (in nats).
|
71 |
+
"""
|
72 |
+
assert x.shape == means.shape == log_scales.shape
|
73 |
+
centered_x = x - means
|
74 |
+
inv_stdv = th.exp(-log_scales)
|
75 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
76 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
77 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
78 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
79 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
80 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
81 |
+
cdf_delta = cdf_plus - cdf_min
|
82 |
+
log_probs = th.where(
|
83 |
+
x < -0.999,
|
84 |
+
log_cdf_plus,
|
85 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
86 |
+
)
|
87 |
+
assert log_probs.shape == x.shape
|
88 |
+
return log_probs
|
paintmind/stage1/diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,886 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
import enum
|
12 |
+
|
13 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
+
|
15 |
+
|
16 |
+
def mean_flat(tensor):
|
17 |
+
"""
|
18 |
+
Take the mean over all non-batch dimensions.
|
19 |
+
"""
|
20 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
21 |
+
|
22 |
+
|
23 |
+
class ModelMeanType(enum.Enum):
|
24 |
+
"""
|
25 |
+
Which type of output the model predicts.
|
26 |
+
"""
|
27 |
+
|
28 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
29 |
+
START_X = enum.auto() # the model predicts x_0
|
30 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
31 |
+
|
32 |
+
|
33 |
+
class ModelVarType(enum.Enum):
|
34 |
+
"""
|
35 |
+
What is used as the model's output variance.
|
36 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
37 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
38 |
+
"""
|
39 |
+
|
40 |
+
LEARNED = enum.auto()
|
41 |
+
FIXED_SMALL = enum.auto()
|
42 |
+
FIXED_LARGE = enum.auto()
|
43 |
+
LEARNED_RANGE = enum.auto()
|
44 |
+
|
45 |
+
|
46 |
+
class LossType(enum.Enum):
|
47 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
48 |
+
RESCALED_MSE = (
|
49 |
+
enum.auto()
|
50 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
51 |
+
KL = enum.auto() # use the variational lower-bound
|
52 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
53 |
+
|
54 |
+
def is_vb(self):
|
55 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
56 |
+
|
57 |
+
|
58 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
59 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
60 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
61 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
62 |
+
return betas
|
63 |
+
|
64 |
+
|
65 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
66 |
+
"""
|
67 |
+
This is the deprecated API for creating beta schedules.
|
68 |
+
See get_named_beta_schedule() for the new library of schedules.
|
69 |
+
"""
|
70 |
+
if beta_schedule == "quad":
|
71 |
+
betas = (
|
72 |
+
np.linspace(
|
73 |
+
beta_start ** 0.5,
|
74 |
+
beta_end ** 0.5,
|
75 |
+
num_diffusion_timesteps,
|
76 |
+
dtype=np.float64,
|
77 |
+
)
|
78 |
+
** 2
|
79 |
+
)
|
80 |
+
elif beta_schedule == "linear":
|
81 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
82 |
+
elif beta_schedule == "warmup10":
|
83 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
84 |
+
elif beta_schedule == "warmup50":
|
85 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
86 |
+
elif beta_schedule == "const":
|
87 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
88 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
89 |
+
betas = 1.0 / np.linspace(
|
90 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError(beta_schedule)
|
94 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
95 |
+
return betas
|
96 |
+
|
97 |
+
|
98 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
99 |
+
"""
|
100 |
+
Get a pre-defined beta schedule for the given name.
|
101 |
+
The beta schedule library consists of beta schedules which remain similar
|
102 |
+
in the limit of num_diffusion_timesteps.
|
103 |
+
Beta schedules may be added, but should not be removed or changed once
|
104 |
+
they are committed to maintain backwards compatibility.
|
105 |
+
"""
|
106 |
+
if schedule_name == "linear":
|
107 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
108 |
+
# diffusion steps.
|
109 |
+
scale = 1000 / num_diffusion_timesteps
|
110 |
+
return get_beta_schedule(
|
111 |
+
"linear",
|
112 |
+
beta_start=scale * 0.0001,
|
113 |
+
beta_end=scale * 0.02,
|
114 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
115 |
+
)
|
116 |
+
elif schedule_name == "cosine":
|
117 |
+
return betas_for_alpha_bar(
|
118 |
+
num_diffusion_timesteps,
|
119 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
123 |
+
|
124 |
+
|
125 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
126 |
+
"""
|
127 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
128 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
129 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
130 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
131 |
+
produces the cumulative product of (1-beta) up to that
|
132 |
+
part of the diffusion process.
|
133 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
134 |
+
prevent singularities.
|
135 |
+
"""
|
136 |
+
betas = []
|
137 |
+
for i in range(num_diffusion_timesteps):
|
138 |
+
t1 = i / num_diffusion_timesteps
|
139 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
140 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
141 |
+
return np.array(betas)
|
142 |
+
|
143 |
+
|
144 |
+
class GaussianDiffusion:
|
145 |
+
"""
|
146 |
+
Utilities for training and sampling diffusion models.
|
147 |
+
Original ported from this codebase:
|
148 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
149 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
150 |
+
starting at T and going to 1.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
*,
|
156 |
+
betas,
|
157 |
+
model_mean_type,
|
158 |
+
model_var_type,
|
159 |
+
loss_type
|
160 |
+
):
|
161 |
+
|
162 |
+
self.model_mean_type = model_mean_type
|
163 |
+
self.model_var_type = model_var_type
|
164 |
+
self.loss_type = loss_type
|
165 |
+
|
166 |
+
# Use float64 for accuracy.
|
167 |
+
betas = np.array(betas, dtype=np.float64)
|
168 |
+
self.betas = betas
|
169 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
170 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
171 |
+
|
172 |
+
self.num_timesteps = int(betas.shape[0])
|
173 |
+
|
174 |
+
alphas = 1.0 - betas
|
175 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
176 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
177 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
178 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
179 |
+
|
180 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
181 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
182 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
183 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
184 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
185 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
186 |
+
|
187 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
188 |
+
self.posterior_variance = (
|
189 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
190 |
+
)
|
191 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
192 |
+
self.posterior_log_variance_clipped = np.log(
|
193 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
194 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
195 |
+
|
196 |
+
self.posterior_mean_coef1 = (
|
197 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
198 |
+
)
|
199 |
+
self.posterior_mean_coef2 = (
|
200 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
201 |
+
)
|
202 |
+
|
203 |
+
def q_mean_variance(self, x_start, t):
|
204 |
+
"""
|
205 |
+
Get the distribution q(x_t | x_0).
|
206 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
207 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
208 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
209 |
+
"""
|
210 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
211 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
212 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
213 |
+
return mean, variance, log_variance
|
214 |
+
|
215 |
+
def q_sample(self, x_start, t, noise=None):
|
216 |
+
"""
|
217 |
+
Diffuse the data for a given number of diffusion steps.
|
218 |
+
In other words, sample from q(x_t | x_0).
|
219 |
+
:param x_start: the initial data batch.
|
220 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
221 |
+
:param noise: if specified, the split-out normal noise.
|
222 |
+
:return: A noisy version of x_start.
|
223 |
+
"""
|
224 |
+
if noise is None:
|
225 |
+
noise = th.randn_like(x_start)
|
226 |
+
assert noise.shape == x_start.shape
|
227 |
+
return (
|
228 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
229 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
230 |
+
)
|
231 |
+
|
232 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
233 |
+
"""
|
234 |
+
Compute the mean and variance of the diffusion posterior:
|
235 |
+
q(x_{t-1} | x_t, x_0)
|
236 |
+
"""
|
237 |
+
assert x_start.shape == x_t.shape
|
238 |
+
posterior_mean = (
|
239 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
240 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
241 |
+
)
|
242 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
243 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
244 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
245 |
+
)
|
246 |
+
assert (
|
247 |
+
posterior_mean.shape[0]
|
248 |
+
== posterior_variance.shape[0]
|
249 |
+
== posterior_log_variance_clipped.shape[0]
|
250 |
+
== x_start.shape[0]
|
251 |
+
)
|
252 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
253 |
+
|
254 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
255 |
+
"""
|
256 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
257 |
+
the initial x, x_0.
|
258 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
259 |
+
as input.
|
260 |
+
:param x: the [N x C x ...] tensor at time t.
|
261 |
+
:param t: a 1-D Tensor of timesteps.
|
262 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
263 |
+
:param denoised_fn: if not None, a function which applies to the
|
264 |
+
x_start prediction before it is used to sample. Applies before
|
265 |
+
clip_denoised.
|
266 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
267 |
+
pass to the model. This can be used for conditioning.
|
268 |
+
:return: a dict with the following keys:
|
269 |
+
- 'mean': the model mean output.
|
270 |
+
- 'variance': the model variance output.
|
271 |
+
- 'log_variance': the log of 'variance'.
|
272 |
+
- 'pred_xstart': the prediction for x_0.
|
273 |
+
"""
|
274 |
+
if model_kwargs is None:
|
275 |
+
model_kwargs = {}
|
276 |
+
|
277 |
+
B, C = x.shape[:2]
|
278 |
+
assert t.shape == (B,)
|
279 |
+
model_output = model(x, t, **model_kwargs)
|
280 |
+
if isinstance(model_output, tuple):
|
281 |
+
model_output, extra = model_output
|
282 |
+
else:
|
283 |
+
extra = None
|
284 |
+
|
285 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
286 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
287 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
288 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
289 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
290 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
291 |
+
frac = (model_var_values + 1) / 2
|
292 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
293 |
+
model_variance = th.exp(model_log_variance)
|
294 |
+
else:
|
295 |
+
model_variance, model_log_variance = {
|
296 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
297 |
+
# to get a better decoder log likelihood.
|
298 |
+
ModelVarType.FIXED_LARGE: (
|
299 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
300 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
301 |
+
),
|
302 |
+
ModelVarType.FIXED_SMALL: (
|
303 |
+
self.posterior_variance,
|
304 |
+
self.posterior_log_variance_clipped,
|
305 |
+
),
|
306 |
+
}[self.model_var_type]
|
307 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
308 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
309 |
+
|
310 |
+
def process_xstart(x):
|
311 |
+
if denoised_fn is not None:
|
312 |
+
x = denoised_fn(x)
|
313 |
+
if clip_denoised:
|
314 |
+
return x.clamp(-1, 1)
|
315 |
+
return x
|
316 |
+
|
317 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
318 |
+
pred_xstart = process_xstart(model_output)
|
319 |
+
else:
|
320 |
+
pred_xstart = process_xstart(
|
321 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
322 |
+
)
|
323 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
324 |
+
|
325 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
326 |
+
return {
|
327 |
+
"mean": model_mean,
|
328 |
+
"variance": model_variance,
|
329 |
+
"log_variance": model_log_variance,
|
330 |
+
"pred_xstart": pred_xstart,
|
331 |
+
"extra": extra,
|
332 |
+
}
|
333 |
+
|
334 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
335 |
+
assert x_t.shape == eps.shape
|
336 |
+
return (
|
337 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
338 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
339 |
+
)
|
340 |
+
|
341 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
342 |
+
return (
|
343 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
344 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
345 |
+
|
346 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
347 |
+
"""
|
348 |
+
Compute the mean for the previous step, given a function cond_fn that
|
349 |
+
computes the gradient of a conditional log probability with respect to
|
350 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
351 |
+
condition on y.
|
352 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
353 |
+
"""
|
354 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
355 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
356 |
+
return new_mean
|
357 |
+
|
358 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
359 |
+
"""
|
360 |
+
Compute what the p_mean_variance output would have been, should the
|
361 |
+
model's score function be conditioned by cond_fn.
|
362 |
+
See condition_mean() for details on cond_fn.
|
363 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
364 |
+
from Song et al (2020).
|
365 |
+
"""
|
366 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
367 |
+
|
368 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
369 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
370 |
+
|
371 |
+
out = p_mean_var.copy()
|
372 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
373 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
374 |
+
return out
|
375 |
+
|
376 |
+
def p_sample(
|
377 |
+
self,
|
378 |
+
model,
|
379 |
+
x,
|
380 |
+
t,
|
381 |
+
clip_denoised=True,
|
382 |
+
denoised_fn=None,
|
383 |
+
cond_fn=None,
|
384 |
+
model_kwargs=None,
|
385 |
+
temperature=1.0
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Sample x_{t-1} from the model at the given timestep.
|
389 |
+
:param model: the model to sample from.
|
390 |
+
:param x: the current tensor at x_{t-1}.
|
391 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
392 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
393 |
+
:param denoised_fn: if not None, a function which applies to the
|
394 |
+
x_start prediction before it is used to sample.
|
395 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
396 |
+
similarly to the model.
|
397 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
398 |
+
pass to the model. This can be used for conditioning.
|
399 |
+
:param temperature: temperature scaling during Diff Loss sampling.
|
400 |
+
:return: a dict containing the following keys:
|
401 |
+
- 'sample': a random sample from the model.
|
402 |
+
- 'pred_xstart': a prediction of x_0.
|
403 |
+
"""
|
404 |
+
out = self.p_mean_variance(
|
405 |
+
model,
|
406 |
+
x,
|
407 |
+
t,
|
408 |
+
clip_denoised=clip_denoised,
|
409 |
+
denoised_fn=denoised_fn,
|
410 |
+
model_kwargs=model_kwargs,
|
411 |
+
)
|
412 |
+
noise = th.randn_like(x)
|
413 |
+
nonzero_mask = (
|
414 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
415 |
+
) # no noise when t == 0
|
416 |
+
if cond_fn is not None:
|
417 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
418 |
+
# scale the noise by temperature
|
419 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
|
420 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
421 |
+
|
422 |
+
def p_sample_loop(
|
423 |
+
self,
|
424 |
+
model,
|
425 |
+
shape,
|
426 |
+
noise=None,
|
427 |
+
clip_denoised=True,
|
428 |
+
denoised_fn=None,
|
429 |
+
cond_fn=None,
|
430 |
+
model_kwargs=None,
|
431 |
+
device=None,
|
432 |
+
progress=False,
|
433 |
+
temperature=1.0,
|
434 |
+
):
|
435 |
+
"""
|
436 |
+
Generate samples from the model.
|
437 |
+
:param model: the model module.
|
438 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
439 |
+
:param noise: if specified, the noise from the encoder to sample.
|
440 |
+
Should be of the same shape as `shape`.
|
441 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
442 |
+
:param denoised_fn: if not None, a function which applies to the
|
443 |
+
x_start prediction before it is used to sample.
|
444 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
445 |
+
similarly to the model.
|
446 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
447 |
+
pass to the model. This can be used for conditioning.
|
448 |
+
:param device: if specified, the device to create the samples on.
|
449 |
+
If not specified, use a model parameter's device.
|
450 |
+
:param progress: if True, show a tqdm progress bar.
|
451 |
+
:param temperature: temperature scaling during Diff Loss sampling.
|
452 |
+
:return: a non-differentiable batch of samples.
|
453 |
+
"""
|
454 |
+
final = None
|
455 |
+
for sample in self.p_sample_loop_progressive(
|
456 |
+
model,
|
457 |
+
shape,
|
458 |
+
noise=noise,
|
459 |
+
clip_denoised=clip_denoised,
|
460 |
+
denoised_fn=denoised_fn,
|
461 |
+
cond_fn=cond_fn,
|
462 |
+
model_kwargs=model_kwargs,
|
463 |
+
device=device,
|
464 |
+
progress=progress,
|
465 |
+
temperature=temperature,
|
466 |
+
):
|
467 |
+
final = sample
|
468 |
+
return final["sample"]
|
469 |
+
|
470 |
+
def p_sample_loop_progressive(
|
471 |
+
self,
|
472 |
+
model,
|
473 |
+
shape,
|
474 |
+
noise=None,
|
475 |
+
clip_denoised=True,
|
476 |
+
denoised_fn=None,
|
477 |
+
cond_fn=None,
|
478 |
+
model_kwargs=None,
|
479 |
+
device=None,
|
480 |
+
progress=False,
|
481 |
+
temperature=1.0,
|
482 |
+
):
|
483 |
+
"""
|
484 |
+
Generate samples from the model and yield intermediate samples from
|
485 |
+
each timestep of diffusion.
|
486 |
+
Arguments are the same as p_sample_loop().
|
487 |
+
Returns a generator over dicts, where each dict is the return value of
|
488 |
+
p_sample().
|
489 |
+
"""
|
490 |
+
if device is None:
|
491 |
+
device = next(model.parameters()).device
|
492 |
+
assert isinstance(shape, (tuple, list))
|
493 |
+
if noise is not None:
|
494 |
+
img = noise
|
495 |
+
else:
|
496 |
+
img = th.randn(*shape, device=device)
|
497 |
+
indices = list(range(self.num_timesteps))[::-1]
|
498 |
+
|
499 |
+
if progress:
|
500 |
+
# Lazy import so that we don't depend on tqdm.
|
501 |
+
from tqdm.auto import tqdm
|
502 |
+
|
503 |
+
indices = tqdm(indices)
|
504 |
+
|
505 |
+
for i in indices:
|
506 |
+
t = th.tensor([i] * shape[0], device=device)
|
507 |
+
with th.no_grad():
|
508 |
+
out = self.p_sample(
|
509 |
+
model,
|
510 |
+
img,
|
511 |
+
t,
|
512 |
+
clip_denoised=clip_denoised,
|
513 |
+
denoised_fn=denoised_fn,
|
514 |
+
cond_fn=cond_fn,
|
515 |
+
model_kwargs=model_kwargs,
|
516 |
+
temperature=temperature,
|
517 |
+
)
|
518 |
+
yield out
|
519 |
+
img = out["sample"]
|
520 |
+
|
521 |
+
def ddim_sample(
|
522 |
+
self,
|
523 |
+
model,
|
524 |
+
x,
|
525 |
+
t,
|
526 |
+
clip_denoised=True,
|
527 |
+
denoised_fn=None,
|
528 |
+
cond_fn=None,
|
529 |
+
model_kwargs=None,
|
530 |
+
eta=0.0,
|
531 |
+
):
|
532 |
+
"""
|
533 |
+
Sample x_{t-1} from the model using DDIM.
|
534 |
+
Same usage as p_sample().
|
535 |
+
"""
|
536 |
+
out = self.p_mean_variance(
|
537 |
+
model,
|
538 |
+
x,
|
539 |
+
t,
|
540 |
+
clip_denoised=clip_denoised,
|
541 |
+
denoised_fn=denoised_fn,
|
542 |
+
model_kwargs=model_kwargs,
|
543 |
+
)
|
544 |
+
if cond_fn is not None:
|
545 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
546 |
+
|
547 |
+
# Usually our model outputs epsilon, but we re-derive it
|
548 |
+
# in case we used x_start or x_prev prediction.
|
549 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
550 |
+
|
551 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
552 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
553 |
+
sigma = (
|
554 |
+
eta
|
555 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
556 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
557 |
+
)
|
558 |
+
# Equation 12.
|
559 |
+
noise = th.randn_like(x)
|
560 |
+
mean_pred = (
|
561 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
562 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
563 |
+
)
|
564 |
+
nonzero_mask = (
|
565 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
566 |
+
) # no noise when t == 0
|
567 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
568 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
569 |
+
|
570 |
+
def ddim_reverse_sample(
|
571 |
+
self,
|
572 |
+
model,
|
573 |
+
x,
|
574 |
+
t,
|
575 |
+
clip_denoised=True,
|
576 |
+
denoised_fn=None,
|
577 |
+
cond_fn=None,
|
578 |
+
model_kwargs=None,
|
579 |
+
eta=0.0,
|
580 |
+
):
|
581 |
+
"""
|
582 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
583 |
+
"""
|
584 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
585 |
+
out = self.p_mean_variance(
|
586 |
+
model,
|
587 |
+
x,
|
588 |
+
t,
|
589 |
+
clip_denoised=clip_denoised,
|
590 |
+
denoised_fn=denoised_fn,
|
591 |
+
model_kwargs=model_kwargs,
|
592 |
+
)
|
593 |
+
if cond_fn is not None:
|
594 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
595 |
+
# Usually our model outputs epsilon, but we re-derive it
|
596 |
+
# in case we used x_start or x_prev prediction.
|
597 |
+
eps = (
|
598 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
599 |
+
- out["pred_xstart"]
|
600 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
601 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
602 |
+
|
603 |
+
# Equation 12. reversed
|
604 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
605 |
+
|
606 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
607 |
+
|
608 |
+
def ddim_sample_loop(
|
609 |
+
self,
|
610 |
+
model,
|
611 |
+
shape,
|
612 |
+
noise=None,
|
613 |
+
clip_denoised=True,
|
614 |
+
denoised_fn=None,
|
615 |
+
cond_fn=None,
|
616 |
+
model_kwargs=None,
|
617 |
+
device=None,
|
618 |
+
progress=False,
|
619 |
+
eta=0.0,
|
620 |
+
):
|
621 |
+
"""
|
622 |
+
Generate samples from the model using DDIM.
|
623 |
+
Same usage as p_sample_loop().
|
624 |
+
"""
|
625 |
+
final = None
|
626 |
+
for sample in self.ddim_sample_loop_progressive(
|
627 |
+
model,
|
628 |
+
shape,
|
629 |
+
noise=noise,
|
630 |
+
clip_denoised=clip_denoised,
|
631 |
+
denoised_fn=denoised_fn,
|
632 |
+
cond_fn=cond_fn,
|
633 |
+
model_kwargs=model_kwargs,
|
634 |
+
device=device,
|
635 |
+
progress=progress,
|
636 |
+
eta=eta,
|
637 |
+
):
|
638 |
+
final = sample
|
639 |
+
return final["sample"]
|
640 |
+
|
641 |
+
def ddim_sample_loop_progressive(
|
642 |
+
self,
|
643 |
+
model,
|
644 |
+
shape,
|
645 |
+
noise=None,
|
646 |
+
clip_denoised=True,
|
647 |
+
denoised_fn=None,
|
648 |
+
cond_fn=None,
|
649 |
+
model_kwargs=None,
|
650 |
+
device=None,
|
651 |
+
progress=False,
|
652 |
+
eta=0.0,
|
653 |
+
):
|
654 |
+
"""
|
655 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
656 |
+
each timestep of DDIM.
|
657 |
+
Same usage as p_sample_loop_progressive().
|
658 |
+
"""
|
659 |
+
if device is None:
|
660 |
+
device = next(model.parameters()).device
|
661 |
+
assert isinstance(shape, (tuple, list))
|
662 |
+
if noise is not None:
|
663 |
+
img = noise
|
664 |
+
else:
|
665 |
+
img = th.randn(*shape, device=device)
|
666 |
+
indices = list(range(self.num_timesteps))[::-1]
|
667 |
+
|
668 |
+
if progress:
|
669 |
+
# Lazy import so that we don't depend on tqdm.
|
670 |
+
from tqdm.auto import tqdm
|
671 |
+
|
672 |
+
indices = tqdm(indices)
|
673 |
+
|
674 |
+
for i in indices:
|
675 |
+
t = th.tensor([i] * shape[0], device=device)
|
676 |
+
with th.no_grad():
|
677 |
+
out = self.ddim_sample(
|
678 |
+
model,
|
679 |
+
img,
|
680 |
+
t,
|
681 |
+
clip_denoised=clip_denoised,
|
682 |
+
denoised_fn=denoised_fn,
|
683 |
+
cond_fn=cond_fn,
|
684 |
+
model_kwargs=model_kwargs,
|
685 |
+
eta=eta,
|
686 |
+
)
|
687 |
+
yield out
|
688 |
+
img = out["sample"]
|
689 |
+
|
690 |
+
def _vb_terms_bpd(
|
691 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
692 |
+
):
|
693 |
+
"""
|
694 |
+
Get a term for the variational lower-bound.
|
695 |
+
The resulting units are bits (rather than nats, as one might expect).
|
696 |
+
This allows for comparison to other papers.
|
697 |
+
:return: a dict with the following keys:
|
698 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
699 |
+
- 'pred_xstart': the x_0 predictions.
|
700 |
+
"""
|
701 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
702 |
+
x_start=x_start, x_t=x_t, t=t
|
703 |
+
)
|
704 |
+
out = self.p_mean_variance(
|
705 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
706 |
+
)
|
707 |
+
kl = normal_kl(
|
708 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
709 |
+
)
|
710 |
+
kl = mean_flat(kl) / np.log(2.0)
|
711 |
+
|
712 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
713 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
714 |
+
)
|
715 |
+
assert decoder_nll.shape == x_start.shape
|
716 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
717 |
+
|
718 |
+
# At the first timestep return the decoder NLL,
|
719 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
720 |
+
output = th.where((t == 0), decoder_nll, kl)
|
721 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
722 |
+
|
723 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
724 |
+
"""
|
725 |
+
Compute training losses for a single timestep.
|
726 |
+
:param model: the model to evaluate loss on.
|
727 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
728 |
+
:param t: a batch of timestep indices.
|
729 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
730 |
+
pass to the model. This can be used for conditioning.
|
731 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
732 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
733 |
+
Some mean or variance settings may also have other keys.
|
734 |
+
"""
|
735 |
+
if model_kwargs is None:
|
736 |
+
model_kwargs = {}
|
737 |
+
if noise is None:
|
738 |
+
noise = th.randn_like(x_start)
|
739 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
740 |
+
|
741 |
+
terms = {}
|
742 |
+
|
743 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
744 |
+
terms["loss"] = self._vb_terms_bpd(
|
745 |
+
model=model,
|
746 |
+
x_start=x_start,
|
747 |
+
x_t=x_t,
|
748 |
+
t=t,
|
749 |
+
clip_denoised=False,
|
750 |
+
model_kwargs=model_kwargs,
|
751 |
+
)["output"]
|
752 |
+
if self.loss_type == LossType.RESCALED_KL:
|
753 |
+
terms["loss"] *= self.num_timesteps
|
754 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
755 |
+
model_output = model(x_t, t, **model_kwargs)
|
756 |
+
|
757 |
+
if self.model_var_type in [
|
758 |
+
ModelVarType.LEARNED,
|
759 |
+
ModelVarType.LEARNED_RANGE,
|
760 |
+
]:
|
761 |
+
B, C = x_t.shape[:2]
|
762 |
+
if len(model_output.shape) == len(x_t.shape) + 1:
|
763 |
+
x_t = x_t.unsqueeze(-1).expand(*([-1] * (len(x_t.shape))), model_output.shape[-1])
|
764 |
+
x_start = x_start.unsqueeze(-1).expand(*([-1] * (len(x_start.shape))), model_output.shape[-1])
|
765 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
766 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
767 |
+
# Learn the variance using the variational bound, but don't let
|
768 |
+
# it affect our mean prediction.
|
769 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
770 |
+
terms["vb"] = self._vb_terms_bpd(
|
771 |
+
model=lambda *args, r=frozen_out: r,
|
772 |
+
x_start=x_start,
|
773 |
+
x_t=x_t,
|
774 |
+
t=t,
|
775 |
+
clip_denoised=False,
|
776 |
+
)["output"]
|
777 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
778 |
+
# Divide by 1000 for equivalence with initial implementation.
|
779 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
780 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
781 |
+
|
782 |
+
target = {
|
783 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
784 |
+
x_start=x_start, x_t=x_t, t=t
|
785 |
+
)[0],
|
786 |
+
ModelMeanType.START_X: x_start,
|
787 |
+
ModelMeanType.EPSILON: noise,
|
788 |
+
}[self.model_mean_type]
|
789 |
+
if len(model_output.shape) == len(target.shape) + 1:
|
790 |
+
target = target.unsqueeze(-1).expand(*([-1] * (len(target.shape))), model_output.shape[-1])
|
791 |
+
assert model_output.shape == target.shape == x_start.shape
|
792 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
793 |
+
if "vb" in terms:
|
794 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
795 |
+
else:
|
796 |
+
terms["loss"] = terms["mse"]
|
797 |
+
else:
|
798 |
+
raise NotImplementedError(self.loss_type)
|
799 |
+
|
800 |
+
return terms
|
801 |
+
|
802 |
+
def _prior_bpd(self, x_start):
|
803 |
+
"""
|
804 |
+
Get the prior KL term for the variational lower-bound, measured in
|
805 |
+
bits-per-dim.
|
806 |
+
This term can't be optimized, as it only depends on the encoder.
|
807 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
808 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
809 |
+
"""
|
810 |
+
batch_size = x_start.shape[0]
|
811 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
812 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
813 |
+
kl_prior = normal_kl(
|
814 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
815 |
+
)
|
816 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
817 |
+
|
818 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
819 |
+
"""
|
820 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
821 |
+
as well as other related quantities.
|
822 |
+
:param model: the model to evaluate loss on.
|
823 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
824 |
+
:param clip_denoised: if True, clip denoised samples.
|
825 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
826 |
+
pass to the model. This can be used for conditioning.
|
827 |
+
:return: a dict containing the following keys:
|
828 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
829 |
+
- prior_bpd: the prior term in the lower-bound.
|
830 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
831 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
832 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
833 |
+
"""
|
834 |
+
device = x_start.device
|
835 |
+
batch_size = x_start.shape[0]
|
836 |
+
|
837 |
+
vb = []
|
838 |
+
xstart_mse = []
|
839 |
+
mse = []
|
840 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
841 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
842 |
+
noise = th.randn_like(x_start)
|
843 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
844 |
+
# Calculate VLB term at the current timestep
|
845 |
+
with th.no_grad():
|
846 |
+
out = self._vb_terms_bpd(
|
847 |
+
model,
|
848 |
+
x_start=x_start,
|
849 |
+
x_t=x_t,
|
850 |
+
t=t_batch,
|
851 |
+
clip_denoised=clip_denoised,
|
852 |
+
model_kwargs=model_kwargs,
|
853 |
+
)
|
854 |
+
vb.append(out["output"])
|
855 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
856 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
857 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
858 |
+
|
859 |
+
vb = th.stack(vb, dim=1)
|
860 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
861 |
+
mse = th.stack(mse, dim=1)
|
862 |
+
|
863 |
+
prior_bpd = self._prior_bpd(x_start)
|
864 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
865 |
+
return {
|
866 |
+
"total_bpd": total_bpd,
|
867 |
+
"prior_bpd": prior_bpd,
|
868 |
+
"vb": vb,
|
869 |
+
"xstart_mse": xstart_mse,
|
870 |
+
"mse": mse,
|
871 |
+
}
|
872 |
+
|
873 |
+
|
874 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
875 |
+
"""
|
876 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
877 |
+
:param arr: the 1-D numpy array.
|
878 |
+
:param timesteps: a tensor of indices into the array to extract.
|
879 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
880 |
+
dimension equal to the length of timesteps.
|
881 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
882 |
+
"""
|
883 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
884 |
+
while len(res.shape) < len(broadcast_shape):
|
885 |
+
res = res[..., None]
|
886 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
paintmind/stage1/diffusion/respace.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
if section_count <= 1:
|
52 |
+
frac_stride = 1
|
53 |
+
else:
|
54 |
+
frac_stride = (size - 1) / (section_count - 1)
|
55 |
+
cur_idx = 0.0
|
56 |
+
taken_steps = []
|
57 |
+
for _ in range(section_count):
|
58 |
+
taken_steps.append(start_idx + round(cur_idx))
|
59 |
+
cur_idx += frac_stride
|
60 |
+
all_steps += taken_steps
|
61 |
+
start_idx += size
|
62 |
+
return set(all_steps)
|
63 |
+
|
64 |
+
|
65 |
+
class SpacedDiffusion(GaussianDiffusion):
|
66 |
+
"""
|
67 |
+
A diffusion process which can skip steps in a base diffusion process.
|
68 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
69 |
+
original diffusion process to retain.
|
70 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, use_timesteps, **kwargs):
|
74 |
+
self.use_timesteps = set(use_timesteps)
|
75 |
+
self.timestep_map = []
|
76 |
+
self.original_num_steps = len(kwargs["betas"])
|
77 |
+
|
78 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
79 |
+
last_alpha_cumprod = 1.0
|
80 |
+
new_betas = []
|
81 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
82 |
+
if i in self.use_timesteps:
|
83 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
84 |
+
last_alpha_cumprod = alpha_cumprod
|
85 |
+
self.timestep_map.append(i)
|
86 |
+
kwargs["betas"] = np.array(new_betas)
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
def p_mean_variance(
|
90 |
+
self, model, *args, **kwargs
|
91 |
+
): # pylint: disable=signature-differs
|
92 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
93 |
+
|
94 |
+
def training_losses(
|
95 |
+
self, model, *args, **kwargs
|
96 |
+
): # pylint: disable=signature-differs
|
97 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
98 |
+
|
99 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
100 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
101 |
+
|
102 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
103 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
104 |
+
|
105 |
+
def _wrap_model(self, model):
|
106 |
+
if isinstance(model, _WrappedModel):
|
107 |
+
return model
|
108 |
+
return _WrappedModel(
|
109 |
+
model, self.timestep_map, self.original_num_steps
|
110 |
+
)
|
111 |
+
|
112 |
+
def _scale_timesteps(self, t):
|
113 |
+
# Scaling is done by the wrapped model.
|
114 |
+
return t
|
115 |
+
|
116 |
+
|
117 |
+
class _WrappedModel:
|
118 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
119 |
+
self.model = model
|
120 |
+
self.timestep_map = timestep_map
|
121 |
+
# self.rescale_timesteps = rescale_timesteps
|
122 |
+
self.original_num_steps = original_num_steps
|
123 |
+
|
124 |
+
def __call__(self, x, ts, **kwargs):
|
125 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
126 |
+
new_ts = map_tensor[ts]
|
127 |
+
# if self.rescale_timesteps:
|
128 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
129 |
+
return self.model(x, new_ts, **kwargs)
|
130 |
+
|
paintmind/stage1/diffusion/timestep_sampler.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as th
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def create_named_schedule_sampler(name, diffusion):
|
14 |
+
"""
|
15 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
+
:param name: the name of the sampler.
|
17 |
+
:param diffusion: the diffusion object to sample for.
|
18 |
+
"""
|
19 |
+
if name == "uniform":
|
20 |
+
return UniformSampler(diffusion)
|
21 |
+
elif name == "loss-second-moment":
|
22 |
+
return LossSecondMomentResampler(diffusion)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleSampler(ABC):
|
28 |
+
"""
|
29 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
+
variance of the objective.
|
31 |
+
By default, samplers perform unbiased importance sampling, in which the
|
32 |
+
objective's mean is unchanged.
|
33 |
+
However, subclasses may override sample() to change how the resampled
|
34 |
+
terms are reweighted, allowing for actual changes in the objective.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def weights(self):
|
39 |
+
"""
|
40 |
+
Get a numpy array of weights, one per diffusion step.
|
41 |
+
The weights needn't be normalized, but must be positive.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def sample(self, batch_size, device):
|
45 |
+
"""
|
46 |
+
Importance-sample timesteps for a batch.
|
47 |
+
:param batch_size: the number of timesteps.
|
48 |
+
:param device: the torch device to save to.
|
49 |
+
:return: a tuple (timesteps, weights):
|
50 |
+
- timesteps: a tensor of timestep indices.
|
51 |
+
- weights: a tensor of weights to scale the resulting losses.
|
52 |
+
"""
|
53 |
+
w = self.weights()
|
54 |
+
p = w / np.sum(w)
|
55 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
+
return indices, weights
|
60 |
+
|
61 |
+
|
62 |
+
class UniformSampler(ScheduleSampler):
|
63 |
+
def __init__(self, diffusion):
|
64 |
+
self.diffusion = diffusion
|
65 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
+
|
67 |
+
def weights(self):
|
68 |
+
return self._weights
|
69 |
+
|
70 |
+
|
71 |
+
class LossAwareSampler(ScheduleSampler):
|
72 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
+
"""
|
74 |
+
Update the reweighting using losses from a model.
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
:param local_ts: an integer Tensor of timesteps.
|
80 |
+
:param local_losses: a 1D Tensor of losses.
|
81 |
+
"""
|
82 |
+
batch_sizes = [
|
83 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
84 |
+
for _ in range(dist.get_world_size())
|
85 |
+
]
|
86 |
+
dist.all_gather(
|
87 |
+
batch_sizes,
|
88 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
89 |
+
)
|
90 |
+
|
91 |
+
# Pad all_gather batches to be the maximum batch size.
|
92 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
93 |
+
max_bs = max(batch_sizes)
|
94 |
+
|
95 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
96 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
97 |
+
dist.all_gather(timestep_batches, local_ts)
|
98 |
+
dist.all_gather(loss_batches, local_losses)
|
99 |
+
timesteps = [
|
100 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
101 |
+
]
|
102 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
103 |
+
self.update_with_all_losses(timesteps, losses)
|
104 |
+
|
105 |
+
@abstractmethod
|
106 |
+
def update_with_all_losses(self, ts, losses):
|
107 |
+
"""
|
108 |
+
Update the reweighting using losses from a model.
|
109 |
+
Sub-classes should override this method to update the reweighting
|
110 |
+
using losses from the model.
|
111 |
+
This method directly updates the reweighting without synchronizing
|
112 |
+
between workers. It is called by update_with_local_losses from all
|
113 |
+
ranks with identical arguments. Thus, it should have deterministic
|
114 |
+
behavior to maintain state across workers.
|
115 |
+
:param ts: a list of int timesteps.
|
116 |
+
:param losses: a list of float losses, one per timestep.
|
117 |
+
"""
|
118 |
+
|
119 |
+
|
120 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
121 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
122 |
+
self.diffusion = diffusion
|
123 |
+
self.history_per_term = history_per_term
|
124 |
+
self.uniform_prob = uniform_prob
|
125 |
+
self._loss_history = np.zeros(
|
126 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
127 |
+
)
|
128 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
129 |
+
|
130 |
+
def weights(self):
|
131 |
+
if not self._warmed_up():
|
132 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
133 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
134 |
+
weights /= np.sum(weights)
|
135 |
+
weights *= 1 - self.uniform_prob
|
136 |
+
weights += self.uniform_prob / len(weights)
|
137 |
+
return weights
|
138 |
+
|
139 |
+
def update_with_all_losses(self, ts, losses):
|
140 |
+
for t, loss in zip(ts, losses):
|
141 |
+
if self._loss_counts[t] == self.history_per_term:
|
142 |
+
# Shift out the oldest loss term.
|
143 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
144 |
+
self._loss_history[t, -1] = loss
|
145 |
+
else:
|
146 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
147 |
+
self._loss_counts[t] += 1
|
148 |
+
|
149 |
+
def _warmed_up(self):
|
150 |
+
return (self._loss_counts == self.history_per_term).all()
|
paintmind/stage1/diffusion_transfomers.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
import math
|
16 |
+
from timm.models.vision_transformer import PatchEmbed, Mlp
|
17 |
+
from paintmind.stage1.fused_attention import Attention
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def modulate(x, shift, scale):
|
22 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
23 |
+
|
24 |
+
|
25 |
+
#################################################################################
|
26 |
+
# Embedding Layers for Timesteps and Class Labels #
|
27 |
+
#################################################################################
|
28 |
+
|
29 |
+
class TimestepEmbedder(nn.Module):
|
30 |
+
"""
|
31 |
+
Embeds scalar timesteps into vector representations.
|
32 |
+
"""
|
33 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
34 |
+
super().__init__()
|
35 |
+
self.mlp = nn.Sequential(
|
36 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
37 |
+
nn.SiLU(),
|
38 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
39 |
+
)
|
40 |
+
self.frequency_embedding_size = frequency_embedding_size
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def timestep_embedding(t, dim, max_period=10000):
|
44 |
+
"""
|
45 |
+
Create sinusoidal timestep embeddings.
|
46 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
47 |
+
These may be fractional.
|
48 |
+
:param dim: the dimension of the output.
|
49 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
50 |
+
:return: an (N, D) Tensor of positional embeddings.
|
51 |
+
"""
|
52 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
53 |
+
half = dim // 2
|
54 |
+
freqs = torch.exp(
|
55 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
56 |
+
).to(device=t.device)
|
57 |
+
args = t[:, None].float() * freqs[None]
|
58 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
59 |
+
if dim % 2:
|
60 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
61 |
+
return embedding
|
62 |
+
|
63 |
+
def forward(self, t):
|
64 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
65 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
66 |
+
return t_emb
|
67 |
+
|
68 |
+
|
69 |
+
class LabelEmbedder(nn.Module):
|
70 |
+
"""
|
71 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
72 |
+
"""
|
73 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
74 |
+
super().__init__()
|
75 |
+
use_cfg_embedding = dropout_prob > 0
|
76 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
77 |
+
self.num_classes = num_classes
|
78 |
+
self.dropout_prob = dropout_prob
|
79 |
+
|
80 |
+
def token_drop(self, labels, force_drop_ids=None):
|
81 |
+
"""
|
82 |
+
Drops labels to enable classifier-free guidance.
|
83 |
+
"""
|
84 |
+
if force_drop_ids is None:
|
85 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
86 |
+
else:
|
87 |
+
drop_ids = force_drop_ids == 1
|
88 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
89 |
+
return labels
|
90 |
+
|
91 |
+
def forward(self, labels, train, force_drop_ids=None):
|
92 |
+
use_dropout = self.dropout_prob > 0
|
93 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
94 |
+
labels = self.token_drop(labels, force_drop_ids)
|
95 |
+
embeddings = self.embedding_table(labels)
|
96 |
+
return embeddings
|
97 |
+
|
98 |
+
|
99 |
+
#################################################################################
|
100 |
+
# Core DiT Model #
|
101 |
+
#################################################################################
|
102 |
+
|
103 |
+
class DiTBlock(nn.Module):
|
104 |
+
"""
|
105 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
106 |
+
"""
|
107 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
108 |
+
super().__init__()
|
109 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
110 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
111 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
112 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
113 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
114 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
115 |
+
self.adaLN_modulation = nn.Sequential(
|
116 |
+
nn.SiLU(),
|
117 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x, c, mask=None):
|
121 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
122 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)
|
123 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class FinalLayer(nn.Module):
|
128 |
+
"""
|
129 |
+
The final layer of DiT.
|
130 |
+
"""
|
131 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
132 |
+
super().__init__()
|
133 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
134 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
135 |
+
self.adaLN_modulation = nn.Sequential(
|
136 |
+
nn.SiLU(),
|
137 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, x, c):
|
141 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
142 |
+
x = modulate(self.norm_final(x), shift, scale)
|
143 |
+
x = self.linear(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class DiT(nn.Module):
|
148 |
+
"""
|
149 |
+
Diffusion model with a Transformer backbone.
|
150 |
+
"""
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
input_size=32,
|
154 |
+
patch_size=2,
|
155 |
+
in_channels=4,
|
156 |
+
hidden_size=1152,
|
157 |
+
depth=28,
|
158 |
+
num_heads=16,
|
159 |
+
mlp_ratio=4.0,
|
160 |
+
class_dropout_prob=0.1,
|
161 |
+
num_classes=1000,
|
162 |
+
learn_sigma=True,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.learn_sigma = learn_sigma
|
166 |
+
self.in_channels = in_channels
|
167 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
168 |
+
self.patch_size = patch_size
|
169 |
+
self.num_heads = num_heads
|
170 |
+
|
171 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
172 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
173 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
174 |
+
num_patches = self.x_embedder.num_patches
|
175 |
+
# Will use fixed sin-cos embedding:
|
176 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
177 |
+
|
178 |
+
self.blocks = nn.ModuleList([
|
179 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
180 |
+
])
|
181 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
182 |
+
self.initialize_weights()
|
183 |
+
|
184 |
+
def initialize_weights(self):
|
185 |
+
# Initialize transformer layers:
|
186 |
+
def _basic_init(module):
|
187 |
+
if isinstance(module, nn.Linear):
|
188 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
189 |
+
if module.bias is not None:
|
190 |
+
nn.init.constant_(module.bias, 0)
|
191 |
+
self.apply(_basic_init)
|
192 |
+
|
193 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
194 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
195 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
196 |
+
|
197 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
198 |
+
w = self.x_embedder.proj.weight.data
|
199 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
200 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
201 |
+
|
202 |
+
# Initialize label embedding table:
|
203 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
204 |
+
|
205 |
+
# Initialize timestep embedding MLP:
|
206 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
207 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
208 |
+
|
209 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
210 |
+
for block in self.blocks:
|
211 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
212 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
213 |
+
|
214 |
+
# Zero-out output layers:
|
215 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
216 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
217 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
218 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
219 |
+
|
220 |
+
def unpatchify(self, x):
|
221 |
+
"""
|
222 |
+
x: (N, T, patch_size**2 * C)
|
223 |
+
imgs: (N, H, W, C)
|
224 |
+
"""
|
225 |
+
c = self.out_channels
|
226 |
+
p = self.x_embedder.patch_size[0]
|
227 |
+
h = w = int(x.shape[1] ** 0.5)
|
228 |
+
assert h * w == x.shape[1]
|
229 |
+
|
230 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
231 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
232 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
233 |
+
return imgs
|
234 |
+
|
235 |
+
def forward(self, x, t, y):
|
236 |
+
"""
|
237 |
+
Forward pass of DiT.
|
238 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
239 |
+
t: (N,) tensor of diffusion timesteps
|
240 |
+
y: (N,) tensor of class labels
|
241 |
+
"""
|
242 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
243 |
+
t = self.t_embedder(t) # (N, D)
|
244 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
245 |
+
c = t + y # (N, D)
|
246 |
+
for block in self.blocks:
|
247 |
+
x = block(x, c) # (N, T, D)
|
248 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
249 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
250 |
+
return x
|
251 |
+
|
252 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
253 |
+
"""
|
254 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
255 |
+
"""
|
256 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
257 |
+
half = x[: len(x) // 2]
|
258 |
+
combined = torch.cat([half, half], dim=0)
|
259 |
+
model_out = self.forward(combined, t, y)
|
260 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
261 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
262 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
263 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
264 |
+
# eps, rest = model_out[:, :3], model_out[:, 3:]
|
265 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
266 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
267 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
268 |
+
return torch.cat([eps, rest], dim=1)
|
269 |
+
|
270 |
+
|
271 |
+
#################################################################################
|
272 |
+
# Sine/Cosine Positional Embedding Functions #
|
273 |
+
#################################################################################
|
274 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
275 |
+
|
276 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
277 |
+
"""
|
278 |
+
grid_size: int of the grid height and width
|
279 |
+
return:
|
280 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
281 |
+
"""
|
282 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
283 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
284 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
285 |
+
grid = np.stack(grid, axis=0)
|
286 |
+
|
287 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
288 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
289 |
+
if cls_token and extra_tokens > 0:
|
290 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
291 |
+
return pos_embed
|
292 |
+
|
293 |
+
|
294 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
295 |
+
assert embed_dim % 2 == 0
|
296 |
+
|
297 |
+
# use half of dimensions to encode grid_h
|
298 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
299 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
300 |
+
|
301 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
302 |
+
return emb
|
303 |
+
|
304 |
+
|
305 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
306 |
+
"""
|
307 |
+
embed_dim: output dimension for each position
|
308 |
+
pos: a list of positions to be encoded: size (M,)
|
309 |
+
out: (M, D)
|
310 |
+
"""
|
311 |
+
assert embed_dim % 2 == 0
|
312 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
313 |
+
omega /= embed_dim / 2.
|
314 |
+
omega = 1. / 10000**omega # (D/2,)
|
315 |
+
|
316 |
+
pos = pos.reshape(-1) # (M,)
|
317 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
318 |
+
|
319 |
+
emb_sin = np.sin(out) # (M, D/2)
|
320 |
+
emb_cos = np.cos(out) # (M, D/2)
|
321 |
+
|
322 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
323 |
+
return emb
|
324 |
+
|
325 |
+
|
326 |
+
#################################################################################
|
327 |
+
# DiT Configs #
|
328 |
+
#################################################################################
|
329 |
+
|
330 |
+
def DiT_XL_2(**kwargs):
|
331 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
332 |
+
|
333 |
+
def DiT_XL_4(**kwargs):
|
334 |
+
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
335 |
+
|
336 |
+
def DiT_XL_8(**kwargs):
|
337 |
+
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
338 |
+
|
339 |
+
def DiT_L_2(**kwargs):
|
340 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
341 |
+
|
342 |
+
def DiT_L_4(**kwargs):
|
343 |
+
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
344 |
+
|
345 |
+
def DiT_L_8(**kwargs):
|
346 |
+
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
347 |
+
|
348 |
+
def DiT_B_2(**kwargs):
|
349 |
+
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
350 |
+
|
351 |
+
def DiT_B_4(**kwargs):
|
352 |
+
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
353 |
+
|
354 |
+
def DiT_B_8(**kwargs):
|
355 |
+
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
356 |
+
|
357 |
+
def DiT_S_2(**kwargs):
|
358 |
+
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
359 |
+
|
360 |
+
def DiT_S_4(**kwargs):
|
361 |
+
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
362 |
+
|
363 |
+
def DiT_S_8(**kwargs):
|
364 |
+
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
365 |
+
|
366 |
+
|
367 |
+
DiT_models = {
|
368 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
369 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
370 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
371 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
372 |
+
}
|
paintmind/stage1/fused_attention.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import Type
|
6 |
+
|
7 |
+
class Attention(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
dim: int,
|
11 |
+
num_heads: int = 8,
|
12 |
+
qkv_bias: bool = False,
|
13 |
+
qk_norm: bool = False,
|
14 |
+
proj_bias: bool = True,
|
15 |
+
attn_drop: float = 0.,
|
16 |
+
proj_drop: float = 0.,
|
17 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
18 |
+
) -> None:
|
19 |
+
super().__init__()
|
20 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.head_dim = dim // num_heads
|
23 |
+
self.scale = self.head_dim ** -0.5
|
24 |
+
|
25 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
26 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
27 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
28 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
29 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
30 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
31 |
+
|
32 |
+
def forward(self, x, attn_mask=None):
|
33 |
+
B, N, C = x.shape
|
34 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
35 |
+
q, k, v = qkv.unbind(0)
|
36 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
37 |
+
|
38 |
+
x = F.scaled_dot_product_attention(
|
39 |
+
q, k, v,
|
40 |
+
attn_mask=attn_mask, # True indicates parts should take part in attention in this API
|
41 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
42 |
+
)
|
43 |
+
|
44 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
45 |
+
x = self.proj(x)
|
46 |
+
x = self.proj_drop(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class MultiHeadCrossAttention(nn.Module):
|
51 |
+
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
|
52 |
+
super().__init__()
|
53 |
+
if d_model % num_heads != 0:
|
54 |
+
raise AssertionError(
|
55 |
+
"d_model (%d) must be divisible by num_heads (%d)"
|
56 |
+
% (d_model, num_heads)
|
57 |
+
)
|
58 |
+
|
59 |
+
self.d_model = d_model
|
60 |
+
self.num_heads = num_heads
|
61 |
+
self.head_dim = d_model // num_heads
|
62 |
+
|
63 |
+
self.q_linear = nn.Linear(d_model, d_model)
|
64 |
+
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
65 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
66 |
+
self.proj = nn.Linear(d_model, d_model)
|
67 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
68 |
+
|
69 |
+
def forward(self, x, cond, mask=None):
|
70 |
+
# query/value: img tokens; key: condition; mask: if padding tokens
|
71 |
+
B, N, C = x.shape
|
72 |
+
|
73 |
+
q = self.q_linear(x).view(-1, self.num_heads, self.head_dim)
|
74 |
+
kv = self.kv_linear(cond).view(-1, 2, self.num_heads, self.head_dim)
|
75 |
+
k, v = kv.unbind(1)
|
76 |
+
|
77 |
+
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
78 |
+
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
79 |
+
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
80 |
+
if mask is not None:
|
81 |
+
temp_mask = torch.ones(B, 1, q.size(-2), k.size(-2), dtype=torch.bool, device=q.device)
|
82 |
+
for i in range(B):
|
83 |
+
temp_mask[i, :, :, mask[i]:] = False
|
84 |
+
mask = temp_mask
|
85 |
+
x = F.scaled_dot_product_attention(
|
86 |
+
q, k, v,
|
87 |
+
attn_mask=mask,
|
88 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
89 |
+
).transpose(1, 2)
|
90 |
+
|
91 |
+
x = x.view(B, -1, C)
|
92 |
+
x = self.proj(x)
|
93 |
+
x = self.proj_drop(x)
|
94 |
+
return x
|
paintmind/stage1/pos_embed.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Position embedding utils
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# 2D sine-cosine position embedding
|
16 |
+
# References:
|
17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
19 |
+
# --------------------------------------------------------
|
20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
21 |
+
"""
|
22 |
+
grid_size: int of the grid height and width
|
23 |
+
return:
|
24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
25 |
+
"""
|
26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
29 |
+
grid = np.stack(grid, axis=0)
|
30 |
+
|
31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
33 |
+
if cls_token:
|
34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
35 |
+
return pos_embed
|
36 |
+
|
37 |
+
|
38 |
+
def get_1d_sincos_pos_embed(embed_dim, grid_size):
|
39 |
+
grid = np.arange(grid_size, dtype=np.float32)
|
40 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
|
41 |
+
return pos_embed
|
42 |
+
|
43 |
+
|
44 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
45 |
+
assert embed_dim % 2 == 0
|
46 |
+
|
47 |
+
# use half of dimensions to encode grid_h
|
48 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
49 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
50 |
+
|
51 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
52 |
+
return emb
|
53 |
+
|
54 |
+
|
55 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
56 |
+
"""
|
57 |
+
embed_dim: output dimension for each position
|
58 |
+
pos: a list of positions to be encoded: size (M,)
|
59 |
+
out: (M, D)
|
60 |
+
"""
|
61 |
+
assert embed_dim % 2 == 0
|
62 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
63 |
+
omega /= embed_dim / 2.
|
64 |
+
omega = 1. / 10000**omega # (D/2,)
|
65 |
+
|
66 |
+
pos = pos.reshape(-1) # (M,)
|
67 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
68 |
+
|
69 |
+
emb_sin = np.sin(out) # (M, D/2)
|
70 |
+
emb_cos = np.cos(out) # (M, D/2)
|
71 |
+
|
72 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
73 |
+
return emb
|
74 |
+
|
75 |
+
|
76 |
+
# --------------------------------------------------------
|
77 |
+
# Interpolate position embeddings for high-resolution
|
78 |
+
# References:
|
79 |
+
# DeiT: https://github.com/facebookresearch/deit
|
80 |
+
# --------------------------------------------------------
|
81 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
82 |
+
if 'pos_embed' in checkpoint_model:
|
83 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
84 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
85 |
+
num_patches = model.patch_embed.num_patches
|
86 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
87 |
+
# height (== width) for the checkpoint position embedding
|
88 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
89 |
+
# height (== width) for the new position embedding
|
90 |
+
new_size = int(num_patches ** 0.5)
|
91 |
+
# class_token and dist_token are kept unchanged
|
92 |
+
if orig_size != new_size:
|
93 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
94 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
95 |
+
# only the position tokens are interpolated
|
96 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
97 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
98 |
+
pos_tokens = torch.nn.functional.interpolate(
|
99 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
100 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
101 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
102 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
paintmind/stage1/quantize.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def l2norm(t):
|
7 |
+
return F.normalize(t, p = 2, dim = -1)
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
def __init__(self, n_e, e_dim, beta=0.25, use_norm=True):
|
11 |
+
super().__init__()
|
12 |
+
self.n_e = n_e
|
13 |
+
self.e_dim = e_dim
|
14 |
+
self.beta = beta
|
15 |
+
self.use_norm = use_norm
|
16 |
+
|
17 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
18 |
+
self.embedding.weight.data.normal_()
|
19 |
+
|
20 |
+
def forward(self, z):
|
21 |
+
if self.use_norm:
|
22 |
+
z = l2norm(z)
|
23 |
+
z_flattened = z.view(-1, self.e_dim)
|
24 |
+
if self.use_norm:
|
25 |
+
embedd_norm = l2norm(self.embedding.weight)
|
26 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
27 |
+
|
28 |
+
if self.use_norm:
|
29 |
+
d = 2 - 2 * torch.einsum('bc, nc -> bn', z_flattened, embedd_norm)
|
30 |
+
else:
|
31 |
+
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
|
32 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
33 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
|
34 |
+
|
35 |
+
encoding_indices = torch.argmin(d, dim=1).view(*z.shape[:-1])
|
36 |
+
z_q = self.embedding(encoding_indices)
|
37 |
+
if self.use_norm:
|
38 |
+
z_q = l2norm(z_q)
|
39 |
+
|
40 |
+
# compute loss for embedding
|
41 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q-z.detach())**2)
|
42 |
+
|
43 |
+
# preserve gradients
|
44 |
+
z_q = z + (z_q - z).detach()
|
45 |
+
|
46 |
+
return z_q, loss, encoding_indices
|
47 |
+
|
48 |
+
def decode_from_indice(self, indices):
|
49 |
+
z_q = self.embedding(indices)
|
50 |
+
if self.use_norm:
|
51 |
+
z_q = l2norm(z_q)
|
52 |
+
|
53 |
+
return z_q
|
54 |
+
|
55 |
+
class DiagonalGaussianDistribution(object):
|
56 |
+
def __init__(self, parameters, deterministic=False):
|
57 |
+
self.parameters = parameters # [B, L, 2C], not [B, 2C, H, W]
|
58 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=2)
|
59 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
60 |
+
self.deterministic = deterministic
|
61 |
+
self.std = torch.exp(0.5 * self.logvar)
|
62 |
+
self.var = torch.exp(self.logvar)
|
63 |
+
if self.deterministic:
|
64 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
65 |
+
|
66 |
+
def sample(self):
|
67 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
68 |
+
return x
|
69 |
+
|
70 |
+
def kl(self, other=None):
|
71 |
+
if self.deterministic:
|
72 |
+
return torch.Tensor([0.])
|
73 |
+
else:
|
74 |
+
if other is None:
|
75 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
76 |
+
+ self.var - 1.0 - self.logvar,
|
77 |
+
dim=[1, 2])
|
78 |
+
else:
|
79 |
+
return 0.5 * torch.sum(
|
80 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
81 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
82 |
+
dim=[1, 2])
|
83 |
+
|
84 |
+
def nll(self, sample, dims=[1,2]):
|
85 |
+
if self.deterministic:
|
86 |
+
return torch.Tensor([0.])
|
87 |
+
logtwopi = np.log(2.0 * np.pi)
|
88 |
+
return 0.5 * torch.sum(
|
89 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
90 |
+
dim=dims)
|
91 |
+
|
92 |
+
def mode(self):
|
93 |
+
return self.mean
|
paintmind/stage1/transport/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
2 |
+
|
3 |
+
def create_transport(
|
4 |
+
path_type='Linear',
|
5 |
+
prediction="velocity",
|
6 |
+
loss_weight=None,
|
7 |
+
train_eps=None,
|
8 |
+
sample_eps=None,
|
9 |
+
):
|
10 |
+
"""function for creating Transport object
|
11 |
+
**Note**: model prediction defaults to velocity
|
12 |
+
Args:
|
13 |
+
- path_type: type of path to use; default to linear
|
14 |
+
- learn_score: set model prediction to score
|
15 |
+
- learn_noise: set model prediction to noise
|
16 |
+
- velocity_weighted: weight loss by velocity weight
|
17 |
+
- likelihood_weighted: weight loss by likelihood weight
|
18 |
+
- train_eps: small epsilon for avoiding instability during training
|
19 |
+
- sample_eps: small epsilon for avoiding instability during sampling
|
20 |
+
"""
|
21 |
+
|
22 |
+
if prediction == "noise":
|
23 |
+
model_type = ModelType.NOISE
|
24 |
+
elif prediction == "score":
|
25 |
+
model_type = ModelType.SCORE
|
26 |
+
else:
|
27 |
+
model_type = ModelType.VELOCITY
|
28 |
+
|
29 |
+
if loss_weight == "velocity":
|
30 |
+
loss_type = WeightType.VELOCITY
|
31 |
+
elif loss_weight == "likelihood":
|
32 |
+
loss_type = WeightType.LIKELIHOOD
|
33 |
+
else:
|
34 |
+
loss_type = WeightType.NONE
|
35 |
+
|
36 |
+
path_choice = {
|
37 |
+
"Linear": PathType.LINEAR,
|
38 |
+
"GVP": PathType.GVP,
|
39 |
+
"VP": PathType.VP,
|
40 |
+
}
|
41 |
+
|
42 |
+
path_type = path_choice[path_type]
|
43 |
+
|
44 |
+
if (path_type in [PathType.VP]):
|
45 |
+
train_eps = 1e-5 if train_eps is None else train_eps
|
46 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
47 |
+
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
48 |
+
train_eps = 1e-3 if train_eps is None else train_eps
|
49 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
50 |
+
else: # velocity & [GVP, LINEAR] is stable everywhere
|
51 |
+
train_eps = 0
|
52 |
+
sample_eps = 0
|
53 |
+
|
54 |
+
# create flow state
|
55 |
+
state = Transport(
|
56 |
+
model_type=model_type,
|
57 |
+
path_type=path_type,
|
58 |
+
loss_type=loss_type,
|
59 |
+
train_eps=train_eps,
|
60 |
+
sample_eps=sample_eps,
|
61 |
+
)
|
62 |
+
|
63 |
+
return state
|
paintmind/stage1/transport/integrators.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch as th
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchdiffeq import odeint
|
5 |
+
from functools import partial
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
class sde:
|
9 |
+
"""SDE solver class"""
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
drift,
|
13 |
+
diffusion,
|
14 |
+
*,
|
15 |
+
t0,
|
16 |
+
t1,
|
17 |
+
num_steps,
|
18 |
+
sampler_type,
|
19 |
+
temperature=1.0,
|
20 |
+
):
|
21 |
+
assert t0 < t1, "SDE sampler has to be in forward time"
|
22 |
+
|
23 |
+
self.num_timesteps = num_steps
|
24 |
+
self.t = th.linspace(t0, t1, num_steps)
|
25 |
+
self.dt = self.t[1] - self.t[0]
|
26 |
+
self.drift = drift
|
27 |
+
self.diffusion = diffusion
|
28 |
+
self.sampler_type = sampler_type
|
29 |
+
self.temperature = temperature
|
30 |
+
|
31 |
+
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
32 |
+
w_cur = th.randn(x.size()).to(x)
|
33 |
+
t = th.ones(x.size(0)).to(x) * t
|
34 |
+
dw = w_cur * th.sqrt(self.dt)
|
35 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
36 |
+
diffusion = self.diffusion(x, t)
|
37 |
+
mean_x = x + drift * self.dt
|
38 |
+
x = mean_x + th.sqrt(2 * diffusion) * dw * self.temperature
|
39 |
+
return x, mean_x
|
40 |
+
|
41 |
+
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
42 |
+
w_cur = th.randn(x.size()).to(x)
|
43 |
+
dw = w_cur * th.sqrt(self.dt) * self.temperature
|
44 |
+
t_cur = th.ones(x.size(0)).to(x) * t
|
45 |
+
diffusion = self.diffusion(x, t_cur)
|
46 |
+
xhat = x + th.sqrt(2 * diffusion) * dw
|
47 |
+
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
48 |
+
xp = xhat + self.dt * K1
|
49 |
+
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
50 |
+
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
51 |
+
|
52 |
+
def __forward_fn(self):
|
53 |
+
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
54 |
+
sampler_dict = {
|
55 |
+
"Euler": self.__Euler_Maruyama_step,
|
56 |
+
"Heun": self.__Heun_step,
|
57 |
+
}
|
58 |
+
|
59 |
+
try:
|
60 |
+
sampler = sampler_dict[self.sampler_type]
|
61 |
+
except:
|
62 |
+
raise NotImplementedError("Smapler type not implemented.")
|
63 |
+
|
64 |
+
return sampler
|
65 |
+
|
66 |
+
def sample(self, init, model, **model_kwargs):
|
67 |
+
"""forward loop of sde"""
|
68 |
+
x = init
|
69 |
+
mean_x = init
|
70 |
+
samples = []
|
71 |
+
sampler = self.__forward_fn()
|
72 |
+
for ti in self.t[:-1]:
|
73 |
+
with th.no_grad():
|
74 |
+
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
75 |
+
samples.append(x)
|
76 |
+
|
77 |
+
return samples
|
78 |
+
|
79 |
+
class ode:
|
80 |
+
"""ODE solver class"""
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
drift,
|
84 |
+
*,
|
85 |
+
t0,
|
86 |
+
t1,
|
87 |
+
sampler_type,
|
88 |
+
num_steps,
|
89 |
+
atol,
|
90 |
+
rtol,
|
91 |
+
temperature=1.0,
|
92 |
+
):
|
93 |
+
assert t0 < t1, "ODE sampler has to be in forward time"
|
94 |
+
|
95 |
+
self.drift = drift
|
96 |
+
self.t = th.linspace(t0, t1, num_steps)
|
97 |
+
self.atol = atol
|
98 |
+
self.rtol = rtol
|
99 |
+
self.sampler_type = sampler_type
|
100 |
+
self.temperature = temperature
|
101 |
+
|
102 |
+
def sample(self, x, model, **model_kwargs):
|
103 |
+
|
104 |
+
device = x[0].device if isinstance(x, tuple) else x.device
|
105 |
+
def _fn(t, x):
|
106 |
+
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
107 |
+
# For ODE, we scale the drift by the temperature
|
108 |
+
# This is equivalent to scaling time by 1/temperature
|
109 |
+
model_output = self.drift(x, t, model, **model_kwargs)
|
110 |
+
if self.temperature != 1.0:
|
111 |
+
# If it's a tuple (for likelihood calculation), only scale the first element
|
112 |
+
if isinstance(model_output, tuple):
|
113 |
+
scaled_output = (model_output[0] / self.temperature, model_output[1])
|
114 |
+
return scaled_output
|
115 |
+
else:
|
116 |
+
return model_output / self.temperature
|
117 |
+
return model_output
|
118 |
+
|
119 |
+
t = self.t.to(device)
|
120 |
+
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
121 |
+
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
122 |
+
samples = odeint(
|
123 |
+
_fn,
|
124 |
+
x,
|
125 |
+
t,
|
126 |
+
method=self.sampler_type,
|
127 |
+
atol=atol,
|
128 |
+
rtol=rtol
|
129 |
+
)
|
130 |
+
return samples
|
paintmind/stage1/transport/path.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
import numpy as np
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
def expand_t_like_x(t, x):
|
6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
7 |
+
Args:
|
8 |
+
t: [batch_dim,], time vector
|
9 |
+
x: [batch_dim,...], data point
|
10 |
+
"""
|
11 |
+
dims = [1] * (len(x.size()) - 1)
|
12 |
+
t = t.view(t.size(0), *dims)
|
13 |
+
return t
|
14 |
+
|
15 |
+
|
16 |
+
#################### Coupling Plans ####################
|
17 |
+
|
18 |
+
class ICPlan:
|
19 |
+
"""Linear Coupling Plan"""
|
20 |
+
def __init__(self, sigma=0.0):
|
21 |
+
self.sigma = sigma
|
22 |
+
|
23 |
+
def compute_alpha_t(self, t):
|
24 |
+
"""Compute the data coefficient along the path"""
|
25 |
+
return t, 1
|
26 |
+
|
27 |
+
def compute_sigma_t(self, t):
|
28 |
+
"""Compute the noise coefficient along the path"""
|
29 |
+
return 1 - t, -1
|
30 |
+
|
31 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
32 |
+
"""Compute the ratio between d_alpha and alpha"""
|
33 |
+
return 1 / t
|
34 |
+
|
35 |
+
def compute_drift(self, x, t):
|
36 |
+
"""We always output sde according to score parametrization; """
|
37 |
+
t = expand_t_like_x(t, x)
|
38 |
+
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
39 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
40 |
+
drift = alpha_ratio * x
|
41 |
+
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
42 |
+
|
43 |
+
return -drift, diffusion
|
44 |
+
|
45 |
+
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
46 |
+
"""Compute the diffusion term of the SDE
|
47 |
+
Args:
|
48 |
+
x: [batch_dim, ...], data point
|
49 |
+
t: [batch_dim,], time vector
|
50 |
+
form: str, form of the diffusion term
|
51 |
+
norm: float, norm of the diffusion term
|
52 |
+
"""
|
53 |
+
t = expand_t_like_x(t, x)
|
54 |
+
choices = {
|
55 |
+
"constant": norm,
|
56 |
+
"SBDM": norm * self.compute_drift(x, t)[1],
|
57 |
+
"sigma": norm * self.compute_sigma_t(t)[0],
|
58 |
+
"linear": norm * (1 - t),
|
59 |
+
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
60 |
+
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
61 |
+
}
|
62 |
+
|
63 |
+
try:
|
64 |
+
diffusion = choices[form]
|
65 |
+
except KeyError:
|
66 |
+
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
67 |
+
|
68 |
+
return diffusion
|
69 |
+
|
70 |
+
def get_score_from_velocity(self, velocity, x, t):
|
71 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
72 |
+
Args:
|
73 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
74 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
75 |
+
t: [batch_dim,] time tensor
|
76 |
+
"""
|
77 |
+
t = expand_t_like_x(t, x)
|
78 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
79 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
80 |
+
mean = x
|
81 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
82 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
83 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
84 |
+
return score
|
85 |
+
|
86 |
+
def get_noise_from_velocity(self, velocity, x, t):
|
87 |
+
"""Wrapper function: transfrom velocity prediction model to denoiser
|
88 |
+
Args:
|
89 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
90 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
91 |
+
t: [batch_dim,] time tensor
|
92 |
+
"""
|
93 |
+
t = expand_t_like_x(t, x)
|
94 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
95 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
96 |
+
mean = x
|
97 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
98 |
+
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
99 |
+
noise = (reverse_alpha_ratio * velocity - mean) / var
|
100 |
+
return noise
|
101 |
+
|
102 |
+
def get_velocity_from_score(self, score, x, t):
|
103 |
+
"""Wrapper function: transfrom score prediction model to velocity
|
104 |
+
Args:
|
105 |
+
score: [batch_dim, ...] shaped tensor; score model output
|
106 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
107 |
+
t: [batch_dim,] time tensor
|
108 |
+
"""
|
109 |
+
t = expand_t_like_x(t, x)
|
110 |
+
drift, var = self.compute_drift(x, t)
|
111 |
+
velocity = var * score - drift
|
112 |
+
return velocity
|
113 |
+
|
114 |
+
def compute_mu_t(self, t, x0, x1):
|
115 |
+
"""Compute the mean of time-dependent density p_t"""
|
116 |
+
t = expand_t_like_x(t, x1)
|
117 |
+
alpha_t, _ = self.compute_alpha_t(t)
|
118 |
+
sigma_t, _ = self.compute_sigma_t(t)
|
119 |
+
return alpha_t * x1 + sigma_t * x0
|
120 |
+
|
121 |
+
def compute_xt(self, t, x0, x1):
|
122 |
+
"""Sample xt from time-dependent density p_t; rng is required"""
|
123 |
+
xt = self.compute_mu_t(t, x0, x1)
|
124 |
+
return xt
|
125 |
+
|
126 |
+
def compute_ut(self, t, x0, x1, xt):
|
127 |
+
"""Compute the vector field corresponding to p_t"""
|
128 |
+
t = expand_t_like_x(t, x1)
|
129 |
+
_, d_alpha_t = self.compute_alpha_t(t)
|
130 |
+
_, d_sigma_t = self.compute_sigma_t(t)
|
131 |
+
return d_alpha_t * x1 + d_sigma_t * x0
|
132 |
+
|
133 |
+
def plan(self, t, x0, x1):
|
134 |
+
xt = self.compute_xt(t, x0, x1)
|
135 |
+
ut = self.compute_ut(t, x0, x1, xt)
|
136 |
+
return t, xt, ut
|
137 |
+
|
138 |
+
|
139 |
+
class VPCPlan(ICPlan):
|
140 |
+
"""class for VP path flow matching"""
|
141 |
+
|
142 |
+
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
143 |
+
self.sigma_min = sigma_min
|
144 |
+
self.sigma_max = sigma_max
|
145 |
+
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
146 |
+
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
147 |
+
|
148 |
+
|
149 |
+
def compute_alpha_t(self, t):
|
150 |
+
"""Compute coefficient of x1"""
|
151 |
+
alpha_t = self.log_mean_coeff(t)
|
152 |
+
alpha_t = th.exp(alpha_t)
|
153 |
+
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
154 |
+
return alpha_t, d_alpha_t
|
155 |
+
|
156 |
+
def compute_sigma_t(self, t):
|
157 |
+
"""Compute coefficient of x0"""
|
158 |
+
p_sigma_t = 2 * self.log_mean_coeff(t)
|
159 |
+
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
160 |
+
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
161 |
+
return sigma_t, d_sigma_t
|
162 |
+
|
163 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
164 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
165 |
+
return self.d_log_mean_coeff(t)
|
166 |
+
|
167 |
+
def compute_drift(self, x, t):
|
168 |
+
"""Compute the drift term of the SDE"""
|
169 |
+
t = expand_t_like_x(t, x)
|
170 |
+
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
171 |
+
return -0.5 * beta_t * x, beta_t / 2
|
172 |
+
|
173 |
+
|
174 |
+
class GVPCPlan(ICPlan):
|
175 |
+
def __init__(self, sigma=0.0):
|
176 |
+
super().__init__(sigma)
|
177 |
+
|
178 |
+
def compute_alpha_t(self, t):
|
179 |
+
"""Compute coefficient of x1"""
|
180 |
+
alpha_t = th.sin(t * np.pi / 2)
|
181 |
+
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
182 |
+
return alpha_t, d_alpha_t
|
183 |
+
|
184 |
+
def compute_sigma_t(self, t):
|
185 |
+
"""Compute coefficient of x0"""
|
186 |
+
sigma_t = th.cos(t * np.pi / 2)
|
187 |
+
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
188 |
+
return sigma_t, d_sigma_t
|
189 |
+
|
190 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
191 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
192 |
+
return np.pi / (2 * th.tan(t * np.pi / 2))
|
paintmind/stage1/transport/transport.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import enum
|
6 |
+
|
7 |
+
from . import path
|
8 |
+
from .utils import EasyDict, log_state, mean_flat
|
9 |
+
from .integrators import ode, sde
|
10 |
+
|
11 |
+
class ModelType(enum.Enum):
|
12 |
+
"""
|
13 |
+
Which type of output the model predicts.
|
14 |
+
"""
|
15 |
+
|
16 |
+
NOISE = enum.auto() # the model predicts epsilon
|
17 |
+
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
18 |
+
VELOCITY = enum.auto() # the model predicts v(x)
|
19 |
+
|
20 |
+
class PathType(enum.Enum):
|
21 |
+
"""
|
22 |
+
Which type of path to use.
|
23 |
+
"""
|
24 |
+
|
25 |
+
LINEAR = enum.auto()
|
26 |
+
GVP = enum.auto()
|
27 |
+
VP = enum.auto()
|
28 |
+
|
29 |
+
class WeightType(enum.Enum):
|
30 |
+
"""
|
31 |
+
Which type of weighting to use.
|
32 |
+
"""
|
33 |
+
|
34 |
+
NONE = enum.auto()
|
35 |
+
VELOCITY = enum.auto()
|
36 |
+
LIKELIHOOD = enum.auto()
|
37 |
+
|
38 |
+
|
39 |
+
class Transport:
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
*,
|
44 |
+
model_type,
|
45 |
+
path_type,
|
46 |
+
loss_type,
|
47 |
+
train_eps,
|
48 |
+
sample_eps,
|
49 |
+
):
|
50 |
+
path_options = {
|
51 |
+
PathType.LINEAR: path.ICPlan,
|
52 |
+
PathType.GVP: path.GVPCPlan,
|
53 |
+
PathType.VP: path.VPCPlan,
|
54 |
+
}
|
55 |
+
|
56 |
+
self.loss_type = loss_type
|
57 |
+
self.model_type = model_type
|
58 |
+
self.path_sampler = path_options[path_type]()
|
59 |
+
self.train_eps = train_eps
|
60 |
+
self.sample_eps = sample_eps
|
61 |
+
|
62 |
+
def prior_logp(self, z):
|
63 |
+
'''
|
64 |
+
Standard multivariate normal prior
|
65 |
+
Assume z is batched
|
66 |
+
'''
|
67 |
+
shape = th.tensor(z.size())
|
68 |
+
N = th.prod(shape[1:])
|
69 |
+
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
|
70 |
+
return th.vmap(_fn)(z)
|
71 |
+
|
72 |
+
|
73 |
+
def check_interval(
|
74 |
+
self,
|
75 |
+
train_eps,
|
76 |
+
sample_eps,
|
77 |
+
*,
|
78 |
+
diffusion_form="SBDM",
|
79 |
+
sde=False,
|
80 |
+
reverse=False,
|
81 |
+
eval=False,
|
82 |
+
last_step_size=0.0,
|
83 |
+
):
|
84 |
+
t0 = 0
|
85 |
+
t1 = 1
|
86 |
+
eps = train_eps if not eval else sample_eps
|
87 |
+
if (type(self.path_sampler) in [path.VPCPlan]):
|
88 |
+
|
89 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
90 |
+
|
91 |
+
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
|
92 |
+
and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
|
93 |
+
|
94 |
+
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
95 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
96 |
+
|
97 |
+
if reverse:
|
98 |
+
t0, t1 = 1 - t0, 1 - t1
|
99 |
+
|
100 |
+
return t0, t1
|
101 |
+
|
102 |
+
|
103 |
+
def sample(self, x1):
|
104 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
105 |
+
Args:
|
106 |
+
x1 - data point; [batch, *dim]
|
107 |
+
"""
|
108 |
+
|
109 |
+
x0 = th.randn_like(x1)
|
110 |
+
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
111 |
+
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
112 |
+
t = t.to(x1)
|
113 |
+
return t, x0, x1
|
114 |
+
|
115 |
+
|
116 |
+
def training_losses(
|
117 |
+
self,
|
118 |
+
model,
|
119 |
+
x1,
|
120 |
+
model_kwargs=None
|
121 |
+
):
|
122 |
+
"""Loss for training the score model
|
123 |
+
Args:
|
124 |
+
- model: backbone model; could be score, noise, or velocity
|
125 |
+
- x1: datapoint
|
126 |
+
- model_kwargs: additional arguments for the model
|
127 |
+
"""
|
128 |
+
if model_kwargs == None:
|
129 |
+
model_kwargs = {}
|
130 |
+
|
131 |
+
t, x0, x1 = self.sample(x1)
|
132 |
+
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
133 |
+
model_output = model(xt, t, **model_kwargs)
|
134 |
+
if len(model_output.shape) == len(xt.shape) + 1:
|
135 |
+
x0 = x0.unsqueeze(-1).expand(*([-1] * (len(x0.shape))), model_output.shape[-1])
|
136 |
+
xt = xt.unsqueeze(-1).expand(*([-1] * (len(xt.shape))), model_output.shape[-1])
|
137 |
+
ut = ut.unsqueeze(-1).expand(*([-1] * (len(ut.shape))), model_output.shape[-1])
|
138 |
+
B, C = xt.shape[:2]
|
139 |
+
assert model_output.shape == (B, C, *xt.shape[2:])
|
140 |
+
|
141 |
+
terms = {}
|
142 |
+
terms['pred'] = model_output
|
143 |
+
if self.model_type == ModelType.VELOCITY:
|
144 |
+
terms['loss'] = mean_flat(((model_output - ut) ** 2))
|
145 |
+
else:
|
146 |
+
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
147 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
148 |
+
if self.loss_type in [WeightType.VELOCITY]:
|
149 |
+
weight = (drift_var / sigma_t) ** 2
|
150 |
+
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
151 |
+
weight = drift_var / (sigma_t ** 2)
|
152 |
+
elif self.loss_type in [WeightType.NONE]:
|
153 |
+
weight = 1
|
154 |
+
else:
|
155 |
+
raise NotImplementedError()
|
156 |
+
|
157 |
+
if self.model_type == ModelType.NOISE:
|
158 |
+
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
|
159 |
+
else:
|
160 |
+
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
161 |
+
|
162 |
+
return terms
|
163 |
+
|
164 |
+
|
165 |
+
def get_drift(
|
166 |
+
self
|
167 |
+
):
|
168 |
+
"""member function for obtaining the drift of the probability flow ODE"""
|
169 |
+
def score_ode(x, t, model, **model_kwargs):
|
170 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
171 |
+
model_output = model(x, t, **model_kwargs)
|
172 |
+
return (-drift_mean + drift_var * model_output) # by change of variable
|
173 |
+
|
174 |
+
def noise_ode(x, t, model, **model_kwargs):
|
175 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
176 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
177 |
+
model_output = model(x, t, **model_kwargs)
|
178 |
+
score = model_output / -sigma_t
|
179 |
+
return (-drift_mean + drift_var * score)
|
180 |
+
|
181 |
+
def velocity_ode(x, t, model, **model_kwargs):
|
182 |
+
model_output = model(x, t, **model_kwargs)
|
183 |
+
return model_output
|
184 |
+
|
185 |
+
if self.model_type == ModelType.NOISE:
|
186 |
+
drift_fn = noise_ode
|
187 |
+
elif self.model_type == ModelType.SCORE:
|
188 |
+
drift_fn = score_ode
|
189 |
+
else:
|
190 |
+
drift_fn = velocity_ode
|
191 |
+
|
192 |
+
def body_fn(x, t, model, **model_kwargs):
|
193 |
+
model_output = drift_fn(x, t, model, **model_kwargs)
|
194 |
+
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
195 |
+
return model_output
|
196 |
+
|
197 |
+
return body_fn
|
198 |
+
|
199 |
+
|
200 |
+
def get_score(
|
201 |
+
self,
|
202 |
+
):
|
203 |
+
"""member function for obtaining score of
|
204 |
+
x_t = alpha_t * x + sigma_t * eps"""
|
205 |
+
if self.model_type == ModelType.NOISE:
|
206 |
+
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
207 |
+
elif self.model_type == ModelType.SCORE:
|
208 |
+
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
209 |
+
elif self.model_type == ModelType.VELOCITY:
|
210 |
+
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
|
211 |
+
else:
|
212 |
+
raise NotImplementedError()
|
213 |
+
|
214 |
+
return score_fn
|
215 |
+
|
216 |
+
|
217 |
+
class Sampler:
|
218 |
+
"""Sampler class for the transport model"""
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
transport,
|
222 |
+
):
|
223 |
+
"""Constructor for a general sampler; supporting different sampling methods
|
224 |
+
Args:
|
225 |
+
- transport: an tranport object specify model prediction & interpolant type
|
226 |
+
"""
|
227 |
+
|
228 |
+
self.transport = transport
|
229 |
+
self.drift = self.transport.get_drift()
|
230 |
+
self.score = self.transport.get_score()
|
231 |
+
|
232 |
+
def __get_sde_diffusion_and_drift(
|
233 |
+
self,
|
234 |
+
*,
|
235 |
+
diffusion_form="SBDM",
|
236 |
+
diffusion_norm=1.0,
|
237 |
+
):
|
238 |
+
|
239 |
+
def diffusion_fn(x, t):
|
240 |
+
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
241 |
+
return diffusion
|
242 |
+
|
243 |
+
sde_drift = \
|
244 |
+
lambda x, t, model, **kwargs: \
|
245 |
+
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
246 |
+
|
247 |
+
sde_diffusion = diffusion_fn
|
248 |
+
|
249 |
+
return sde_drift, sde_diffusion
|
250 |
+
|
251 |
+
def __get_last_step(
|
252 |
+
self,
|
253 |
+
sde_drift,
|
254 |
+
*,
|
255 |
+
last_step,
|
256 |
+
last_step_size,
|
257 |
+
):
|
258 |
+
"""Get the last step function of the SDE solver"""
|
259 |
+
|
260 |
+
if last_step is None:
|
261 |
+
last_step_fn = \
|
262 |
+
lambda x, t, model, **model_kwargs: \
|
263 |
+
x
|
264 |
+
elif last_step == "Mean":
|
265 |
+
last_step_fn = \
|
266 |
+
lambda x, t, model, **model_kwargs: \
|
267 |
+
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
268 |
+
elif last_step == "Tweedie":
|
269 |
+
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
270 |
+
sigma = self.transport.path_sampler.compute_sigma_t
|
271 |
+
last_step_fn = \
|
272 |
+
lambda x, t, model, **model_kwargs: \
|
273 |
+
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
|
274 |
+
elif last_step == "Euler":
|
275 |
+
last_step_fn = \
|
276 |
+
lambda x, t, model, **model_kwargs: \
|
277 |
+
x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
278 |
+
else:
|
279 |
+
raise NotImplementedError()
|
280 |
+
|
281 |
+
return last_step_fn
|
282 |
+
|
283 |
+
def sample_sde(
|
284 |
+
self,
|
285 |
+
*,
|
286 |
+
sampling_method="Euler",
|
287 |
+
diffusion_form="SBDM",
|
288 |
+
diffusion_norm=1.0,
|
289 |
+
last_step="Mean",
|
290 |
+
last_step_size=0.04,
|
291 |
+
num_steps=250,
|
292 |
+
temperature=1.0,
|
293 |
+
):
|
294 |
+
"""returns a sampling function with given SDE settings
|
295 |
+
Args:
|
296 |
+
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
297 |
+
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
298 |
+
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
299 |
+
- last_step: type of the last step; default to identity
|
300 |
+
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
301 |
+
- num_steps: total integration step of SDE
|
302 |
+
- temperature: temperature scaling for the noise during sampling; default to 1.0
|
303 |
+
"""
|
304 |
+
|
305 |
+
if last_step is None:
|
306 |
+
last_step_size = 0.0
|
307 |
+
|
308 |
+
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
309 |
+
diffusion_form=diffusion_form,
|
310 |
+
diffusion_norm=diffusion_norm,
|
311 |
+
)
|
312 |
+
|
313 |
+
t0, t1 = self.transport.check_interval(
|
314 |
+
self.transport.train_eps,
|
315 |
+
self.transport.sample_eps,
|
316 |
+
diffusion_form=diffusion_form,
|
317 |
+
sde=True,
|
318 |
+
eval=True,
|
319 |
+
reverse=False,
|
320 |
+
last_step_size=last_step_size,
|
321 |
+
)
|
322 |
+
|
323 |
+
_sde = sde(
|
324 |
+
sde_drift,
|
325 |
+
sde_diffusion,
|
326 |
+
t0=t0,
|
327 |
+
t1=t1,
|
328 |
+
num_steps=num_steps,
|
329 |
+
sampler_type=sampling_method,
|
330 |
+
temperature=temperature
|
331 |
+
)
|
332 |
+
|
333 |
+
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
334 |
+
|
335 |
+
|
336 |
+
def _sample(init, model, **model_kwargs):
|
337 |
+
xs = _sde.sample(init, model, **model_kwargs)
|
338 |
+
ts = th.ones(init.size(0), device=init.device) * t1
|
339 |
+
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
340 |
+
xs.append(x)
|
341 |
+
|
342 |
+
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
343 |
+
|
344 |
+
return xs
|
345 |
+
|
346 |
+
return _sample
|
347 |
+
|
348 |
+
def sample_ode(
|
349 |
+
self,
|
350 |
+
*,
|
351 |
+
sampling_method="dopri5",
|
352 |
+
num_steps=50,
|
353 |
+
atol=1e-6,
|
354 |
+
rtol=1e-3,
|
355 |
+
reverse=False,
|
356 |
+
temperature=1.0,
|
357 |
+
):
|
358 |
+
"""returns a sampling function with given ODE settings
|
359 |
+
Args:
|
360 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
361 |
+
- num_steps:
|
362 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
363 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
364 |
+
- atol: absolute error tolerance for the solver
|
365 |
+
- rtol: relative error tolerance for the solver
|
366 |
+
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
367 |
+
- temperature: temperature scaling for the drift during sampling; default to 1.0
|
368 |
+
"""
|
369 |
+
if reverse:
|
370 |
+
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
371 |
+
else:
|
372 |
+
drift = self.drift
|
373 |
+
|
374 |
+
t0, t1 = self.transport.check_interval(
|
375 |
+
self.transport.train_eps,
|
376 |
+
self.transport.sample_eps,
|
377 |
+
sde=False,
|
378 |
+
eval=True,
|
379 |
+
reverse=reverse,
|
380 |
+
last_step_size=0.0,
|
381 |
+
)
|
382 |
+
|
383 |
+
_ode = ode(
|
384 |
+
drift=drift,
|
385 |
+
t0=t0,
|
386 |
+
t1=t1,
|
387 |
+
sampler_type=sampling_method,
|
388 |
+
num_steps=num_steps,
|
389 |
+
atol=atol,
|
390 |
+
rtol=rtol,
|
391 |
+
temperature=temperature,
|
392 |
+
)
|
393 |
+
|
394 |
+
return _ode.sample
|
395 |
+
|
396 |
+
def sample_ode_likelihood(
|
397 |
+
self,
|
398 |
+
*,
|
399 |
+
sampling_method="dopri5",
|
400 |
+
num_steps=50,
|
401 |
+
atol=1e-6,
|
402 |
+
rtol=1e-3,
|
403 |
+
temperature=1.0,
|
404 |
+
):
|
405 |
+
|
406 |
+
"""returns a sampling function for calculating likelihood with given ODE settings
|
407 |
+
Args:
|
408 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
409 |
+
- num_steps:
|
410 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
411 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
412 |
+
- atol: absolute error tolerance for the solver
|
413 |
+
- rtol: relative error tolerance for the solver
|
414 |
+
- temperature: temperature scaling for the drift during sampling; default to 1.0
|
415 |
+
"""
|
416 |
+
def _likelihood_drift(x, t, model, **model_kwargs):
|
417 |
+
x, _ = x
|
418 |
+
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
419 |
+
t = th.ones_like(t) * (1 - t)
|
420 |
+
with th.enable_grad():
|
421 |
+
x.requires_grad = True
|
422 |
+
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
423 |
+
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
424 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
425 |
+
return (-drift, logp_grad)
|
426 |
+
|
427 |
+
t0, t1 = self.transport.check_interval(
|
428 |
+
self.transport.train_eps,
|
429 |
+
self.transport.sample_eps,
|
430 |
+
sde=False,
|
431 |
+
eval=True,
|
432 |
+
reverse=False,
|
433 |
+
last_step_size=0.0,
|
434 |
+
)
|
435 |
+
|
436 |
+
_ode = ode(
|
437 |
+
drift=_likelihood_drift,
|
438 |
+
t0=t0,
|
439 |
+
t1=t1,
|
440 |
+
sampler_type=sampling_method,
|
441 |
+
num_steps=num_steps,
|
442 |
+
atol=atol,
|
443 |
+
rtol=rtol,
|
444 |
+
temperature=temperature,
|
445 |
+
)
|
446 |
+
|
447 |
+
def _sample_fn(x, model, **model_kwargs):
|
448 |
+
init_logp = th.zeros(x.size(0)).to(x)
|
449 |
+
input = (x, init_logp)
|
450 |
+
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
451 |
+
drift, delta_logp = drift[-1], delta_logp[-1]
|
452 |
+
prior_logp = self.transport.prior_logp(drift)
|
453 |
+
logp = prior_logp - delta_logp
|
454 |
+
return logp, drift
|
455 |
+
|
456 |
+
return _sample_fn
|
paintmind/stage1/transport/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
|
3 |
+
class EasyDict:
|
4 |
+
|
5 |
+
def __init__(self, sub_dict):
|
6 |
+
for k, v in sub_dict.items():
|
7 |
+
setattr(self, k, v)
|
8 |
+
|
9 |
+
def __getitem__(self, key):
|
10 |
+
return getattr(self, key)
|
11 |
+
|
12 |
+
def mean_flat(x):
|
13 |
+
"""
|
14 |
+
Take the mean over all non-batch dimensions.
|
15 |
+
"""
|
16 |
+
return th.mean(x, dim=list(range(1, len(x.size()))))
|
17 |
+
|
18 |
+
def log_state(state):
|
19 |
+
result = []
|
20 |
+
|
21 |
+
sorted_state = dict(sorted(state.items()))
|
22 |
+
for key, value in sorted_state.items():
|
23 |
+
# Check if the value is an instance of a class
|
24 |
+
if "<object" in str(value) or "object at" in str(value):
|
25 |
+
result.append(f"{key}: [{value.__class__.__name__}]")
|
26 |
+
else:
|
27 |
+
result.append(f"{key}: {value}")
|
28 |
+
|
29 |
+
return '\n'.join(result)
|
paintmind/stage1/vision_transformers.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Mostly copy-paste from DINO and timm library:
|
9 |
+
https://github.com/facebookresearch/dino
|
10 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
11 |
+
"""
|
12 |
+
|
13 |
+
import math
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from functools import partial
|
19 |
+
from paintmind.stage1.fused_attention import Attention
|
20 |
+
|
21 |
+
__all__ = ['VisionTransformer', 'vit_tiny_patch16', 'vit_small_patch16',
|
22 |
+
'vit_base_patch16', 'vit_large_patch16', 'vit_huge_patch14']
|
23 |
+
|
24 |
+
|
25 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
26 |
+
if drop_prob == 0. or not training:
|
27 |
+
return x
|
28 |
+
keep_prob = 1 - drop_prob
|
29 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
30 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
31 |
+
if keep_prob > 0.0:
|
32 |
+
random_tensor.div_(keep_prob)
|
33 |
+
return x * random_tensor
|
34 |
+
|
35 |
+
|
36 |
+
class DropPath(nn.Module):
|
37 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, drop_prob=None):
|
41 |
+
super(DropPath, self).__init__()
|
42 |
+
self.drop_prob = drop_prob
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return drop_path(x, self.drop_prob, self.training)
|
46 |
+
|
47 |
+
|
48 |
+
class Mlp(nn.Module):
|
49 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
50 |
+
super().__init__()
|
51 |
+
out_features = out_features or in_features
|
52 |
+
hidden_features = hidden_features or in_features
|
53 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
54 |
+
self.act = act_layer()
|
55 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
56 |
+
self.drop = nn.Dropout(drop)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = self.fc1(x)
|
60 |
+
x = self.act(x)
|
61 |
+
x = self.drop(x)
|
62 |
+
x = self.fc2(x)
|
63 |
+
x = self.drop(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class Block(nn.Module):
|
68 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
|
69 |
+
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0):
|
70 |
+
super().__init__()
|
71 |
+
self.norm1 = norm_layer(dim)
|
72 |
+
self.attn = Attention(
|
73 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
74 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
75 |
+
self.norm2 = norm_layer(dim)
|
76 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
77 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
78 |
+
|
79 |
+
if init_values > 0:
|
80 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
81 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
82 |
+
else:
|
83 |
+
self.gamma_1, self.gamma_2 = None, None
|
84 |
+
|
85 |
+
def forward(self, x, attn_mask=None):
|
86 |
+
y = self.attn(self.norm1(x), attn_mask=attn_mask)
|
87 |
+
if self.gamma_1 is None:
|
88 |
+
x = x + self.drop_path(y)
|
89 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
90 |
+
else:
|
91 |
+
x = x + self.drop_path(self.gamma_1 * y)
|
92 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
class PatchEmbed(nn.Module):
|
97 |
+
""" Image to Patch Embedding
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
101 |
+
super().__init__()
|
102 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
103 |
+
self.img_size = img_size
|
104 |
+
self.patch_size = patch_size
|
105 |
+
self.num_patches = num_patches
|
106 |
+
|
107 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
B, C, H, W = x.shape
|
111 |
+
return self.proj(x)
|
112 |
+
|
113 |
+
|
114 |
+
class VisionTransformer(nn.Module):
|
115 |
+
""" Vision Transformer """
|
116 |
+
|
117 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12,
|
118 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
|
119 |
+
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
120 |
+
init_values=0, num_slots=16, slot_through=True):
|
121 |
+
super().__init__()
|
122 |
+
self.num_features = self.embed_dim = embed_dim
|
123 |
+
|
124 |
+
self.patch_embed = PatchEmbed(
|
125 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
126 |
+
num_patches = self.patch_embed.num_patches
|
127 |
+
|
128 |
+
self.num_slots = num_slots if slot_through else 0
|
129 |
+
self.slot_through = slot_through
|
130 |
+
|
131 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
132 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + self.num_slots, embed_dim))
|
133 |
+
if self.slot_through:
|
134 |
+
self.slot_embed = nn.Parameter(torch.zeros(1, num_slots, embed_dim))
|
135 |
+
|
136 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
137 |
+
|
138 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
139 |
+
self.blocks = nn.ModuleList([
|
140 |
+
Block(
|
141 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
142 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
143 |
+
init_values=init_values)
|
144 |
+
for i in range(depth)])
|
145 |
+
|
146 |
+
self.norm = norm_layer(embed_dim)
|
147 |
+
|
148 |
+
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
149 |
+
nn.init.trunc_normal_(self.cls_token, std=.02)
|
150 |
+
if self.slot_through:
|
151 |
+
nn.init.trunc_normal_(self.slot_embed, std=.02)
|
152 |
+
self.apply(self._init_weights)
|
153 |
+
|
154 |
+
def _init_weights(self, m):
|
155 |
+
if isinstance(m, nn.Linear):
|
156 |
+
nn.init.trunc_normal_(m.weight, std=.02)
|
157 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
158 |
+
nn.init.constant_(m.bias, 0)
|
159 |
+
elif isinstance(m, nn.LayerNorm):
|
160 |
+
nn.init.constant_(m.bias, 0)
|
161 |
+
nn.init.constant_(m.weight, 1.0)
|
162 |
+
|
163 |
+
def interpolate_pos_encoding(self, x, w, h):
|
164 |
+
npatch = x.shape[1] - 1 - self.num_slots
|
165 |
+
N = self.pos_embed.shape[1] - 1 - self.num_slots
|
166 |
+
if npatch == N and w == h:
|
167 |
+
return self.pos_embed
|
168 |
+
class_pos_embed = self.pos_embed[:, 0]
|
169 |
+
patch_pos_embed = self.pos_embed[:, 1:1+npatch]
|
170 |
+
dim = x.shape[-1]
|
171 |
+
w0 = w // self.patch_embed.patch_size[0]
|
172 |
+
h0 = h // self.patch_embed.patch_size[1]
|
173 |
+
# we add a small number to avoid floating point error in the interpolation
|
174 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
175 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
176 |
+
patch_pos_embed = nn.functional.interpolate(
|
177 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
178 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
179 |
+
mode='bicubic',
|
180 |
+
)
|
181 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
182 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
183 |
+
|
184 |
+
if self.slot_through:
|
185 |
+
slots_pos_embed = self.pos_embed[:, 1+npatch:]
|
186 |
+
slots_pos_embed = slots_pos_embed.view(1, -1, dim) # (1, num_slots, dim)
|
187 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, slots_pos_embed), dim=1)
|
188 |
+
|
189 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
190 |
+
|
191 |
+
def prepare_tokens(self, x):
|
192 |
+
B, nc, w, h = x.shape
|
193 |
+
x = self.patch_embed(x)
|
194 |
+
x = x.flatten(2).transpose(1, 2)
|
195 |
+
if self.slot_through:
|
196 |
+
x = torch.cat((self.cls_token.expand(B, -1, -1), x, self.slot_embed.expand(B, -1, -1)), dim=1)
|
197 |
+
else:
|
198 |
+
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
|
199 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
200 |
+
return self.pos_drop(x)
|
201 |
+
|
202 |
+
def forward(self, x, is_causal=True):
|
203 |
+
x = self.prepare_tokens(x)
|
204 |
+
if is_causal and self.slot_through:
|
205 |
+
attn_mask = torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool)
|
206 |
+
# slots are causal to each other
|
207 |
+
causal_mask = torch.ones(self.num_slots, self.num_slots, device=x.device, dtype=torch.bool).tril(diagonal=0)
|
208 |
+
attn_mask[-self.num_slots:, -self.num_slots:] = causal_mask
|
209 |
+
# cls token and patches should not see slots
|
210 |
+
attn_mask[:-self.num_slots, -self.num_slots:] = False
|
211 |
+
else:
|
212 |
+
attn_mask = None
|
213 |
+
|
214 |
+
for blk in self.blocks:
|
215 |
+
x = blk(x, attn_mask=attn_mask)
|
216 |
+
|
217 |
+
x = self.norm(x)
|
218 |
+
if self.slot_through:
|
219 |
+
outcome = x[:, -self.num_slots:] # return the slots
|
220 |
+
else:
|
221 |
+
outcome = x[:, 1:] # return the patches
|
222 |
+
return outcome
|
223 |
+
|
224 |
+
def get_intermediate_layers(self, x, n=1):
|
225 |
+
x = self.prepare_tokens(x)
|
226 |
+
# we return the output tokens from the `n` last blocks
|
227 |
+
output = []
|
228 |
+
for i, blk in enumerate(self.blocks):
|
229 |
+
x = blk(x)
|
230 |
+
if len(self.blocks) - i <= n:
|
231 |
+
output.append(self.norm(x))
|
232 |
+
return output
|
233 |
+
|
234 |
+
|
235 |
+
def vit_tiny_patch16(**kwargs):
|
236 |
+
model = VisionTransformer(
|
237 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
238 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
def vit_small_patch16(**kwargs):
|
243 |
+
model = VisionTransformer(
|
244 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
245 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
246 |
+
return model
|
247 |
+
|
248 |
+
|
249 |
+
def vit_base_patch16(**kwargs):
|
250 |
+
model = VisionTransformer(
|
251 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
252 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
253 |
+
return model
|
254 |
+
|
255 |
+
|
256 |
+
def vit_large_patch16(**kwargs):
|
257 |
+
model = VisionTransformer(
|
258 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
259 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
260 |
+
return model
|
261 |
+
|
262 |
+
|
263 |
+
def vit_huge_patch14(**kwargs):
|
264 |
+
model = VisionTransformer(
|
265 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
266 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
267 |
+
return model
|
paintmind/stage2/__init__.py
ADDED
File without changes
|
paintmind/stage2/causaldit.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from timm.models.vision_transformer import Mlp
|
15 |
+
from paintmind.stage1.diffusion_transfomers import TimestepEmbedder, LabelEmbedder, FinalLayer, modulate
|
16 |
+
from paintmind.stage1.diffusion import create_diffusion
|
17 |
+
from paintmind.stage1.transport import create_transport, Sampler
|
18 |
+
|
19 |
+
|
20 |
+
class GeneralizedCausalAttention(nn.Module):
|
21 |
+
def __init__(self, dim, num_heads, norm_layer=nn.LayerNorm):
|
22 |
+
super().__init__()
|
23 |
+
assert dim % num_heads == 0
|
24 |
+
self.num_heads = num_heads
|
25 |
+
self.head_dim = dim // num_heads
|
26 |
+
self.scale = self.head_dim ** -0.5
|
27 |
+
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
|
28 |
+
self.proj = nn.Linear(dim, dim)
|
29 |
+
self.q_norm = norm_layer(self.head_dim)
|
30 |
+
self.k_norm = norm_layer(self.head_dim)
|
31 |
+
|
32 |
+
def _forward_kv_cache(
|
33 |
+
self,
|
34 |
+
x: torch.Tensor,
|
35 |
+
layer_index: int,
|
36 |
+
kv_cache: dict,
|
37 |
+
update_kv_cache: bool = False,
|
38 |
+
):
|
39 |
+
N, Lq = x.shape[:2]
|
40 |
+
qkv = self.qkv(x).reshape(N, Lq, 3, self.num_heads, self.head_dim)
|
41 |
+
q, curr_k, curr_v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # N, nhead, Lq, dhead
|
42 |
+
q = self.q_norm(q)
|
43 |
+
curr_k = self.k_norm(curr_k)
|
44 |
+
|
45 |
+
if kv_cache[layer_index]["k"] is not None:
|
46 |
+
k = kv_cache[layer_index]["k"]
|
47 |
+
v = kv_cache[layer_index]["v"]
|
48 |
+
k = torch.cat((k, curr_k), dim=2)
|
49 |
+
v = torch.cat((v, curr_v), dim=2)
|
50 |
+
else:
|
51 |
+
k = curr_k
|
52 |
+
v = curr_v
|
53 |
+
|
54 |
+
if update_kv_cache:
|
55 |
+
kv_cache[layer_index]["k"] = k
|
56 |
+
kv_cache[layer_index]["v"] = v
|
57 |
+
|
58 |
+
return self._forward_sdpa(q, k, v, attn_mask=None)
|
59 |
+
|
60 |
+
def _forward(self, x, attn_mask):
|
61 |
+
N, L = x.shape[:2]
|
62 |
+
qkv = self.qkv(x).reshape(N, L, 3, self.num_heads, self.head_dim)
|
63 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # N, nhead, L, dhead
|
64 |
+
q = self.q_norm(q)
|
65 |
+
k = self.k_norm(k)
|
66 |
+
return self._forward_sdpa(q, k, v, attn_mask)
|
67 |
+
|
68 |
+
def _forward_sdpa(self, q, k, v, attn_mask):
|
69 |
+
N, _, Lq, _ = q.shape
|
70 |
+
x = F.scaled_dot_product_attention(
|
71 |
+
q, k, v,
|
72 |
+
attn_mask=attn_mask,
|
73 |
+
)
|
74 |
+
|
75 |
+
x = x.transpose(1, 2).reshape(N, Lq, -1)
|
76 |
+
x = self.proj(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
x: torch.Tensor,
|
82 |
+
attn_mask: Optional[torch.Tensor] = None,
|
83 |
+
kv_cache: Optional[dict] = None,
|
84 |
+
layer_index: Optional[int] = None,
|
85 |
+
update_kv_cache: Optional[bool] = None,
|
86 |
+
) -> torch.Tensor:
|
87 |
+
if kv_cache is not None:
|
88 |
+
return self._forward_kv_cache(
|
89 |
+
x,
|
90 |
+
kv_cache=kv_cache,
|
91 |
+
layer_index=layer_index,
|
92 |
+
update_kv_cache=update_kv_cache
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
return self._forward(x, attn_mask)
|
96 |
+
|
97 |
+
|
98 |
+
class DiTBlock(nn.Module):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
hidden_size,
|
102 |
+
num_heads,
|
103 |
+
mlp_ratio=4.0,
|
104 |
+
norm_layer=nn.LayerNorm,
|
105 |
+
causal_fusion=False,
|
106 |
+
deep_supervision=False,
|
107 |
+
output_dim=None,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
111 |
+
self.attn = GeneralizedCausalAttention(hidden_size, num_heads=num_heads, norm_layer=norm_layer)
|
112 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
113 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
114 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
115 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
116 |
+
self.causal_fusion = causal_fusion
|
117 |
+
if not causal_fusion:
|
118 |
+
self.adaLN_modulation = nn.Sequential(
|
119 |
+
nn.SiLU(),
|
120 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
121 |
+
)
|
122 |
+
self.deep_supervision = deep_supervision
|
123 |
+
if deep_supervision:
|
124 |
+
if not causal_fusion:
|
125 |
+
self.final_layer = FinalLayer(hidden_size, 1, output_dim)
|
126 |
+
else:
|
127 |
+
self.final_layer = nn.Sequential(
|
128 |
+
nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6),
|
129 |
+
nn.Linear(hidden_size, output_dim)
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward_causal_dit(self, x, c, **kwargs):
|
133 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
134 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), **kwargs)
|
135 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
136 |
+
if self.deep_supervision and self.training:
|
137 |
+
return x, self.final_layer(x, c)
|
138 |
+
else:
|
139 |
+
return x
|
140 |
+
|
141 |
+
def forward_causal_fusion(self, x, **kwargs):
|
142 |
+
x = x + self.attn(self.norm1(x), **kwargs)
|
143 |
+
x = x + self.mlp(self.norm2(x))
|
144 |
+
if self.deep_supervision and self.training:
|
145 |
+
return x, self.final_layer(x)
|
146 |
+
else:
|
147 |
+
return x
|
148 |
+
|
149 |
+
def forward(self, x, c=None, **kwargs):
|
150 |
+
if self.causal_fusion:
|
151 |
+
return self.forward_causal_fusion(x, **kwargs)
|
152 |
+
else:
|
153 |
+
return self.forward_causal_dit(x, c, **kwargs)
|
154 |
+
|
155 |
+
class CausalDiT(nn.Module):
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
num_slots=16,
|
159 |
+
slot_dim=256,
|
160 |
+
hidden_size=1152,
|
161 |
+
depth=28,
|
162 |
+
num_heads=16,
|
163 |
+
mlp_ratio=4.0,
|
164 |
+
class_dropout_prob=0.1,
|
165 |
+
num_classes=1000,
|
166 |
+
num_sampling_steps='250',
|
167 |
+
use_si=False,
|
168 |
+
predict_xstart=False,
|
169 |
+
causal_fusion=False,
|
170 |
+
deep_supervision=False,
|
171 |
+
cls_token_num=0,
|
172 |
+
**kwargs
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.num_slots = num_slots
|
176 |
+
self.slot_dim = slot_dim
|
177 |
+
self.num_heads = num_heads
|
178 |
+
self.hidden_size = hidden_size
|
179 |
+
self.num_classes = num_classes
|
180 |
+
self.output_dim = slot_dim * 2 if not use_si else slot_dim
|
181 |
+
self.x_embedder = nn.Linear(slot_dim, hidden_size)
|
182 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
183 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
184 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_slots, hidden_size))
|
185 |
+
blocks = [
|
186 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, causal_fusion=causal_fusion,
|
187 |
+
deep_supervision=deep_supervision, output_dim=self.output_dim) for _ in range(depth - 1)
|
188 |
+
]
|
189 |
+
blocks.append(DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, causal_fusion=causal_fusion))
|
190 |
+
self.blocks = nn.ModuleList(blocks)
|
191 |
+
self.cls_token_num = cls_token_num
|
192 |
+
self.causal_fusion = causal_fusion
|
193 |
+
self.deep_supervision = deep_supervision
|
194 |
+
if not causal_fusion:
|
195 |
+
self.final_layer = FinalLayer(hidden_size, 1, self.output_dim)
|
196 |
+
else:
|
197 |
+
self.cond_pos_embed = nn.Parameter(torch.zeros(1, cls_token_num, hidden_size))
|
198 |
+
self.final_layer = nn.Sequential(
|
199 |
+
nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6),
|
200 |
+
nn.Linear(hidden_size, self.output_dim)
|
201 |
+
)
|
202 |
+
|
203 |
+
self.initialize_weights()
|
204 |
+
|
205 |
+
self.use_si = use_si
|
206 |
+
if not use_si:
|
207 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", predict_xstart=predict_xstart)
|
208 |
+
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, predict_xstart=predict_xstart)
|
209 |
+
else:
|
210 |
+
self.transport = create_transport()
|
211 |
+
self.sampler = Sampler(self.transport)
|
212 |
+
|
213 |
+
|
214 |
+
def initialize_weights(self):
|
215 |
+
# Initialize transformer layers:
|
216 |
+
def _basic_init(module):
|
217 |
+
if isinstance(module, nn.Linear):
|
218 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
219 |
+
if module.bias is not None:
|
220 |
+
nn.init.constant_(module.bias, 0)
|
221 |
+
if self.causal_fusion and isinstance(module, nn.LayerNorm):
|
222 |
+
if module.weight is not None:
|
223 |
+
nn.init.constant_(module.weight, 1.0)
|
224 |
+
if module.bias is not None:
|
225 |
+
nn.init.constant_(module.bias, 0)
|
226 |
+
self.apply(_basic_init)
|
227 |
+
|
228 |
+
# Initialize pos_embed:
|
229 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
230 |
+
# Initialize label embedding table:
|
231 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
232 |
+
if self.causal_fusion:
|
233 |
+
nn.init.normal_(self.cond_pos_embed, std=0.02)
|
234 |
+
|
235 |
+
# Initialize timestep embedding MLP:
|
236 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
237 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
238 |
+
|
239 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
240 |
+
for i, block in enumerate(self.blocks):
|
241 |
+
if not self.causal_fusion:
|
242 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
243 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
244 |
+
if self.deep_supervision and i < len(self.blocks) - 1:
|
245 |
+
if not self.causal_fusion:
|
246 |
+
nn.init.constant_(block.final_layer.adaLN_modulation[-1].weight, 0)
|
247 |
+
nn.init.constant_(block.final_layer.adaLN_modulation[-1].bias, 0)
|
248 |
+
nn.init.constant_(block.final_layer.linear.weight, 0)
|
249 |
+
nn.init.constant_(block.final_layer.linear.bias, 0)
|
250 |
+
else:
|
251 |
+
nn.init.constant_(block.final_layer[1].weight, 0)
|
252 |
+
nn.init.constant_(block.final_layer[1].bias, 0)
|
253 |
+
|
254 |
+
# Zero-out output layers:
|
255 |
+
if not self.causal_fusion:
|
256 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
257 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
258 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
259 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
260 |
+
else:
|
261 |
+
nn.init.constant_(self.final_layer[1].weight, 0)
|
262 |
+
nn.init.constant_(self.final_layer[1].bias, 0)
|
263 |
+
|
264 |
+
def forward_cache_update(self, x, kv_cache):
|
265 |
+
for idx, block in enumerate(self.blocks):
|
266 |
+
x = block(x, layer_index=idx, kv_cache=kv_cache, update_kv_cache=True)
|
267 |
+
return None
|
268 |
+
|
269 |
+
def _forward_inference(self, xn, t, y, kv_cache, pos_embed, attn_mask, context):
|
270 |
+
xn = xn.transpose(1, 2)
|
271 |
+
if context is not None:
|
272 |
+
xn = torch.cat([context, xn], dim=1)
|
273 |
+
xn = self.x_embedder(xn) + pos_embed
|
274 |
+
y = self.y_embedder(y, self.training)
|
275 |
+
t = self.t_embedder(t)
|
276 |
+
if not self.causal_fusion:
|
277 |
+
c = t + y
|
278 |
+
else:
|
279 |
+
y = y.unsqueeze(1).expand(-1, self.cls_token_num, -1) + self.cond_pos_embed
|
280 |
+
xn = torch.cat([y, xn], dim=1)
|
281 |
+
t = t.unsqueeze(1)
|
282 |
+
xn = xn + t
|
283 |
+
c = None
|
284 |
+
|
285 |
+
for idx, block in enumerate(self.blocks):
|
286 |
+
xn = block(xn, c, attn_mask=attn_mask, layer_index=idx, kv_cache=kv_cache, update_kv_cache=False)
|
287 |
+
|
288 |
+
xn = xn[:, -1].unsqueeze(1)
|
289 |
+
xn = self._forward_final_layer(xn, c)
|
290 |
+
return xn.transpose(1, 2)
|
291 |
+
|
292 |
+
def _forward_inference_with_cfg(self, xn, t, y, kv_cache, pos_embed, attn_mask, context, cfg_scale):
|
293 |
+
half = xn[: len(xn) // 2]
|
294 |
+
combined = torch.cat([half, half], dim=0)
|
295 |
+
model_out = self._forward_inference(combined, t, y, kv_cache, pos_embed, attn_mask, context)
|
296 |
+
eps, rest = model_out[:, :self.slot_dim], model_out[:, self.slot_dim:]
|
297 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
298 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
299 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
300 |
+
return torch.cat([eps, rest], dim=1)
|
301 |
+
|
302 |
+
def sample(self, y, pos_embed, attn_mask, context=None, cfg=1.0):
|
303 |
+
# diffusion loss sampling
|
304 |
+
device = y.device
|
305 |
+
if not cfg == 1.0:
|
306 |
+
noise = torch.randn(y.shape[0] // 2, self.slot_dim, 1, device=device)
|
307 |
+
noise = torch.cat([noise, noise], dim=0)
|
308 |
+
model_kwargs = dict(y=y, kv_cache=None, pos_embed=pos_embed, attn_mask=attn_mask, context=context, cfg_scale=cfg)
|
309 |
+
sample_fn = self._forward_inference_with_cfg
|
310 |
+
else:
|
311 |
+
noise = torch.randn(y.shape[0], self.slot_dim, 1, device=device)
|
312 |
+
model_kwargs = dict(y=y, kv_cache=None, pos_embed=pos_embed, attn_mask=attn_mask, context=context)
|
313 |
+
sample_fn = self._forward_inference
|
314 |
+
|
315 |
+
if not self.use_si:
|
316 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
317 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
|
318 |
+
device=device
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma")
|
322 |
+
sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1]
|
323 |
+
|
324 |
+
return sampled_token_latent.transpose(1, 2)
|
325 |
+
|
326 |
+
def _forward_final_layer(self, x, c):
|
327 |
+
if not self.causal_fusion:
|
328 |
+
return self.final_layer(x, c)
|
329 |
+
else:
|
330 |
+
return self.final_layer(x)
|
331 |
+
|
332 |
+
def forward_train(self, xn, t, y, xc):
|
333 |
+
"""
|
334 |
+
Args:
|
335 |
+
xn: noised latent
|
336 |
+
t: time step
|
337 |
+
y: condition
|
338 |
+
xc: clean latent
|
339 |
+
"""
|
340 |
+
|
341 |
+
xc = self.x_embedder(xc.transpose(1, 2)) + self.pos_embed
|
342 |
+
xn = self.x_embedder(xn.transpose(1, 2)) + self.pos_embed
|
343 |
+
|
344 |
+
t = self.t_embedder(t)
|
345 |
+
y = self.y_embedder(y, self.training)
|
346 |
+
if not self.causal_fusion:
|
347 |
+
c = t + y
|
348 |
+
else:
|
349 |
+
y = y.unsqueeze(1).expand(-1, self.cls_token_num, -1) + self.cond_pos_embed
|
350 |
+
t = t.unsqueeze(1)
|
351 |
+
xc = torch.cat((y, xc), dim=1)
|
352 |
+
xn = xn + t
|
353 |
+
c = None
|
354 |
+
|
355 |
+
# forward transformer
|
356 |
+
x = torch.cat((xc, xn), dim=1)
|
357 |
+
attn_mask = get_attn_mask(self.cls_token_num, self.num_slots).to(x.device)
|
358 |
+
|
359 |
+
if self.deep_supervision and self.training:
|
360 |
+
xs = []
|
361 |
+
for block in self.blocks[:-1]:
|
362 |
+
x, x_hat = block(x, c=c, attn_mask=attn_mask)
|
363 |
+
xs.append(x_hat[:, -self.num_slots:])
|
364 |
+
x = self.blocks[-1](x, c=c, attn_mask=attn_mask)
|
365 |
+
x = self._forward_final_layer(x[:, -self.num_slots:], c)
|
366 |
+
xs.append(x)
|
367 |
+
# N, B, L, C -> B, C, L, N
|
368 |
+
return torch.stack(xs, dim=0).permute(1, 3, 2, 0)
|
369 |
+
else:
|
370 |
+
for block in self.blocks:
|
371 |
+
x = block(x, c=c, attn_mask=attn_mask)
|
372 |
+
x = x[:, -self.num_slots:]
|
373 |
+
return self._forward_final_layer(x, c).transpose(1, 2)
|
374 |
+
|
375 |
+
def forward(self, slots, targets):
|
376 |
+
slots = slots.transpose(1, 2)
|
377 |
+
model_kwargs = dict(y=targets, xc=slots)
|
378 |
+
if not self.use_si:
|
379 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (slots.shape[0],), device=slots.device)
|
380 |
+
loss_dict = self.train_diffusion.training_losses(self.forward_train, slots, t, model_kwargs)
|
381 |
+
else:
|
382 |
+
loss_dict = self.transport.training_losses(self.forward_train, slots, model_kwargs)
|
383 |
+
loss = loss_dict["loss"]
|
384 |
+
return loss.mean()
|
385 |
+
|
386 |
+
|
387 |
+
def CausalDiT_L(**kwargs):
|
388 |
+
return CausalDiT(depth=24, hidden_size=1024, num_heads=16, **kwargs)
|
389 |
+
|
390 |
+
|
391 |
+
def CausalDiT_XL(**kwargs):
|
392 |
+
return CausalDiT(depth=32, hidden_size=1280, num_heads=20, **kwargs)
|
393 |
+
|
394 |
+
|
395 |
+
def CausalDiT_H(**kwargs):
|
396 |
+
return CausalDiT(depth=48, hidden_size=1408, num_heads=22, **kwargs)
|
397 |
+
|
398 |
+
|
399 |
+
CausalDiT_models = {
|
400 |
+
"CausalDiT-L": CausalDiT_L,
|
401 |
+
"CausalDiT-XL": CausalDiT_XL,
|
402 |
+
"CausalDiT-H": CausalDiT_H
|
403 |
+
}
|
404 |
+
|
405 |
+
def get_attn_mask(context_len, sample_len):
|
406 |
+
padder_len = 1 if context_len <= 0 else context_len
|
407 |
+
seq_len = padder_len + sample_len * 2
|
408 |
+
attn_mask = torch.eye(seq_len, dtype=bool)
|
409 |
+
triangle = torch.ones(sample_len, sample_len, dtype=bool).tril()
|
410 |
+
attn_mask[-sample_len:, :sample_len] = triangle
|
411 |
+
if padder_len == 1:
|
412 |
+
return attn_mask[1:, 1:]
|
413 |
+
else:
|
414 |
+
return attn_mask
|
415 |
+
|
416 |
+
# if context_len == 0
|
417 |
+
# 100000
|
418 |
+
# 010000
|
419 |
+
# 001000
|
420 |
+
# 000100
|
421 |
+
# 100010
|
422 |
+
# 110001
|
paintmind/stage2/diffloss.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
import math
|
5 |
+
|
6 |
+
from paintmind.stage1.diffusion import create_diffusion
|
7 |
+
from paintmind.stage1.transport import create_transport, Sampler
|
8 |
+
|
9 |
+
|
10 |
+
class DiffLoss(nn.Module):
|
11 |
+
"""Diffusion Loss"""
|
12 |
+
def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, predict_xstart=False, use_si=False, deep_supervision=False, token_drop_prob=0.0, cond_method="adaln", decoupled_cfg=True):
|
13 |
+
super(DiffLoss, self).__init__()
|
14 |
+
self.in_channels = target_channels
|
15 |
+
self.net = SimpleMLPAdaLN(
|
16 |
+
in_channels=target_channels,
|
17 |
+
model_channels=width,
|
18 |
+
out_channels=target_channels * 2 if not use_si else target_channels, # for vlb loss
|
19 |
+
z_channels=z_channels,
|
20 |
+
num_res_blocks=depth,
|
21 |
+
deep_supervision=deep_supervision,
|
22 |
+
token_drop_prob=token_drop_prob,
|
23 |
+
cond_method=cond_method,
|
24 |
+
decoupled_cfg=decoupled_cfg,
|
25 |
+
)
|
26 |
+
self.use_si = use_si
|
27 |
+
if not use_si:
|
28 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine", predict_xstart=predict_xstart)
|
29 |
+
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine", predict_xstart=predict_xstart)
|
30 |
+
else:
|
31 |
+
self.transport = create_transport()
|
32 |
+
self.sampler = Sampler(self.transport)
|
33 |
+
|
34 |
+
def forward(self, target, z, mask=None):
|
35 |
+
model_kwargs = dict(c=z)
|
36 |
+
if not self.use_si:
|
37 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
38 |
+
loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
|
39 |
+
else:
|
40 |
+
loss_dict = self.transport.training_losses(self.net, target, model_kwargs)
|
41 |
+
loss = loss_dict["loss"]
|
42 |
+
if mask is not None:
|
43 |
+
loss = (loss * mask).sum() / mask.sum()
|
44 |
+
return loss.mean()
|
45 |
+
|
46 |
+
def sample(self, z, temperature=1.0, cfg=1.0):
|
47 |
+
# diffusion loss sampling
|
48 |
+
device = z.device
|
49 |
+
if not cfg == 1.0:
|
50 |
+
noise = torch.randn(z.shape[0] // 2, self.in_channels, device=device)
|
51 |
+
noise = torch.cat([noise, noise], dim=0)
|
52 |
+
model_kwargs = dict(c=z, cfg_scale=cfg)
|
53 |
+
sample_fn = self.net.forward_with_cfg
|
54 |
+
else:
|
55 |
+
noise = torch.randn(z.shape[0], self.in_channels, device=device)
|
56 |
+
model_kwargs = dict(c=z)
|
57 |
+
sample_fn = self.net.forward
|
58 |
+
|
59 |
+
if not self.use_si:
|
60 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
61 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
|
62 |
+
temperature=temperature, device=device
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma", temperature=temperature)
|
66 |
+
sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1]
|
67 |
+
if cfg != 1.0:
|
68 |
+
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
|
69 |
+
return sampled_token_latent
|
70 |
+
|
71 |
+
|
72 |
+
def modulate(x, shift, scale):
|
73 |
+
return x * (1 + scale) + shift
|
74 |
+
|
75 |
+
|
76 |
+
class TimestepEmbedder(nn.Module):
|
77 |
+
"""
|
78 |
+
Embeds scalar timesteps into vector representations.
|
79 |
+
"""
|
80 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
81 |
+
super().__init__()
|
82 |
+
self.mlp = nn.Sequential(
|
83 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
84 |
+
nn.SiLU(),
|
85 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
86 |
+
)
|
87 |
+
self.frequency_embedding_size = frequency_embedding_size
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def timestep_embedding(t, dim, max_period=10000):
|
91 |
+
"""
|
92 |
+
Create sinusoidal timestep embeddings.
|
93 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
94 |
+
These may be fractional.
|
95 |
+
:param dim: the dimension of the output.
|
96 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
97 |
+
:return: an (N, D) Tensor of positional embeddings.
|
98 |
+
"""
|
99 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
100 |
+
half = dim // 2
|
101 |
+
freqs = torch.exp(
|
102 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
103 |
+
).to(device=t.device)
|
104 |
+
args = t[:, None].float() * freqs[None]
|
105 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
106 |
+
if dim % 2:
|
107 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
108 |
+
return embedding
|
109 |
+
|
110 |
+
def forward(self, t):
|
111 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
112 |
+
t_emb = self.mlp(t_freq)
|
113 |
+
return t_emb
|
114 |
+
|
115 |
+
|
116 |
+
class ResBlock(nn.Module):
|
117 |
+
"""
|
118 |
+
A residual block with AdaLN for timestep and optional concatenation for condition.
|
119 |
+
"""
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
channels,
|
123 |
+
out_channels=None,
|
124 |
+
deep_supervision=False,
|
125 |
+
cond_method="adaln",
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.channels = channels
|
129 |
+
self.deep_supervision = deep_supervision
|
130 |
+
self.cond_method = cond_method
|
131 |
+
|
132 |
+
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
133 |
+
self.adaLN_modulation = nn.Sequential(
|
134 |
+
nn.SiLU(),
|
135 |
+
nn.Linear(channels, 3 * channels, bias=True)
|
136 |
+
)
|
137 |
+
|
138 |
+
# Input dimension depends on conditioning method
|
139 |
+
mlp_in_dim = channels * 2 if cond_method == "concat" else channels
|
140 |
+
self.mlp = nn.Sequential(
|
141 |
+
nn.Linear(mlp_in_dim, channels, bias=True),
|
142 |
+
nn.SiLU(),
|
143 |
+
nn.Linear(channels, channels, bias=True),
|
144 |
+
)
|
145 |
+
|
146 |
+
if self.deep_supervision:
|
147 |
+
self.final_layer = FinalLayer(channels, out_channels, cond_method=cond_method)
|
148 |
+
|
149 |
+
def forward(self, x, t, c=None):
|
150 |
+
# Apply timestep embedding via AdaLN
|
151 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(3, dim=-1)
|
152 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
153 |
+
|
154 |
+
# Concatenate condition if using concat method
|
155 |
+
if self.cond_method == "concat" and c is not None:
|
156 |
+
h = torch.cat([h, c], dim=-1)
|
157 |
+
|
158 |
+
h = self.mlp(h)
|
159 |
+
x = x + gate_mlp * h
|
160 |
+
|
161 |
+
if self.deep_supervision and self.training:
|
162 |
+
return x, self.final_layer(x, t, c)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class FinalLayer(nn.Module):
|
167 |
+
"""
|
168 |
+
Final layer with AdaLN for timestep and optional concatenation for condition.
|
169 |
+
"""
|
170 |
+
def __init__(self, model_channels, out_channels, cond_method="adaln"):
|
171 |
+
super().__init__()
|
172 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
173 |
+
self.cond_method = cond_method
|
174 |
+
|
175 |
+
self.adaLN_modulation = nn.Sequential(
|
176 |
+
nn.SiLU(),
|
177 |
+
nn.Linear(model_channels, 2 * model_channels, bias=True)
|
178 |
+
)
|
179 |
+
|
180 |
+
# Output dimension depends on conditioning method
|
181 |
+
linear_in_dim = model_channels * 2 if cond_method == "concat" else model_channels
|
182 |
+
self.linear = nn.Linear(linear_in_dim, out_channels, bias=True)
|
183 |
+
|
184 |
+
def forward(self, x, t, c=None):
|
185 |
+
# Apply timestep embedding via AdaLN
|
186 |
+
shift, scale = self.adaLN_modulation(t).chunk(2, dim=-1)
|
187 |
+
x = modulate(self.norm_final(x), shift, scale)
|
188 |
+
|
189 |
+
# Concatenate condition if using concat method
|
190 |
+
if self.cond_method == "concat" and c is not None:
|
191 |
+
x = torch.cat([x, c], dim=-1)
|
192 |
+
|
193 |
+
return self.linear(x)
|
194 |
+
|
195 |
+
|
196 |
+
class SimpleMLPAdaLN(nn.Module):
|
197 |
+
"""
|
198 |
+
MLP for Diffusion Loss with AdaLN for timestep and optional concatenation for condition.
|
199 |
+
"""
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
in_channels,
|
203 |
+
model_channels,
|
204 |
+
out_channels,
|
205 |
+
z_channels,
|
206 |
+
num_res_blocks,
|
207 |
+
deep_supervision=False,
|
208 |
+
token_drop_prob=0.0,
|
209 |
+
cond_method="adaln",
|
210 |
+
decoupled_cfg=True,
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
self.in_channels = in_channels
|
214 |
+
self.model_channels = model_channels
|
215 |
+
self.out_channels = out_channels
|
216 |
+
self.deep_supervision = deep_supervision
|
217 |
+
self.token_drop_prob = token_drop_prob
|
218 |
+
self.cond_method = cond_method
|
219 |
+
self.decoupled_cfg = decoupled_cfg
|
220 |
+
if decoupled_cfg and token_drop_prob > 0.0:
|
221 |
+
self.null_token = nn.Parameter(torch.zeros(1, z_channels))
|
222 |
+
|
223 |
+
self.time_embed = TimestepEmbedder(model_channels)
|
224 |
+
self.cond_embed = nn.Linear(z_channels, model_channels)
|
225 |
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
226 |
+
|
227 |
+
# Create residual blocks
|
228 |
+
res_blocks = []
|
229 |
+
for i in range(num_res_blocks - 1):
|
230 |
+
res_blocks.append(ResBlock(model_channels, out_channels, deep_supervision, cond_method))
|
231 |
+
res_blocks.append(ResBlock(model_channels, cond_method=cond_method))
|
232 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
233 |
+
|
234 |
+
self.final_layer = FinalLayer(model_channels, out_channels, cond_method=cond_method)
|
235 |
+
self.initialize_weights()
|
236 |
+
|
237 |
+
def initialize_weights(self):
|
238 |
+
# Basic initialization for all linear layers
|
239 |
+
def _basic_init(module):
|
240 |
+
if isinstance(module, nn.Linear):
|
241 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
242 |
+
if module.bias is not None:
|
243 |
+
nn.init.constant_(module.bias, 0)
|
244 |
+
self.apply(_basic_init)
|
245 |
+
|
246 |
+
# Initialize timestep embedding MLP
|
247 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
248 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
249 |
+
|
250 |
+
if self.token_drop_prob > 0.0:
|
251 |
+
nn.init.normal_(self.null_token, std=0.02)
|
252 |
+
|
253 |
+
# Zero-out adaLN modulation layers (always used for timestep)
|
254 |
+
for i, block in enumerate(self.res_blocks):
|
255 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
256 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
257 |
+
if self.deep_supervision and i < len(self.res_blocks) - 1:
|
258 |
+
nn.init.constant_(block.final_layer.adaLN_modulation[-1].weight, 0)
|
259 |
+
nn.init.constant_(block.final_layer.adaLN_modulation[-1].bias, 0)
|
260 |
+
nn.init.constant_(block.final_layer.linear.weight, 0)
|
261 |
+
nn.init.constant_(block.final_layer.linear.bias, 0)
|
262 |
+
|
263 |
+
# Zero-out output layers
|
264 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
265 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
266 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
267 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
268 |
+
|
269 |
+
def forward(self, x, t, c):
|
270 |
+
"""
|
271 |
+
Apply the model to an input batch.
|
272 |
+
:param x: an [N x C] Tensor of inputs.
|
273 |
+
:param t: a 1-D batch of timesteps.
|
274 |
+
:param c: conditioning from AR transformer.
|
275 |
+
:return: an [N x C] Tensor of outputs.
|
276 |
+
"""
|
277 |
+
x = self.input_proj(x)
|
278 |
+
t_emb = self.time_embed(t)
|
279 |
+
|
280 |
+
# Apply token dropout if needed
|
281 |
+
if self.token_drop_prob > 0.0 and self.training:
|
282 |
+
drop_ids = torch.rand(c.shape[0], 1, device=c.device) < self.token_drop_prob
|
283 |
+
c = torch.where(drop_ids, self.null_token, c)
|
284 |
+
c_emb = self.cond_embed(c)
|
285 |
+
|
286 |
+
# Prepare conditioning based on method
|
287 |
+
if self.cond_method == "adaln":
|
288 |
+
t_combined, c_for_concat = t_emb + c_emb, None
|
289 |
+
else: # concat
|
290 |
+
t_combined, c_for_concat = t_emb, c_emb
|
291 |
+
|
292 |
+
if self.deep_supervision and self.training:
|
293 |
+
xs = []
|
294 |
+
for block in self.res_blocks[:-1]:
|
295 |
+
x, x_hat = block(x, t_combined, c_for_concat)
|
296 |
+
xs.append(x_hat)
|
297 |
+
x = self.res_blocks[-1](x, t_combined, c_for_concat)
|
298 |
+
x = self.final_layer(x, t_combined, c_for_concat)
|
299 |
+
xs.append(x)
|
300 |
+
return torch.stack(xs, dim=-1)
|
301 |
+
else:
|
302 |
+
for block in self.res_blocks:
|
303 |
+
x = block(x, t_combined, c_for_concat)
|
304 |
+
return self.final_layer(x, t_combined, c_for_concat)
|
305 |
+
|
306 |
+
def forward_with_cfg(self, x, t, c, cfg_scale):
|
307 |
+
half = x[: len(x) // 2]
|
308 |
+
combined = torch.cat([half, half], dim=0)
|
309 |
+
model_out = self.forward(combined, t, c)
|
310 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
311 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
312 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
313 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
314 |
+
return torch.cat([eps, rest], dim=1)
|
paintmind/stage2/generate.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, diff_cfg: float = 1.0, temperature: float = 1.0):
|
7 |
+
tokens = model(None, cond_idx, input_pos, cfg=cfg_scale, diff_cfg=diff_cfg, temperature=temperature)
|
8 |
+
return tokens.unsqueeze(1)
|
9 |
+
|
10 |
+
|
11 |
+
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, diff_cfg: float = 1.0, temperature: float = 1.0):
|
12 |
+
assert input_pos.shape[-1] == 1
|
13 |
+
if cfg_scale > 1.0:
|
14 |
+
x = torch.cat([x, x])
|
15 |
+
tokens = model(x, cond_idx=None, input_pos=input_pos, cfg=cfg_scale, diff_cfg=diff_cfg, temperature=temperature)
|
16 |
+
return tokens
|
17 |
+
|
18 |
+
|
19 |
+
def decode_n_tokens(
|
20 |
+
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
|
21 |
+
cfg_scale: float, diff_cfg: float = 1.0, temperature: float = 1.0, cfg_schedule = "constant", diff_cfg_schedule = "constant"):
|
22 |
+
new_tokens = []
|
23 |
+
for i in range(num_new_tokens):
|
24 |
+
cfg_iter = get_cfg(cfg_scale, i + 1, num_new_tokens + 1, cfg_schedule)
|
25 |
+
diff_cfg_iter = get_cfg(diff_cfg, i + 1, num_new_tokens + 1, diff_cfg_schedule)
|
26 |
+
next_token = decode_one_token(model, cur_token, input_pos, cfg_iter, diff_cfg=diff_cfg_iter, temperature=temperature).unsqueeze(1)
|
27 |
+
input_pos += 1
|
28 |
+
new_tokens.append(next_token.clone())
|
29 |
+
cur_token = next_token
|
30 |
+
|
31 |
+
return new_tokens
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, diff_cfg: float = 1.0, temperature: float = 1.0, cfg_schedule = "constant", diff_cfg_schedule = "constant"):
|
36 |
+
if cfg_scale > 1.0:
|
37 |
+
cond_null = torch.ones_like(cond) * model.num_classes
|
38 |
+
cond_combined = torch.cat([cond, cond_null])
|
39 |
+
else:
|
40 |
+
cond_combined = cond
|
41 |
+
T = model.cls_token_num
|
42 |
+
|
43 |
+
T_new = T + max_new_tokens
|
44 |
+
max_seq_length = T_new
|
45 |
+
max_batch_size = cond.shape[0]
|
46 |
+
|
47 |
+
device = cond.device
|
48 |
+
dtype = model.z_proj.weight.dtype
|
49 |
+
if torch.is_autocast_enabled():
|
50 |
+
dtype = torch.get_autocast_dtype(device_type=device.type)
|
51 |
+
with torch.device(device):
|
52 |
+
max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
|
53 |
+
model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=dtype)
|
54 |
+
|
55 |
+
if emb_masks is not None:
|
56 |
+
assert emb_masks.shape[0] == max_batch_size
|
57 |
+
assert emb_masks.shape[-1] == T
|
58 |
+
if cfg_scale > 1.0:
|
59 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
|
60 |
+
else:
|
61 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
|
62 |
+
|
63 |
+
eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
|
64 |
+
model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
|
65 |
+
|
66 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
67 |
+
seq = torch.empty((max_batch_size, T_new, model.slot_dim), dtype=dtype, device=device)
|
68 |
+
|
69 |
+
input_pos = torch.arange(0, T, device=device)
|
70 |
+
cfg_iter = get_cfg(cfg_scale, 0, max_new_tokens, cfg_schedule)
|
71 |
+
diff_cfg_iter = get_cfg(diff_cfg, 0, max_new_tokens, diff_cfg_schedule)
|
72 |
+
next_token = prefill(model, cond_combined, input_pos, cfg_iter, diff_cfg=diff_cfg_iter, temperature=temperature)
|
73 |
+
seq[:, T:T+1] = next_token
|
74 |
+
|
75 |
+
if max_new_tokens > 1:
|
76 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
77 |
+
generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, cfg_scale, diff_cfg=diff_cfg, temperature=temperature, cfg_schedule=cfg_schedule, diff_cfg_schedule=diff_cfg_schedule)
|
78 |
+
seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
|
79 |
+
|
80 |
+
model.reset_caches()
|
81 |
+
return seq[:, T:]
|
82 |
+
|
83 |
+
|
84 |
+
def get_cfg(cfg, cur_step, total_step, cfg_schedule="constant"):
|
85 |
+
if cfg_schedule == "linear":
|
86 |
+
return 1 + (cfg - 1) * (cur_step + 1) / total_step
|
87 |
+
elif cfg_schedule == "inv_linear":
|
88 |
+
return 1 + (cfg - 1) * (total_step - cur_step - 1) / total_step
|
89 |
+
elif cfg_schedule == "constant":
|
90 |
+
return cfg
|
91 |
+
else:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def generate_causal_dit(model, cond, max_new_tokens, cfg_scale=1.0):
|
97 |
+
assert max_new_tokens == model.num_slots
|
98 |
+
|
99 |
+
batch_size = cond.shape[0]
|
100 |
+
device = cond.device
|
101 |
+
|
102 |
+
if cfg_scale > 1.0:
|
103 |
+
cond_null = torch.ones_like(cond) * model.num_classes
|
104 |
+
cond_combined = torch.cat([cond, cond_null])
|
105 |
+
else:
|
106 |
+
cond_combined = cond
|
107 |
+
|
108 |
+
cur_tokens = []
|
109 |
+
for i in range(max_new_tokens):
|
110 |
+
pos_embed = model.pos_embed[:, :i + 1].view(1, -1, model.hidden_size).expand(batch_size, -1, -1)
|
111 |
+
if cfg_scale > 1.0:
|
112 |
+
pos_embed = torch.cat([pos_embed, pos_embed], dim=0)
|
113 |
+
|
114 |
+
attn_mask = torch.ones(model.cls_token_num + i + 1, model.cls_token_num + i + 1, dtype=torch.bool).tril(diagonal=0).to(device)
|
115 |
+
|
116 |
+
context = torch.cat(cur_tokens, dim=1) if len(cur_tokens) > 0 else None
|
117 |
+
if cfg_scale > 1.0 and context is not None:
|
118 |
+
context = torch.cat([context, context], dim=0)
|
119 |
+
|
120 |
+
next_token = model.sample(cond_combined, pos_embed, attn_mask, context, cfg_scale)
|
121 |
+
if cfg_scale > 1.0:
|
122 |
+
next_token, _ = next_token.chunk(2, dim=0)
|
123 |
+
cur_tokens.append(next_token.clone())
|
124 |
+
|
125 |
+
seq = torch.cat(cur_tokens, dim=1)
|
126 |
+
|
127 |
+
return seq
|
paintmind/stage2/gpt.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
|
5 |
+
# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
|
6 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
7 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Optional, List, Union
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
from paintmind.stage1.vision_transformers import DropPath
|
17 |
+
from paintmind.stage2.diffloss import DiffLoss
|
18 |
+
|
19 |
+
def find_multiple(n: int, k: int):
|
20 |
+
if n % k == 0:
|
21 |
+
return n
|
22 |
+
return n + k - (n % k)
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
#################################################################################
|
27 |
+
# Embedding Layers for Class Labels #
|
28 |
+
#################################################################################
|
29 |
+
class LabelEmbedder(nn.Module):
|
30 |
+
"""
|
31 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
32 |
+
"""
|
33 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
34 |
+
super().__init__()
|
35 |
+
use_cfg_embedding = dropout_prob > 0
|
36 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
37 |
+
self.num_classes = num_classes
|
38 |
+
self.dropout_prob = dropout_prob
|
39 |
+
|
40 |
+
def token_drop(self, labels, force_drop_ids=None):
|
41 |
+
"""
|
42 |
+
Drops labels to enable classifier-free guidance.
|
43 |
+
"""
|
44 |
+
if force_drop_ids is None:
|
45 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
46 |
+
else:
|
47 |
+
drop_ids = force_drop_ids == 1
|
48 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
49 |
+
return labels
|
50 |
+
|
51 |
+
def forward(self, labels, train, force_drop_ids=None):
|
52 |
+
use_dropout = self.dropout_prob > 0
|
53 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
54 |
+
labels = self.token_drop(labels, force_drop_ids)
|
55 |
+
embeddings = self.embedding_table(labels).unsqueeze(1)
|
56 |
+
return embeddings
|
57 |
+
|
58 |
+
|
59 |
+
class MLP(nn.Module):
|
60 |
+
def __init__(self, in_features, hidden_features, out_features):
|
61 |
+
super().__init__()
|
62 |
+
out_features = out_features or in_features
|
63 |
+
hidden_features = hidden_features or in_features
|
64 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
65 |
+
self.act = nn.GELU(approximate='tanh')
|
66 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.act(x)
|
71 |
+
x = self.fc2(x)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
#################################################################################
|
76 |
+
# GPT Model #
|
77 |
+
#################################################################################
|
78 |
+
class RMSNorm(torch.nn.Module):
|
79 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
80 |
+
super().__init__()
|
81 |
+
self.eps = eps
|
82 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
83 |
+
|
84 |
+
def _norm(self, x):
|
85 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
output = self._norm(x.float()).type_as(x)
|
89 |
+
return output * self.weight
|
90 |
+
|
91 |
+
|
92 |
+
class FeedForward(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
dim: int,
|
96 |
+
multiple_of: int = 256,
|
97 |
+
ffn_dropout_p: float = 0.0,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
hidden_dim = 4 * dim
|
101 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
102 |
+
hidden_dim = find_multiple(hidden_dim, multiple_of)
|
103 |
+
|
104 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
105 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
106 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
107 |
+
self.ffn_dropout = nn.Dropout(ffn_dropout_p)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
111 |
+
|
112 |
+
|
113 |
+
class KVCache(nn.Module):
|
114 |
+
def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
|
115 |
+
super().__init__()
|
116 |
+
cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
|
117 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
118 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
119 |
+
|
120 |
+
def update(self, input_pos, k_val, v_val):
|
121 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
122 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
123 |
+
k_out = self.k_cache
|
124 |
+
v_out = self.v_cache
|
125 |
+
k_out[:, :, input_pos] = k_val
|
126 |
+
v_out[:, :, input_pos] = v_val
|
127 |
+
|
128 |
+
return k_out, v_out
|
129 |
+
|
130 |
+
|
131 |
+
class Attention(nn.Module):
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
dim: int,
|
135 |
+
n_head: int,
|
136 |
+
attn_dropout_p: float = 0.0,
|
137 |
+
resid_dropout_p: float = 0.1,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
assert dim % n_head == 0
|
141 |
+
self.dim = dim
|
142 |
+
self.head_dim = dim // n_head
|
143 |
+
self.n_head = n_head
|
144 |
+
|
145 |
+
# key, query, value projections for all heads, but in a batch
|
146 |
+
self.wqkv = nn.Linear(dim, dim * 3, bias=False)
|
147 |
+
self.wo = nn.Linear(dim, dim, bias=False)
|
148 |
+
self.kv_cache = None
|
149 |
+
|
150 |
+
# regularization
|
151 |
+
self.attn_dropout_p = attn_dropout_p
|
152 |
+
self.resid_dropout = nn.Dropout(resid_dropout_p)
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self, x: torch.Tensor,
|
156 |
+
input_pos: Optional[torch.Tensor] = None,
|
157 |
+
mask: Optional[torch.Tensor] = None
|
158 |
+
):
|
159 |
+
bsz, seqlen, _ = x.shape
|
160 |
+
xq, xk, xv = self.wqkv(x).split([self.dim, self.dim, self.dim], dim=-1)
|
161 |
+
|
162 |
+
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
|
163 |
+
xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
|
164 |
+
xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
|
165 |
+
|
166 |
+
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
|
167 |
+
|
168 |
+
if self.kv_cache is not None:
|
169 |
+
keys, values = self.kv_cache.update(input_pos, xk, xv)
|
170 |
+
else:
|
171 |
+
keys, values = xk, xv
|
172 |
+
|
173 |
+
output = F.scaled_dot_product_attention(
|
174 |
+
xq, keys, values,
|
175 |
+
attn_mask=mask,
|
176 |
+
is_causal=True if mask is None else False, # is_causal=False is for KV cache
|
177 |
+
dropout_p=self.attn_dropout_p if self.training else 0)
|
178 |
+
|
179 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
180 |
+
|
181 |
+
output = self.resid_dropout(self.wo(output))
|
182 |
+
return output
|
183 |
+
|
184 |
+
|
185 |
+
class TransformerBlock(nn.Module):
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
dim: int,
|
189 |
+
n_head: int,
|
190 |
+
multiple_of: int = 256,
|
191 |
+
norm_eps: float = 1e-5,
|
192 |
+
attn_dropout_p: float = 0.0,
|
193 |
+
ffn_dropout_p: float = 0.1,
|
194 |
+
resid_dropout_p: float = 0.1,
|
195 |
+
drop_path: float = 0.0,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.attention = Attention(
|
199 |
+
dim=dim,
|
200 |
+
n_head=n_head,
|
201 |
+
attn_dropout_p=attn_dropout_p,
|
202 |
+
resid_dropout_p=resid_dropout_p,
|
203 |
+
)
|
204 |
+
self.feed_forward = FeedForward(
|
205 |
+
dim=dim,
|
206 |
+
multiple_of=multiple_of,
|
207 |
+
ffn_dropout_p=ffn_dropout_p,
|
208 |
+
)
|
209 |
+
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
210 |
+
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
211 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
212 |
+
|
213 |
+
def forward(self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
|
214 |
+
h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
|
215 |
+
out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
|
216 |
+
return out
|
217 |
+
|
218 |
+
|
219 |
+
class Transformer(nn.Module):
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
dim: int = 4096,
|
223 |
+
n_layer: int = 32,
|
224 |
+
n_head: int = 32,
|
225 |
+
attn_dropout_p: float = 0.0,
|
226 |
+
resid_dropout_p: float = 0.1,
|
227 |
+
ffn_dropout_p: float = 0.1,
|
228 |
+
drop_path_rate: float = 0.0,
|
229 |
+
num_classes: Union[int, List[int]] = 1000,
|
230 |
+
class_dropout_prob: float = 0.1,
|
231 |
+
|
232 |
+
cls_token_num: int = 1,
|
233 |
+
num_slots: int = 16,
|
234 |
+
slot_dim: int = 256,
|
235 |
+
|
236 |
+
diffloss_d: int = 3,
|
237 |
+
diffloss_w: int = 1024,
|
238 |
+
num_sampling_steps: str = '100',
|
239 |
+
diffusion_batch_mul: int = 4,
|
240 |
+
predict_xstart: bool = False,
|
241 |
+
use_si: bool = False,
|
242 |
+
deep_supervision: bool = False,
|
243 |
+
token_drop_prob: float = 0.0,
|
244 |
+
cond_method: str = "adaln",
|
245 |
+
decoupled_cfg: bool = True,
|
246 |
+
**kwargs,
|
247 |
+
):
|
248 |
+
super().__init__()
|
249 |
+
|
250 |
+
# Store configuration
|
251 |
+
self.dim = dim
|
252 |
+
self.n_layer = n_layer
|
253 |
+
self.n_head = n_head
|
254 |
+
self.num_slots = num_slots
|
255 |
+
self.slot_dim = slot_dim
|
256 |
+
self.num_classes = num_classes
|
257 |
+
self.cls_token_num = cls_token_num
|
258 |
+
|
259 |
+
# Initialize embeddings
|
260 |
+
self.cls_embedding = LabelEmbedder(num_classes, dim, class_dropout_prob)
|
261 |
+
self.z_proj = nn.Linear(slot_dim, dim, bias=True)
|
262 |
+
self.z_proj_ln = RMSNorm(dim)
|
263 |
+
self.pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots + cls_token_num, dim))
|
264 |
+
|
265 |
+
# transformer blocks
|
266 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)]
|
267 |
+
self.layers = torch.nn.ModuleList()
|
268 |
+
for layer_id in range(n_layer):
|
269 |
+
self.layers.append(TransformerBlock(
|
270 |
+
dim=dim,
|
271 |
+
n_head=n_head,
|
272 |
+
ffn_dropout_p=ffn_dropout_p,
|
273 |
+
attn_dropout_p=attn_dropout_p,
|
274 |
+
resid_dropout_p=resid_dropout_p,
|
275 |
+
drop_path=dpr[layer_id],
|
276 |
+
))
|
277 |
+
|
278 |
+
# output layer
|
279 |
+
self.norm = RMSNorm(dim)
|
280 |
+
|
281 |
+
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots, dim))
|
282 |
+
|
283 |
+
# KVCache
|
284 |
+
self.max_batch_size = -1
|
285 |
+
self.max_seq_length = -1
|
286 |
+
|
287 |
+
self.initialize_weights()
|
288 |
+
|
289 |
+
# Diffusion Loss
|
290 |
+
self.diffloss = DiffLoss(
|
291 |
+
target_channels=slot_dim,
|
292 |
+
z_channels=dim,
|
293 |
+
width=diffloss_w,
|
294 |
+
depth=diffloss_d,
|
295 |
+
num_sampling_steps=num_sampling_steps,
|
296 |
+
predict_xstart=predict_xstart,
|
297 |
+
use_si=use_si,
|
298 |
+
deep_supervision=deep_supervision,
|
299 |
+
token_drop_prob=token_drop_prob,
|
300 |
+
cond_method=cond_method,
|
301 |
+
decoupled_cfg=decoupled_cfg,
|
302 |
+
)
|
303 |
+
self.decoupled_cfg = decoupled_cfg
|
304 |
+
self.diffusion_batch_mul = diffusion_batch_mul
|
305 |
+
|
306 |
+
def initialize_weights(self):
|
307 |
+
nn.init.normal_(self.pos_embed_learned, std=0.02)
|
308 |
+
nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02)
|
309 |
+
# Initialize nn.Linear and nn.Embedding
|
310 |
+
self.apply(self._init_weights)
|
311 |
+
|
312 |
+
def _init_weights(self, module):
|
313 |
+
if isinstance(module, nn.Linear):
|
314 |
+
module.weight.data.normal_(std=0.02)
|
315 |
+
if module.bias is not None:
|
316 |
+
module.bias.data.zero_()
|
317 |
+
elif isinstance(module, nn.Embedding):
|
318 |
+
module.weight.data.normal_(std=0.02)
|
319 |
+
|
320 |
+
def setup_caches(self, max_batch_size, max_seq_length, dtype):
|
321 |
+
# if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
322 |
+
# return
|
323 |
+
head_dim = self.dim // self.n_head
|
324 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
325 |
+
self.max_seq_length = max_seq_length
|
326 |
+
self.max_batch_size = max_batch_size
|
327 |
+
for b in self.layers:
|
328 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.n_head, head_dim, dtype)
|
329 |
+
|
330 |
+
causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
|
331 |
+
self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
|
332 |
+
|
333 |
+
def reset_caches(self):
|
334 |
+
self.max_seq_length = -1
|
335 |
+
self.max_batch_size = -1
|
336 |
+
for b in self.layers:
|
337 |
+
b.attention.kv_cache = None
|
338 |
+
|
339 |
+
def forward_loss(self, z, target):
|
340 |
+
bsz, seq_len, _ = target.shape
|
341 |
+
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
|
342 |
+
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
|
343 |
+
loss = self.diffloss(z=z, target=target)
|
344 |
+
return loss
|
345 |
+
|
346 |
+
def forward_cfg(self, h, cfg):
|
347 |
+
if cfg > 1.0:
|
348 |
+
h_cond, h_uncond = h.chunk(2, dim=0)
|
349 |
+
h = h_uncond + cfg * (h_cond - h_uncond)
|
350 |
+
return h
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
slots: torch.Tensor,
|
355 |
+
cond_idx: torch.Tensor, # cond_idx_or_embed
|
356 |
+
input_pos: Optional[torch.Tensor] = None,
|
357 |
+
mask: Optional[torch.Tensor] = None,
|
358 |
+
cfg: float = 1.0,
|
359 |
+
diff_cfg: float = 1.0,
|
360 |
+
temperature: float = 1.0
|
361 |
+
):
|
362 |
+
if slots is not None and cond_idx is not None: # training or naive inference
|
363 |
+
cond_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
364 |
+
cond_embeddings = cond_embeddings.expand(-1, self.cls_token_num, -1)
|
365 |
+
token_embeddings = self.z_proj(slots)
|
366 |
+
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
|
367 |
+
else:
|
368 |
+
if cond_idx is not None: # prefill in inference
|
369 |
+
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
370 |
+
token_embeddings = token_embeddings.expand(-1, self.cls_token_num, -1)
|
371 |
+
else: # decode_n_tokens(kv cache) in inference
|
372 |
+
token_embeddings = self.z_proj(slots)
|
373 |
+
|
374 |
+
bs = token_embeddings.shape[0]
|
375 |
+
mask = self.causal_mask[:bs, None, input_pos]
|
376 |
+
|
377 |
+
h = token_embeddings
|
378 |
+
if self.training:
|
379 |
+
h = h + self.pos_embed_learned
|
380 |
+
else:
|
381 |
+
h = h + self.pos_embed_learned[:, input_pos].view(1, -1, self.dim)
|
382 |
+
|
383 |
+
h = self.z_proj_ln(h) # not sure if this is needed
|
384 |
+
|
385 |
+
# transformer blocks
|
386 |
+
for layer in self.layers:
|
387 |
+
h = layer(h, input_pos, mask)
|
388 |
+
|
389 |
+
h = self.norm(h)
|
390 |
+
|
391 |
+
if self.training:
|
392 |
+
h = h[:, self.cls_token_num - 1 : -1].contiguous()
|
393 |
+
h = h + self.diffusion_pos_embed_learned
|
394 |
+
loss = self.forward_loss(h, slots.detach())
|
395 |
+
return loss
|
396 |
+
else:
|
397 |
+
if self.decoupled_cfg:
|
398 |
+
h = self.forward_cfg(h[:, -1], cfg)
|
399 |
+
h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1]
|
400 |
+
if diff_cfg > 1.0 and hasattr(self.diffloss.net, 'null_token'):
|
401 |
+
null_token = self.diffloss.net.null_token.expand(h.shape[0], -1)
|
402 |
+
h = torch.cat((h, null_token), dim=0)
|
403 |
+
else:
|
404 |
+
diff_cfg = 1.0
|
405 |
+
next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=diff_cfg)
|
406 |
+
else:
|
407 |
+
h = h[:, -1]
|
408 |
+
h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1]
|
409 |
+
next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=cfg)
|
410 |
+
return next_tokens
|
411 |
+
|
412 |
+
|
413 |
+
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
414 |
+
return list(self.layers)
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
#################################################################################
|
419 |
+
# GPT Configs #
|
420 |
+
#################################################################################
|
421 |
+
### text-conditional
|
422 |
+
def GPT_7B(**kwargs):
|
423 |
+
return Transformer(n_layer=32, n_head=32, dim=4096, **kwargs) # 6.6B
|
424 |
+
|
425 |
+
def GPT_3B(**kwargs):
|
426 |
+
return Transformer(n_layer=24, n_head=32, dim=3200, **kwargs) # 3.1B
|
427 |
+
|
428 |
+
def GPT_1B(**kwargs):
|
429 |
+
return Transformer(n_layer=22, n_head=32, dim=2048, **kwargs) # 1.2B
|
430 |
+
|
431 |
+
### class-conditional
|
432 |
+
def GPT_XXXL(**kwargs):
|
433 |
+
return Transformer(n_layer=48, n_head=40, dim=2560, **kwargs) # 3.9B
|
434 |
+
|
435 |
+
def GPT_XXL(**kwargs):
|
436 |
+
return Transformer(n_layer=48, n_head=24, dim=1536, **kwargs) # 1.4B
|
437 |
+
|
438 |
+
def GPT_XL(**kwargs):
|
439 |
+
return Transformer(n_layer=36, n_head=20, dim=1280, **kwargs) # 775M
|
440 |
+
|
441 |
+
def GPT_L(**kwargs):
|
442 |
+
return Transformer(n_layer=24, n_head=16, dim=1024, **kwargs) # 343M
|
443 |
+
|
444 |
+
def GPT_B(**kwargs):
|
445 |
+
return Transformer(n_layer=12, n_head=12, dim=768, **kwargs) # 111M
|
446 |
+
|
447 |
+
|
448 |
+
GPT_models = {
|
449 |
+
'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
|
450 |
+
'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
|
451 |
+
}
|
paintmind/utils/__init__.py
ADDED
File without changes
|
paintmind/utils/datasets.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import numpy as np
|
5 |
+
import os.path as osp
|
6 |
+
from glob import glob
|
7 |
+
from PIL import Image
|
8 |
+
import torchvision
|
9 |
+
import torchvision.transforms as TF
|
10 |
+
|
11 |
+
def pair(t):
|
12 |
+
return t if isinstance(t, tuple) else (t, t)
|
13 |
+
|
14 |
+
def center_crop_arr(pil_image, image_size):
|
15 |
+
"""
|
16 |
+
Center cropping implementation from ADM.
|
17 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
18 |
+
"""
|
19 |
+
while min(*pil_image.size) >= 2 * image_size:
|
20 |
+
pil_image = pil_image.resize(
|
21 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
22 |
+
)
|
23 |
+
|
24 |
+
scale = image_size / min(*pil_image.size)
|
25 |
+
pil_image = pil_image.resize(
|
26 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
27 |
+
)
|
28 |
+
|
29 |
+
arr = np.array(pil_image)
|
30 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
31 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
32 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
33 |
+
|
34 |
+
def vae_transforms(image_set, aug='randcrop', img_size=256):
|
35 |
+
|
36 |
+
t = []
|
37 |
+
if image_set == 'train':
|
38 |
+
if aug == 'randcrop':
|
39 |
+
t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
|
40 |
+
t.append(TF.RandomCrop(img_size))
|
41 |
+
elif aug == 'centercrop':
|
42 |
+
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Invalid augmentation: {aug}")
|
45 |
+
t.append(TF.RandomHorizontalFlip(p=0.5))
|
46 |
+
else:
|
47 |
+
t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
|
48 |
+
t.append(TF.CenterCrop(img_size))
|
49 |
+
|
50 |
+
t.append(TF.ToTensor())
|
51 |
+
|
52 |
+
return TF.Compose(t)
|
53 |
+
|
54 |
+
|
55 |
+
def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]):
|
56 |
+
t = []
|
57 |
+
if 'centercrop' in aug:
|
58 |
+
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
|
59 |
+
t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))])))
|
60 |
+
elif 'tencrop' in aug:
|
61 |
+
crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges]
|
62 |
+
t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes]))
|
63 |
+
t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple]))
|
64 |
+
t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops])))
|
65 |
+
else:
|
66 |
+
raise ValueError(f"Invalid augmentation: {aug}")
|
67 |
+
|
68 |
+
return TF.Compose(t)
|
69 |
+
|
70 |
+
|
71 |
+
class ImageNet(torchvision.datasets.ImageFolder):
|
72 |
+
def __init__(self, root, split='train', aug='randcrop', img_size=256):
|
73 |
+
super().__init__(osp.join(root, split))
|
74 |
+
if not 'cache' in aug:
|
75 |
+
self.transform = vae_transforms(split, aug=aug, img_size=img_size)
|
76 |
+
else:
|
77 |
+
self.transform = cached_transforms(aug=aug, img_size=img_size)
|
paintmind/utils/device_utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import torch
|
3 |
+
import importlib.util
|
4 |
+
|
5 |
+
def configure_compute_backend():
|
6 |
+
"""Configure PyTorch compute backend settings for CUDA."""
|
7 |
+
if torch.cuda.is_available():
|
8 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
9 |
+
torch.backends.cudnn.allow_tf32 = True
|
10 |
+
torch.backends.cudnn.benchmark = True
|
11 |
+
torch.backends.cudnn.deterministic = False
|
12 |
+
else:
|
13 |
+
raise ValueError("No CUDA available")
|
14 |
+
|
15 |
+
def get_device():
|
16 |
+
"""Get the device to use for training."""
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
return torch.device("cuda")
|
19 |
+
else:
|
20 |
+
raise ValueError("No CUDA available")
|
paintmind/utils/logger.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict, deque
|
2 |
+
import datetime
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from paintmind.engine.misc import is_dist_avail_and_initialized, is_main_process
|
7 |
+
from paintmind.utils.device_utils import get_device
|
8 |
+
|
9 |
+
def synchronize_processes():
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
torch.cuda.synchronize()
|
12 |
+
else: # do nothing
|
13 |
+
pass
|
14 |
+
|
15 |
+
def empty_cache():
|
16 |
+
if torch.cuda.is_available():
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
else: # do nothing
|
19 |
+
pass
|
20 |
+
|
21 |
+
class SmoothedValue(object):
|
22 |
+
"""Track a series of values and provide access to smoothed values over a
|
23 |
+
window or the global series average.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, window_size=20, fmt=None):
|
27 |
+
if fmt is None:
|
28 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
29 |
+
self.deque = deque(maxlen=window_size)
|
30 |
+
self.total = 0.0
|
31 |
+
self.count = 0
|
32 |
+
self.fmt = fmt
|
33 |
+
|
34 |
+
def update(self, value, n=1):
|
35 |
+
self.deque.append(value)
|
36 |
+
self.count += n
|
37 |
+
self.total += value * n
|
38 |
+
|
39 |
+
def synchronize_between_processes(self):
|
40 |
+
"""
|
41 |
+
Warning: does not synchronize the deque!
|
42 |
+
"""
|
43 |
+
if not is_dist_avail_and_initialized():
|
44 |
+
return
|
45 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float32, device=get_device())
|
46 |
+
dist.barrier()
|
47 |
+
dist.all_reduce(t)
|
48 |
+
t = t.tolist()
|
49 |
+
self.count = int(t[0])
|
50 |
+
self.total = t[1]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def median(self):
|
54 |
+
d = torch.tensor(list(self.deque))
|
55 |
+
return d.median().item()
|
56 |
+
|
57 |
+
@property
|
58 |
+
def avg(self):
|
59 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
60 |
+
return d.mean().item()
|
61 |
+
|
62 |
+
@property
|
63 |
+
def global_avg(self):
|
64 |
+
return self.total / self.count
|
65 |
+
|
66 |
+
@property
|
67 |
+
def max(self):
|
68 |
+
return max(self.deque)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def value(self):
|
72 |
+
return self.deque[-1]
|
73 |
+
|
74 |
+
def __str__(self):
|
75 |
+
return self.fmt.format(
|
76 |
+
median=self.median,
|
77 |
+
avg=self.avg,
|
78 |
+
global_avg=self.global_avg,
|
79 |
+
max=self.max,
|
80 |
+
value=self.value)
|
81 |
+
|
82 |
+
|
83 |
+
class MetricLogger(object):
|
84 |
+
def __init__(self, delimiter="\t"):
|
85 |
+
self.meters = defaultdict(SmoothedValue)
|
86 |
+
self.delimiter = delimiter
|
87 |
+
|
88 |
+
def update(self, **kwargs):
|
89 |
+
for k, v in kwargs.items():
|
90 |
+
if v is None:
|
91 |
+
continue
|
92 |
+
if isinstance(v, torch.Tensor):
|
93 |
+
v = v.item()
|
94 |
+
assert isinstance(v, (float, int))
|
95 |
+
self.meters[k].update(v)
|
96 |
+
|
97 |
+
def __getattr__(self, attr):
|
98 |
+
if attr in self.meters:
|
99 |
+
return self.meters[attr]
|
100 |
+
if attr in self.__dict__:
|
101 |
+
return self.__dict__[attr]
|
102 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
103 |
+
type(self).__name__, attr))
|
104 |
+
|
105 |
+
def __str__(self):
|
106 |
+
loss_str = []
|
107 |
+
for name, meter in self.meters.items():
|
108 |
+
loss_str.append(
|
109 |
+
"{}: {}".format(name, str(meter))
|
110 |
+
)
|
111 |
+
return self.delimiter.join(loss_str)
|
112 |
+
|
113 |
+
def synchronize_between_processes(self):
|
114 |
+
for meter in self.meters.values():
|
115 |
+
meter.synchronize_between_processes()
|
116 |
+
|
117 |
+
def add_meter(self, name, meter):
|
118 |
+
self.meters[name] = meter
|
119 |
+
|
120 |
+
def log_every(self, iterable, print_freq, header=None):
|
121 |
+
i = 0
|
122 |
+
if not header:
|
123 |
+
header = ''
|
124 |
+
start_time = time.time()
|
125 |
+
end = time.time()
|
126 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
127 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
128 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
129 |
+
log_msg = [
|
130 |
+
header,
|
131 |
+
'[{0' + space_fmt + '}/{1}]',
|
132 |
+
'eta: {eta}',
|
133 |
+
'{meters}',
|
134 |
+
'time: {time}',
|
135 |
+
'data: {data}'
|
136 |
+
]
|
137 |
+
if torch.cuda.is_available():
|
138 |
+
log_msg.append('mem: {memory:.0f}')
|
139 |
+
log_msg.append("util: {util:.1f}%")
|
140 |
+
log_msg = self.delimiter.join(log_msg)
|
141 |
+
MB = 1024.0 * 1024.0
|
142 |
+
for obj in iterable:
|
143 |
+
data_time.update(time.time() - end)
|
144 |
+
yield obj
|
145 |
+
iter_time.update(time.time() - end)
|
146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
149 |
+
if torch.cuda.is_available():
|
150 |
+
if is_main_process():
|
151 |
+
memory = torch.cuda.max_memory_allocated()
|
152 |
+
util = torch.cuda.utilization()
|
153 |
+
print(log_msg.format(
|
154 |
+
i, len(iterable), eta=eta_string,
|
155 |
+
meters=str(self),
|
156 |
+
time=str(iter_time), data=str(data_time),
|
157 |
+
memory=memory / MB, util=util))
|
158 |
+
else:
|
159 |
+
if is_main_process():
|
160 |
+
print(log_msg.format(
|
161 |
+
i, len(iterable), eta=eta_string,
|
162 |
+
meters=str(self),
|
163 |
+
time=str(iter_time), data=str(data_time)))
|
164 |
+
i += 1
|
165 |
+
end = time.time()
|
166 |
+
total_time = time.time() - start_time
|
167 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
168 |
+
if is_main_process():
|
169 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
170 |
+
header, total_time_str, total_time / len(iterable)))
|
paintmind/utils/lr_scheduler.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
2 |
+
from timm.scheduler.step_lr import StepLRScheduler
|
3 |
+
|
4 |
+
def build_scheduler(optimizer, n_epoch, n_iter_per_epoch, lr_min=0, warmup_steps=0, warmup_lr_init=0, decay_steps=None, cosine_lr=True):
|
5 |
+
if decay_steps is None:
|
6 |
+
decay_steps = n_epoch * n_iter_per_epoch
|
7 |
+
|
8 |
+
if cosine_lr:
|
9 |
+
scheduler = CosineLRScheduler(optimizer, t_initial=decay_steps, lr_min=lr_min, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
|
10 |
+
cycle_limit=1, t_in_epochs=False, warmup_prefix=True)
|
11 |
+
else:
|
12 |
+
scheduler = StepLRScheduler(optimizer, decay_t=decay_steps, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
|
13 |
+
t_in_epochs=False, warmup_prefix=True)
|
14 |
+
|
15 |
+
return scheduler
|
paintmind/utils/transform.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import torchvision.transforms as T
|
3 |
+
|
4 |
+
def pair(t):
|
5 |
+
return t if isinstance(t, tuple) else (t, t)
|
6 |
+
|
7 |
+
def stage1_transform(img_size=256, is_train=True, scale=0.8):
|
8 |
+
|
9 |
+
resize = pair(int(img_size/scale))
|
10 |
+
t = []
|
11 |
+
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
|
12 |
+
if is_train:
|
13 |
+
t.append(T.RandomCrop(img_size))
|
14 |
+
t.append(T.RandomHorizontalFlip(p=0.5))
|
15 |
+
else:
|
16 |
+
t.append(T.CenterCrop(img_size))
|
17 |
+
|
18 |
+
t.append(T.ToTensor())
|
19 |
+
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
|
20 |
+
|
21 |
+
return T.Compose(t)
|
22 |
+
|
23 |
+
def stage2_transform(img_size=256, is_train=True, scale=0.8):
|
24 |
+
resize = pair(int(img_size/scale))
|
25 |
+
t = []
|
26 |
+
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
|
27 |
+
if is_train:
|
28 |
+
t.append(T.RandomCrop(img_size))
|
29 |
+
else:
|
30 |
+
t.append(T.CenterCrop(img_size))
|
31 |
+
|
32 |
+
t.append(T.ToTensor())
|
33 |
+
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
|
34 |
+
|
35 |
+
return T.Compose(t)
|
paintmind/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = '0.0.0'
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.4
|
2 |
+
sympy>=1.10
|
3 |
+
accelerate
|
4 |
+
datasets
|
5 |
+
diffusers[torch]
|
6 |
+
transformers
|
7 |
+
safetensors
|
8 |
+
smart_open
|
9 |
+
dotwiz
|
10 |
+
omegaconf
|
11 |
+
tensorboard
|
12 |
+
huggingface-hub
|
13 |
+
einops
|
14 |
+
lpips
|
15 |
+
timm
|
16 |
+
scipy
|
17 |
+
scikit-learn
|
18 |
+
scikit-image
|
19 |
+
kornia
|
20 |
+
torchtyping
|
21 |
+
git+https://github.com/xwen99/torch-fidelity.git@master#egg=torch-fidelity
|
22 |
+
open_clip_torch
|
23 |
+
opencv-python-headless
|
24 |
+
torchmetrics
|
25 |
+
torchdiffeq
|
26 |
+
lmdb
|
27 |
+
triton==3.0.0
|
submitit_test.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# A script to run multinode training with submitit.
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import os.path as osp
|
12 |
+
import submitit
|
13 |
+
import itertools
|
14 |
+
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from paintmind.engine.util import instantiate_from_config
|
17 |
+
from paintmind.utils.device_utils import configure_compute_backend
|
18 |
+
|
19 |
+
|
20 |
+
def parse_args():
|
21 |
+
parser = argparse.ArgumentParser("Submitit for accelerator training")
|
22 |
+
# Slurm configuration
|
23 |
+
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
|
24 |
+
parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
|
25 |
+
parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days")
|
26 |
+
parser.add_argument("--qos", default="normal", type=str, help="QOS to request")
|
27 |
+
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
|
28 |
+
parser.add_argument("--partition", default="h100-camera-train", type=str, help="Partition where to submit")
|
29 |
+
parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition")
|
30 |
+
parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request")
|
31 |
+
parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
|
32 |
+
|
33 |
+
# Model and testing configuration
|
34 |
+
parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model(s)")
|
35 |
+
parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number(s)")
|
36 |
+
parser.add_argument('--cfg', type=str, default=None, help="Path to config file")
|
37 |
+
parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use")
|
38 |
+
|
39 |
+
# Legacy parameter (preserved for backward compatibility)
|
40 |
+
parser.add_argument('--cfg_value', type=float, nargs='+', default=[None],
|
41 |
+
help='Legacy parameter for GPT classifier-free guidance scale')
|
42 |
+
|
43 |
+
# CFG-related parameters - all with nargs='+' to support multiple values
|
44 |
+
parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None],
|
45 |
+
help="Autoencoder classifier-free guidance scale")
|
46 |
+
parser.add_argument('--diff_cfg', type=float, nargs='+', default=[None],
|
47 |
+
help="Diffusion classifier-free guidance scale")
|
48 |
+
parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None],
|
49 |
+
help="CFG schedule type (e.g., constant, linear)")
|
50 |
+
parser.add_argument('--diff_cfg_schedule', type=str, nargs='+', default=[None],
|
51 |
+
help="Diffusion CFG schedule type (e.g., constant, inv_linear)")
|
52 |
+
parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None],
|
53 |
+
help="Number of slots to use for inference")
|
54 |
+
parser.add_argument('--temperature', type=float, nargs='+', default=[None],
|
55 |
+
help="Temperature for sampling")
|
56 |
+
|
57 |
+
return parser.parse_args()
|
58 |
+
|
59 |
+
|
60 |
+
def load_config(model_path, cfg_path=None):
|
61 |
+
"""Load configuration from file or model directory."""
|
62 |
+
if cfg_path is not None and osp.exists(cfg_path):
|
63 |
+
config_path = cfg_path
|
64 |
+
elif model_path and osp.exists(osp.join(model_path, 'config.yaml')):
|
65 |
+
config_path = osp.join(model_path, 'config.yaml')
|
66 |
+
else:
|
67 |
+
raise ValueError(f"No config file found at {model_path} or {cfg_path}")
|
68 |
+
|
69 |
+
return OmegaConf.load(config_path)
|
70 |
+
|
71 |
+
|
72 |
+
def setup_checkpoint_path(model_path, step, config):
|
73 |
+
"""Set up the checkpoint path based on model and step."""
|
74 |
+
if model_path:
|
75 |
+
ckpt_path = osp.join(model_path, 'models', f'step{step}')
|
76 |
+
if not osp.exists(ckpt_path):
|
77 |
+
print(f"Skipping non-existent checkpoint: {ckpt_path}")
|
78 |
+
return None
|
79 |
+
if hasattr(config.trainer.params, 'model'):
|
80 |
+
config.trainer.params.model.params.ckpt_path = ckpt_path
|
81 |
+
else:
|
82 |
+
config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
|
83 |
+
else:
|
84 |
+
result_folder = config.trainer.params.result_folder
|
85 |
+
ckpt_path = osp.join(result_folder, 'models', f'step{step}')
|
86 |
+
if hasattr(config.trainer.params, 'model'):
|
87 |
+
config.trainer.params.model.params.ckpt_path = ckpt_path
|
88 |
+
else:
|
89 |
+
config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
|
90 |
+
|
91 |
+
return ckpt_path
|
92 |
+
|
93 |
+
|
94 |
+
def setup_test_config(config, use_coco=False):
|
95 |
+
"""Set up common test configuration parameters."""
|
96 |
+
config.trainer.params.test_dataset = config.trainer.params.dataset
|
97 |
+
if not use_coco:
|
98 |
+
config.trainer.params.test_dataset.params.split = 'val'
|
99 |
+
else:
|
100 |
+
config.trainer.params.test_dataset.target = 'paintmind.utils.datasets.COCO'
|
101 |
+
config.trainer.params.test_dataset.params.root = './dataset/coco'
|
102 |
+
config.trainer.params.test_dataset.params.split = 'val2017'
|
103 |
+
config.trainer.params.test_only = True
|
104 |
+
config.trainer.params.compile = False
|
105 |
+
config.trainer.params.eval_fid = True
|
106 |
+
config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz'
|
107 |
+
if hasattr(config.trainer.params, 'model'):
|
108 |
+
config.trainer.params.model.params.num_sampling_steps = '250'
|
109 |
+
else:
|
110 |
+
config.trainer.params.ae_model.params.num_sampling_steps = '250'
|
111 |
+
|
112 |
+
def apply_cfg_params(config, param_dict):
|
113 |
+
"""Apply CFG-related parameters to the config."""
|
114 |
+
# Apply each parameter if it's not None
|
115 |
+
if param_dict.get('cfg_value') is not None:
|
116 |
+
config.trainer.params.cfg = param_dict['cfg_value']
|
117 |
+
print(f"Setting cfg to {param_dict['cfg_value']}")
|
118 |
+
|
119 |
+
if param_dict.get('ae_cfg') is not None:
|
120 |
+
config.trainer.params.ae_cfg = param_dict['ae_cfg']
|
121 |
+
print(f"Setting ae_cfg to {param_dict['ae_cfg']}")
|
122 |
+
|
123 |
+
if param_dict.get('diff_cfg') is not None:
|
124 |
+
config.trainer.params.diff_cfg = param_dict['diff_cfg']
|
125 |
+
print(f"Setting diff_cfg to {param_dict['diff_cfg']}")
|
126 |
+
|
127 |
+
if param_dict.get('cfg_schedule') is not None:
|
128 |
+
config.trainer.params.cfg_schedule = param_dict['cfg_schedule']
|
129 |
+
print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}")
|
130 |
+
|
131 |
+
if param_dict.get('diff_cfg_schedule') is not None:
|
132 |
+
config.trainer.params.diff_cfg_schedule = param_dict['diff_cfg_schedule']
|
133 |
+
print(f"Setting diff_cfg_schedule to {param_dict['diff_cfg_schedule']}")
|
134 |
+
|
135 |
+
if param_dict.get('test_num_slots') is not None:
|
136 |
+
config.trainer.params.test_num_slots = param_dict['test_num_slots']
|
137 |
+
print(f"Setting test_num_slots to {param_dict['test_num_slots']}")
|
138 |
+
|
139 |
+
if param_dict.get('temperature') is not None:
|
140 |
+
config.trainer.params.temperature = param_dict['temperature']
|
141 |
+
print(f"Setting temperature to {param_dict['temperature']}")
|
142 |
+
|
143 |
+
|
144 |
+
def run_test(config):
|
145 |
+
"""Instantiate trainer and run test."""
|
146 |
+
trainer = instantiate_from_config(config.trainer)
|
147 |
+
trainer.train()
|
148 |
+
|
149 |
+
|
150 |
+
def generate_param_combinations(args):
|
151 |
+
"""Generate all combinations of parameters from the provided arguments."""
|
152 |
+
# Create parameter grid for all combinations
|
153 |
+
param_grid = {
|
154 |
+
'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value,
|
155 |
+
'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg,
|
156 |
+
'diff_cfg': [None] if args.diff_cfg == [None] else args.diff_cfg,
|
157 |
+
'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule,
|
158 |
+
'diff_cfg_schedule': [None] if args.diff_cfg_schedule == [None] else args.diff_cfg_schedule,
|
159 |
+
'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots,
|
160 |
+
'temperature': [None] if args.temperature == [None] else args.temperature
|
161 |
+
}
|
162 |
+
|
163 |
+
# Get all parameter names that have non-None values
|
164 |
+
active_params = [k for k, v in param_grid.items() if v != [None]]
|
165 |
+
|
166 |
+
if not active_params:
|
167 |
+
# If no parameters are specified, yield a dict with all None values
|
168 |
+
yield {k: None for k in param_grid.keys()}
|
169 |
+
return
|
170 |
+
|
171 |
+
# Generate all combinations of active parameters
|
172 |
+
active_values = [param_grid[k] for k in active_params]
|
173 |
+
for combination in itertools.product(*active_values):
|
174 |
+
param_dict = {k: None for k in param_grid.keys()} # Start with all None
|
175 |
+
for i, param_name in enumerate(active_params):
|
176 |
+
param_dict[param_name] = combination[i]
|
177 |
+
yield param_dict
|
178 |
+
|
179 |
+
|
180 |
+
class Trainer(object):
|
181 |
+
def __init__(self, args):
|
182 |
+
self.args = args
|
183 |
+
|
184 |
+
def __call__(self):
|
185 |
+
"""Main entry point for the submitit job."""
|
186 |
+
self._setup_gpu_args()
|
187 |
+
configure_compute_backend()
|
188 |
+
self._run_tests()
|
189 |
+
|
190 |
+
def _run_tests(self):
|
191 |
+
"""Run tests for all specified models and steps."""
|
192 |
+
for step in self.args.step:
|
193 |
+
for model in self.args.model:
|
194 |
+
print(f"Testing model: {model} at step: {step}")
|
195 |
+
|
196 |
+
# Load configuration
|
197 |
+
config = load_config(model, self.args.cfg)
|
198 |
+
|
199 |
+
# Setup checkpoint path
|
200 |
+
ckpt_path = setup_checkpoint_path(model, step, config)
|
201 |
+
if ckpt_path is None:
|
202 |
+
continue
|
203 |
+
|
204 |
+
use_coco = self.args.dataset == 'coco' or self.args.dataset == 'COCO'
|
205 |
+
# Setup test configuration
|
206 |
+
setup_test_config(config, use_coco)
|
207 |
+
|
208 |
+
# Generate and apply all parameter combinations
|
209 |
+
for param_dict in generate_param_combinations(self.args):
|
210 |
+
# Create a copy of the config for each parameter combination
|
211 |
+
current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True))
|
212 |
+
|
213 |
+
# Print parameter combination
|
214 |
+
param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None])
|
215 |
+
print(f"Testing with parameters: {param_str}")
|
216 |
+
|
217 |
+
# Apply parameters and run test
|
218 |
+
apply_cfg_params(current_config, param_dict)
|
219 |
+
run_test(current_config)
|
220 |
+
|
221 |
+
def _setup_gpu_args(self):
|
222 |
+
"""Set up GPU and distributed environment variables."""
|
223 |
+
import submitit
|
224 |
+
|
225 |
+
print("Exporting PyTorch distributed environment variables")
|
226 |
+
dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False)
|
227 |
+
print(f"Master: {dist_env.master_addr}:{dist_env.master_port}")
|
228 |
+
print(f"Rank: {dist_env.rank}")
|
229 |
+
print(f"World size: {dist_env.world_size}")
|
230 |
+
print(f"Local rank: {dist_env.local_rank}")
|
231 |
+
print(f"Local world size: {dist_env.local_world_size}")
|
232 |
+
|
233 |
+
job_env = submitit.JobEnvironment()
|
234 |
+
self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id))
|
235 |
+
self.args.log_dir = self.args.output_dir
|
236 |
+
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
237 |
+
|
238 |
+
|
239 |
+
def main():
|
240 |
+
"""Main function to set up and submit the job."""
|
241 |
+
args = parse_args()
|
242 |
+
|
243 |
+
# Determine job directory
|
244 |
+
if args.cfg is not None and osp.exists(args.cfg):
|
245 |
+
config = OmegaConf.load(args.cfg)
|
246 |
+
elif osp.exists(osp.join(args.model[0], 'config.yaml')):
|
247 |
+
config = OmegaConf.load(osp.join(args.model[0], 'config.yaml'))
|
248 |
+
else:
|
249 |
+
raise ValueError(f"No config file found at {args.model[0]} or {args.cfg}")
|
250 |
+
|
251 |
+
args.job_dir = config.trainer.params.result_folder
|
252 |
+
|
253 |
+
# Set up the executor
|
254 |
+
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
|
255 |
+
|
256 |
+
# Configure slurm parameters
|
257 |
+
slurm_kwargs = {
|
258 |
+
'slurm_signal_delay_s': 120,
|
259 |
+
'slurm_qos': args.qos
|
260 |
+
}
|
261 |
+
|
262 |
+
if args.comment:
|
263 |
+
slurm_kwargs['slurm_comment'] = args.comment
|
264 |
+
if args.exclude:
|
265 |
+
slurm_kwargs['slurm_exclude'] = args.exclude
|
266 |
+
if args.nodelist:
|
267 |
+
slurm_kwargs['slurm_nodelist'] = args.nodelist
|
268 |
+
|
269 |
+
# Update executor parameters
|
270 |
+
executor.update_parameters(
|
271 |
+
gpus_per_node=args.ngpus,
|
272 |
+
tasks_per_node=args.ngpus, # one task per GPU
|
273 |
+
nodes=args.nodes,
|
274 |
+
timeout_min=args.timeout,
|
275 |
+
slurm_partition=args.partition,
|
276 |
+
name="fid",
|
277 |
+
**slurm_kwargs
|
278 |
+
)
|
279 |
+
|
280 |
+
args.output_dir = args.job_dir
|
281 |
+
|
282 |
+
# Submit the job
|
283 |
+
trainer = Trainer(args)
|
284 |
+
job = executor.submit(trainer)
|
285 |
+
|
286 |
+
print("Submitted job_id:", job.job_id)
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
main()
|
submitit_train.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# A script to run multinode training with submitit.
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import submitit
|
13 |
+
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from paintmind.engine.util import instantiate_from_config
|
16 |
+
from paintmind.utils.device_utils import configure_compute_backend
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser("Submitit for accelerator training")
|
20 |
+
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
|
21 |
+
parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
|
22 |
+
parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days")
|
23 |
+
parser.add_argument("--qos", default="normal", type=str, help="QOS to request")
|
24 |
+
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
|
25 |
+
|
26 |
+
parser.add_argument("--partition", default="h100-camera-train", type=str, help="Partition where to submit")
|
27 |
+
parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition")
|
28 |
+
parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request")
|
29 |
+
parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
|
30 |
+
parser.add_argument('--cfg', type=str, default='configs/dit_imagenet_400ep.yaml', help='accelerator configs')
|
31 |
+
return parser.parse_args()
|
32 |
+
|
33 |
+
|
34 |
+
class Trainer(object):
|
35 |
+
def __init__(self, args, config):
|
36 |
+
self.args = args
|
37 |
+
self.config = config
|
38 |
+
|
39 |
+
def __call__(self):
|
40 |
+
self._setup_gpu_args()
|
41 |
+
configure_compute_backend()
|
42 |
+
trainer = instantiate_from_config(self.config.trainer)
|
43 |
+
trainer.train(self.config)
|
44 |
+
|
45 |
+
def checkpoint(self):
|
46 |
+
import os
|
47 |
+
import submitit
|
48 |
+
|
49 |
+
model_dir = os.path.join(self.args.output_dir, "models")
|
50 |
+
if os.path.exists(model_dir):
|
51 |
+
# Get all step folders
|
52 |
+
step_folders = [d for d in os.listdir(model_dir) if d.startswith("step")]
|
53 |
+
if step_folders:
|
54 |
+
# Extract step numbers and find max
|
55 |
+
steps = [int(f.replace("step", "")) for f in step_folders]
|
56 |
+
max_step = max(steps)
|
57 |
+
# Set ckpt path to the latest step folder
|
58 |
+
self.config.trainer.params.model.params.ckpt_path = os.path.join(model_dir, f"step{max_step}")
|
59 |
+
print("Requeuing ", self.args, self.config)
|
60 |
+
empty_trainer = type(self)(self.args, self.config)
|
61 |
+
return submitit.helpers.DelayedSubmission(empty_trainer)
|
62 |
+
|
63 |
+
def _setup_gpu_args(self):
|
64 |
+
import submitit
|
65 |
+
|
66 |
+
# print_env()
|
67 |
+
print("exporting PyTorch distributed environment variables")
|
68 |
+
dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False)
|
69 |
+
print(f"master: {dist_env.master_addr}:{dist_env.master_port}")
|
70 |
+
print(f"rank: {dist_env.rank}")
|
71 |
+
print(f"world size: {dist_env.world_size}")
|
72 |
+
print(f"local rank: {dist_env.local_rank}")
|
73 |
+
print(f"local world size: {dist_env.local_world_size}")
|
74 |
+
# print_env()
|
75 |
+
|
76 |
+
# os.environ["NCCL_DEBUG"] = "INFO"
|
77 |
+
os.environ["NCCL_P2P_DISABLE"] = "0"
|
78 |
+
os.environ["NCCL_IB_DISABLE"] = "0"
|
79 |
+
|
80 |
+
job_env = submitit.JobEnvironment()
|
81 |
+
self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id))
|
82 |
+
self.args.log_dir = self.args.output_dir
|
83 |
+
self.config.trainer.params.result_folder = self.args.output_dir
|
84 |
+
self.config.trainer.params.log_dir = os.path.join(self.args.output_dir, "logs")
|
85 |
+
# self.args.gpu = job_env.local_rank
|
86 |
+
# self.args.rank = job_env.global_rank
|
87 |
+
# self.args.world_size = job_env.num_tasks
|
88 |
+
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
89 |
+
|
90 |
+
|
91 |
+
def main():
|
92 |
+
args = parse_args()
|
93 |
+
cfg_file = args.cfg
|
94 |
+
assert os.path.exists(cfg_file)
|
95 |
+
config = OmegaConf.load(cfg_file)
|
96 |
+
|
97 |
+
if config.trainer.params.result_folder is None:
|
98 |
+
if args.job_dir == "":
|
99 |
+
args.job_dir = "./output/%j"
|
100 |
+
|
101 |
+
config.trainer.params.result_folder = args.job_dir
|
102 |
+
config.trainer.params.log_dir = os.path.join(args.job_dir, "logs")
|
103 |
+
else:
|
104 |
+
args.job_dir = config.trainer.params.result_folder
|
105 |
+
|
106 |
+
# Note that the folder will depend on the job_id, to easily track experiments
|
107 |
+
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
|
108 |
+
|
109 |
+
num_gpus_per_node = args.ngpus
|
110 |
+
nodes = args.nodes
|
111 |
+
timeout_min = args.timeout
|
112 |
+
qos = args.qos
|
113 |
+
|
114 |
+
partition = args.partition
|
115 |
+
kwargs = {}
|
116 |
+
if args.comment:
|
117 |
+
kwargs['slurm_comment'] = args.comment
|
118 |
+
if args.exclude:
|
119 |
+
kwargs["slurm_exclude"] = args.exclude
|
120 |
+
if args.nodelist:
|
121 |
+
kwargs["slurm_nodelist"] = args.nodelist
|
122 |
+
|
123 |
+
executor.update_parameters(
|
124 |
+
mem_gb=40 * num_gpus_per_node,
|
125 |
+
gpus_per_node=num_gpus_per_node,
|
126 |
+
tasks_per_node=num_gpus_per_node, # one task per GPU
|
127 |
+
# cpus_per_task=16,
|
128 |
+
nodes=nodes,
|
129 |
+
timeout_min=timeout_min, # max is 60 * 72
|
130 |
+
# Below are cluster dependent parameters
|
131 |
+
slurm_partition=partition,
|
132 |
+
slurm_signal_delay_s=120,
|
133 |
+
slurm_qos=qos,
|
134 |
+
**kwargs
|
135 |
+
)
|
136 |
+
|
137 |
+
executor.update_parameters(name="sar")
|
138 |
+
|
139 |
+
args.output_dir = args.job_dir
|
140 |
+
|
141 |
+
trainer = Trainer(args, config)
|
142 |
+
job = executor.submit(trainer)
|
143 |
+
|
144 |
+
print("Submitted job_id:", job.job_id)
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
main()
|