Spaces:
Running
Running
upload
Browse files- README.md +1 -1
- configs/autoregressive_l.yaml +68 -0
- configs/autoregressive_xl.yaml +66 -0
- configs/onenode_config.yaml +11 -0
- configs/tokenizer_l.yaml +55 -0
- configs/tokenizer_xl.yaml +55 -0
- examples/city.jpg +0 -0
- examples/food.jpg +0 -0
- examples/highland.webp +0 -0
- gen_demo.py +262 -0
- requirements.txt +16 -0
- semanticist/engine/diffusion_trainer.py +488 -0
- semanticist/engine/gpt_trainer.py +694 -0
- semanticist/engine/trainer_utils.py +251 -0
- semanticist/stage1/diffuse_slot.py +452 -0
- semanticist/stage1/diffusion/__init__.py +46 -0
- semanticist/stage1/diffusion/diffusion_utils.py +88 -0
- semanticist/stage1/diffusion/gaussian_diffusion.py +886 -0
- semanticist/stage1/diffusion/respace.py +130 -0
- semanticist/stage1/diffusion/timestep_sampler.py +150 -0
- semanticist/stage1/diffusion_transfomer.py +372 -0
- semanticist/stage1/fused_attention.py +45 -0
- semanticist/stage1/pos_embed.py +102 -0
- semanticist/stage1/transport/__init__.py +63 -0
- semanticist/stage1/transport/integrators.py +130 -0
- semanticist/stage1/transport/path.py +192 -0
- semanticist/stage1/transport/transport.py +456 -0
- semanticist/stage1/transport/utils.py +29 -0
- semanticist/stage1/vision_transformer.py +259 -0
- semanticist/stage2/diffloss.py +267 -0
- semanticist/stage2/generate.py +88 -0
- semanticist/stage2/gpt.py +431 -0
- semanticist/utils/datasets.py +72 -0
- semanticist/utils/device_utils.py +18 -0
- semanticist/utils/logger.py +170 -0
- semanticist/utils/lr_scheduler.py +15 -0
- semanticist/utils/transform.py +35 -0
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.20.1
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.20.1
|
8 |
+
app_file: gen_demo.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
configs/autoregressive_l.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer:
|
2 |
+
target: semanticist.engine.gpt_trainer.GPTTrainer
|
3 |
+
params:
|
4 |
+
num_epoch: 400
|
5 |
+
blr: 1e-4
|
6 |
+
cosine_lr: False
|
7 |
+
warmup_epochs: 100
|
8 |
+
batch_size: 16
|
9 |
+
num_workers: 8
|
10 |
+
pin_memory: True
|
11 |
+
grad_accum_steps: 1
|
12 |
+
precision: 'bf16'
|
13 |
+
max_grad_norm: 1.0
|
14 |
+
enable_ema: True
|
15 |
+
save_every: 10000
|
16 |
+
sample_every: 5000
|
17 |
+
fid_every: 50000
|
18 |
+
eval_fid: False
|
19 |
+
result_folder: "./output/autoregressive"
|
20 |
+
log_dir: "./output/autoregressive/logs"
|
21 |
+
ae_cfg: 1.0
|
22 |
+
cfg: 6.0
|
23 |
+
cfg_schedule: "linear"
|
24 |
+
train_num_slots: 32
|
25 |
+
test_num_slots: 32
|
26 |
+
compile: True
|
27 |
+
enable_cache_latents: False
|
28 |
+
ae_model:
|
29 |
+
target: semanticist.stage1.diffuse_slot.DiffuseSlot
|
30 |
+
params:
|
31 |
+
encoder: 'vit_base_patch16'
|
32 |
+
enc_img_size: 256
|
33 |
+
enc_causal: True
|
34 |
+
num_slots: 256
|
35 |
+
slot_dim: 16
|
36 |
+
norm_slots: True
|
37 |
+
cond_method: 'token'
|
38 |
+
dit_model: 'DiT-L-2'
|
39 |
+
vae: 'xwen99/mar-vae-kl16'
|
40 |
+
num_sampling_steps: '250'
|
41 |
+
# ckpt_path: ./output/tokenizer/models_l/step250000/custom_checkpoint_1.pkl
|
42 |
+
ckpt_path: /mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/semanticist_tok_L.pkl
|
43 |
+
|
44 |
+
gpt_model:
|
45 |
+
target: GPT-L
|
46 |
+
params:
|
47 |
+
num_slots: 32
|
48 |
+
slot_dim: 16
|
49 |
+
num_classes: 1000
|
50 |
+
cls_token_num: 1
|
51 |
+
resid_dropout_p: 0.1
|
52 |
+
ffn_dropout_p: 0.1
|
53 |
+
diffloss_d: 12
|
54 |
+
diffloss_w: 1536
|
55 |
+
num_sampling_steps: '100'
|
56 |
+
diffusion_batch_mul: 4
|
57 |
+
use_si: True
|
58 |
+
cond_method: 'concat'
|
59 |
+
ckpt_path: None
|
60 |
+
|
61 |
+
dataset:
|
62 |
+
target: semanticist.utils.datasets.ImageNet
|
63 |
+
params:
|
64 |
+
root: ./dataset/imagenet/
|
65 |
+
split: train
|
66 |
+
# aug: tencrop_cached # or centercrop_cached
|
67 |
+
aug: randcrop
|
68 |
+
img_size: 256
|
configs/autoregressive_xl.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer:
|
2 |
+
target: semanticist.engine.gpt_trainer.GPTTrainer
|
3 |
+
params:
|
4 |
+
num_epoch: 400
|
5 |
+
blr: 1e-4
|
6 |
+
cosine_lr: False
|
7 |
+
warmup_epochs: 100
|
8 |
+
batch_size: 256
|
9 |
+
num_workers: 8
|
10 |
+
pin_memory: True
|
11 |
+
grad_accum_steps: 1
|
12 |
+
precision: 'bf16'
|
13 |
+
max_grad_norm: 1.0
|
14 |
+
enable_ema: True
|
15 |
+
save_every: 10000
|
16 |
+
sample_every: 5000
|
17 |
+
fid_every: 50000
|
18 |
+
eval_fid: False
|
19 |
+
result_folder: "./output/autoregressive"
|
20 |
+
log_dir: "./output/autoregressive/logs"
|
21 |
+
ae_cfg: 1.0
|
22 |
+
cfg: 5.0
|
23 |
+
cfg_schedule: "linear"
|
24 |
+
train_num_slots: 32
|
25 |
+
test_num_slots: 32
|
26 |
+
compile: True
|
27 |
+
enable_cache_latents: True
|
28 |
+
ae_model:
|
29 |
+
target: semanticist.stage1.diffuse_slot.DiffuseSlot
|
30 |
+
params:
|
31 |
+
encoder: 'vit_base_patch16'
|
32 |
+
enc_img_size: 256
|
33 |
+
enc_causal: True
|
34 |
+
num_slots: 256
|
35 |
+
slot_dim: 16
|
36 |
+
norm_slots: True
|
37 |
+
cond_method: 'token'
|
38 |
+
dit_model: 'DiT-XL-2'
|
39 |
+
vae: 'xwen99/mar-vae-kl16'
|
40 |
+
num_sampling_steps: '250'
|
41 |
+
ckpt_path: ./output/tokenizer/models_xl/step250000/custom_checkpoint_1.pkl
|
42 |
+
|
43 |
+
gpt_model:
|
44 |
+
target: GPT-L
|
45 |
+
params:
|
46 |
+
num_slots: 32
|
47 |
+
slot_dim: 16
|
48 |
+
num_classes: 1000
|
49 |
+
cls_token_num: 1
|
50 |
+
resid_dropout_p: 0.1
|
51 |
+
ffn_dropout_p: 0.1
|
52 |
+
diffloss_d: 12
|
53 |
+
diffloss_w: 1536
|
54 |
+
num_sampling_steps: '100'
|
55 |
+
diffusion_batch_mul: 4
|
56 |
+
use_si: True
|
57 |
+
cond_method: 'concat'
|
58 |
+
ckpt_path: None
|
59 |
+
|
60 |
+
dataset:
|
61 |
+
target: semanticist.utils.datasets.ImageNet
|
62 |
+
params:
|
63 |
+
root: ./dataset/imagenet/
|
64 |
+
split: train
|
65 |
+
aug: tencrop_cached # or centercrop_cached
|
66 |
+
img_size: 256
|
configs/onenode_config.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
deepspeed_config: {}
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
fsdp_config: {}
|
5 |
+
machine_rank: 0
|
6 |
+
main_process_ip: null
|
7 |
+
main_process_port: null
|
8 |
+
main_training_function: main
|
9 |
+
num_machines: 1
|
10 |
+
num_processes: 1
|
11 |
+
use_cpu: false
|
configs/tokenizer_l.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer:
|
2 |
+
target: semanticist.engine.diffusion_trainer.DiffusionTrainer
|
3 |
+
params:
|
4 |
+
num_epoch: 400
|
5 |
+
valid_size: 64
|
6 |
+
blr: 2.5e-5
|
7 |
+
cosine_lr: True
|
8 |
+
warmup_epochs: 1
|
9 |
+
batch_size: 64
|
10 |
+
num_workers: 16
|
11 |
+
pin_memory: True
|
12 |
+
grad_accum_steps: 1
|
13 |
+
precision: 'bf16'
|
14 |
+
max_grad_norm: 3.0
|
15 |
+
enable_ema: True
|
16 |
+
save_every: 10000
|
17 |
+
sample_every: 5000
|
18 |
+
fid_every: 50000
|
19 |
+
result_folder: "./output/tokenizer/models_l"
|
20 |
+
log_dir: "./output/tokenizer/models_l/logs"
|
21 |
+
cfg: 3.0
|
22 |
+
compile: True
|
23 |
+
model:
|
24 |
+
target: semanticist.stage1.diffuse_slot.DiffuseSlot
|
25 |
+
params:
|
26 |
+
encoder: 'vit_base_patch16'
|
27 |
+
enc_img_size: 256
|
28 |
+
enc_causal: True
|
29 |
+
enc_use_mlp: False
|
30 |
+
num_slots: 256
|
31 |
+
slot_dim: 16
|
32 |
+
norm_slots: True
|
33 |
+
dit_model: 'DiT-L-2'
|
34 |
+
vae: 'xwen99/mar-vae-kl16'
|
35 |
+
enable_nest: False
|
36 |
+
enable_nest_after: 50
|
37 |
+
use_repa: True
|
38 |
+
eval_fid: True
|
39 |
+
fid_stats: 'fid_stats/adm_in256_stats.npz'
|
40 |
+
num_sampling_steps: '250'
|
41 |
+
ckpt_path: None
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
target: semanticist.utils.datasets.ImageNet
|
45 |
+
params:
|
46 |
+
root: ./dataset/imagenet/
|
47 |
+
split: train
|
48 |
+
img_size: 256
|
49 |
+
|
50 |
+
test_dataset:
|
51 |
+
target: semanticist.utils.datasets.ImageNet
|
52 |
+
params:
|
53 |
+
root: ./dataset/imagenet/
|
54 |
+
split: val
|
55 |
+
img_size: 256
|
configs/tokenizer_xl.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer:
|
2 |
+
target: semanticist.engine.diffusion_trainer.DiffusionTrainer
|
3 |
+
params:
|
4 |
+
num_epoch: 400
|
5 |
+
valid_size: 64
|
6 |
+
blr: 2.5e-5
|
7 |
+
cosine_lr: True
|
8 |
+
warmup_epochs: 100
|
9 |
+
batch_size: 256
|
10 |
+
num_workers: 16
|
11 |
+
pin_memory: True
|
12 |
+
grad_accum_steps: 1
|
13 |
+
precision: 'bf16'
|
14 |
+
max_grad_norm: 3.0
|
15 |
+
enable_ema: True
|
16 |
+
save_every: 10000
|
17 |
+
sample_every: 5000
|
18 |
+
fid_every: 50000
|
19 |
+
result_folder: "./output/tokenizer/models_xl"
|
20 |
+
log_dir: "./output/tokenizer/models_xl/logs"
|
21 |
+
cfg: 3.0
|
22 |
+
compile: True
|
23 |
+
model:
|
24 |
+
target: semanticist.stage1.diffuse_slot.DiffuseSlot
|
25 |
+
params:
|
26 |
+
encoder: 'vit_base_patch16'
|
27 |
+
enc_img_size: 256
|
28 |
+
enc_causal: True
|
29 |
+
enc_use_mlp: False
|
30 |
+
num_slots: 256
|
31 |
+
slot_dim: 16
|
32 |
+
norm_slots: True
|
33 |
+
dit_model: 'DiT-XL-2'
|
34 |
+
vae: 'xwen99/mar-vae-kl16'
|
35 |
+
enable_nest: False
|
36 |
+
enable_nest_after: 50
|
37 |
+
use_repa: True
|
38 |
+
eval_fid: True
|
39 |
+
fid_stats: 'fid_stats/adm_in256_stats.npz'
|
40 |
+
num_sampling_steps: '250'
|
41 |
+
ckpt_path: None
|
42 |
+
|
43 |
+
dataset:
|
44 |
+
target: semanticist.utils.datasets.ImageNet
|
45 |
+
params:
|
46 |
+
root: ./dataset/imagenet/
|
47 |
+
split: train
|
48 |
+
img_size: 256
|
49 |
+
|
50 |
+
test_dataset:
|
51 |
+
target: semanticist.utils.datasets.ImageNet
|
52 |
+
params:
|
53 |
+
root: ./dataset/imagenet/
|
54 |
+
split: val
|
55 |
+
img_size: 256
|
examples/city.jpg
ADDED
![]() |
examples/food.jpg
ADDED
![]() |
examples/highland.webp
ADDED
![]() |
gen_demo.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import os.path as osp
|
5 |
+
import torch
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from tqdm import tqdm
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from semanticist.engine.trainer_utils import instantiate_from_config
|
11 |
+
from semanticist.stage1.diffuse_slot import DiffuseSlot
|
12 |
+
from semanticist.stage2.gpt import GPT_models
|
13 |
+
from semanticist.stage2.generate import generate
|
14 |
+
from safetensors import safe_open
|
15 |
+
from semanticist.utils.datasets import vae_transforms
|
16 |
+
from PIL import Image
|
17 |
+
from imagenet_classes import imagenet_classes
|
18 |
+
|
19 |
+
transform = vae_transforms('test')
|
20 |
+
|
21 |
+
|
22 |
+
def norm_ip(img, low, high):
|
23 |
+
img.clamp_(min=low, max=high)
|
24 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
25 |
+
|
26 |
+
def norm_range(t, value_range):
|
27 |
+
if value_range is not None:
|
28 |
+
norm_ip(t, value_range[0], value_range[1])
|
29 |
+
else:
|
30 |
+
norm_ip(t, float(t.min()), float(t.max()))
|
31 |
+
|
32 |
+
from PIL import Image
|
33 |
+
def convert_np(img):
|
34 |
+
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
|
35 |
+
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
36 |
+
return ndarr
|
37 |
+
def convert_PIL(img):
|
38 |
+
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
|
39 |
+
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
40 |
+
img = Image.fromarray(ndarr)
|
41 |
+
return img
|
42 |
+
|
43 |
+
def norm_slots(slots):
|
44 |
+
mean = torch.mean(slots, dim=-1, keepdim=True)
|
45 |
+
std = torch.std(slots, dim=-1, keepdim=True)
|
46 |
+
return (slots - mean) / std
|
47 |
+
|
48 |
+
def load_state_dict(state_dict, model):
|
49 |
+
"""Helper to load a state dict with proper prefix handling."""
|
50 |
+
if 'state_dict' in state_dict:
|
51 |
+
state_dict = state_dict['state_dict']
|
52 |
+
# Remove '_orig_mod' prefix if present
|
53 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
54 |
+
missing, unexpected = model.load_state_dict(
|
55 |
+
state_dict, strict=False
|
56 |
+
)
|
57 |
+
# print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
|
58 |
+
|
59 |
+
def load_safetensors(path, model):
|
60 |
+
"""Helper to load a safetensors checkpoint."""
|
61 |
+
from safetensors.torch import safe_open
|
62 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
63 |
+
state_dict = {k: f.get_tensor(k) for k in f.keys()}
|
64 |
+
load_state_dict(state_dict, model)
|
65 |
+
|
66 |
+
def load_checkpoint(ckpt_path, model):
|
67 |
+
if ckpt_path is None or not osp.exists(ckpt_path):
|
68 |
+
return
|
69 |
+
|
70 |
+
if osp.isdir(ckpt_path):
|
71 |
+
# ckpt_path is something like 'path/to/models/step10/'
|
72 |
+
model_path = osp.join(ckpt_path, "model.safetensors")
|
73 |
+
if osp.exists(model_path):
|
74 |
+
load_safetensors(model_path, model)
|
75 |
+
else:
|
76 |
+
# ckpt_path is something like 'path/to/models/step10.pt'
|
77 |
+
if ckpt_path.endswith(".safetensors"):
|
78 |
+
load_safetensors(ckpt_path, model)
|
79 |
+
else:
|
80 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
81 |
+
load_state_dict(state_dict, model)
|
82 |
+
|
83 |
+
print(f"Loaded checkpoint from {ckpt_path}")
|
84 |
+
|
85 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
86 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
87 |
+
if device == 'cuda':
|
88 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
89 |
+
|
90 |
+
ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_ar_gen_L.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
|
91 |
+
config_path = 'configs/autoregressive_xl.yaml'
|
92 |
+
|
93 |
+
cfg = OmegaConf.load(config_path)
|
94 |
+
params = cfg.trainer.params
|
95 |
+
|
96 |
+
ae_model = instantiate_from_config(params.ae_model).to(device)
|
97 |
+
ae_model_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_tok_XL.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
|
98 |
+
load_checkpoint(ae_model_path, ae_model)
|
99 |
+
ae_model.eval()
|
100 |
+
|
101 |
+
gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(device)
|
102 |
+
load_checkpoint(ckpt_path, gpt_model)
|
103 |
+
gpt_model.eval();
|
104 |
+
|
105 |
+
def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
|
106 |
+
n_slots_inf = []
|
107 |
+
for num_slots_to_inference in nums:
|
108 |
+
drop_mask = model.nested_sampler(slots.shape[0], device, num_slots_to_inference)
|
109 |
+
recon_n = model.sample(slots, drop_mask=drop_mask, cfg=cfg)
|
110 |
+
n_slots_inf.append(recon_n)
|
111 |
+
return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
|
112 |
+
|
113 |
+
num_slots = params.ae_model.params.num_slots
|
114 |
+
slot_dim = params.ae_model.params.slot_dim
|
115 |
+
dtype = torch.bfloat16
|
116 |
+
# the model is trained with only 32 tokens.
|
117 |
+
num_slots_to_gen = 32
|
118 |
+
|
119 |
+
# Function to generate image from class
|
120 |
+
def generate_from_class(class_id, cfg_scale):
|
121 |
+
with torch.no_grad():
|
122 |
+
dtype = torch.bfloat16
|
123 |
+
num_slots_to_gen = 32
|
124 |
+
with torch.autocast(device, dtype=dtype):
|
125 |
+
slots_gen = generate(
|
126 |
+
gpt_model,
|
127 |
+
torch.tensor([class_id]).to(device),
|
128 |
+
num_slots_to_gen,
|
129 |
+
cfg_scale=cfg_scale,
|
130 |
+
cfg_schedule="linear"
|
131 |
+
)
|
132 |
+
if num_slots_to_gen < num_slots:
|
133 |
+
null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
|
134 |
+
null_slots = null_slots[:, num_slots_to_gen:, :]
|
135 |
+
slots_gen = torch.cat([slots_gen, null_slots], dim=1)
|
136 |
+
return slots_gen
|
137 |
+
|
138 |
+
with gr.Blocks() as demo:
|
139 |
+
with gr.Row():
|
140 |
+
# First column - Input and configs
|
141 |
+
with gr.Column(scale=1):
|
142 |
+
gr.Markdown("## Input")
|
143 |
+
|
144 |
+
# Replace image input with ImageNet class selection
|
145 |
+
imagenet_classes = {k: v for k, v in enumerate(imagenet_classes)}
|
146 |
+
class_choices = [f"{id}: {name}" for id, name in imagenet_classes.items()]
|
147 |
+
|
148 |
+
# Dropdown for class selection
|
149 |
+
class_dropdown = gr.Dropdown(
|
150 |
+
choices=class_choices[:20], # Limit for demonstration
|
151 |
+
label="Select ImageNet Class",
|
152 |
+
value=class_choices[0] if class_choices else None
|
153 |
+
)
|
154 |
+
|
155 |
+
# Option to enter class ID directly
|
156 |
+
class_id_input = gr.Number(
|
157 |
+
label="Or enter class ID directly (0-999)",
|
158 |
+
value=0,
|
159 |
+
minimum=0,
|
160 |
+
maximum=999,
|
161 |
+
step=1
|
162 |
+
)
|
163 |
+
|
164 |
+
with gr.Group():
|
165 |
+
gr.Markdown("### Configuration")
|
166 |
+
show_gallery = gr.Checkbox(label="Show Gallery", value=True)
|
167 |
+
slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
|
168 |
+
labels_input = gr.Textbox(
|
169 |
+
label="Number of tokens to reconstruct (comma-separated)",
|
170 |
+
value="1, 2, 4, 8, 16",
|
171 |
+
placeholder="Enter comma-separated numbers for the number of slots to use"
|
172 |
+
)
|
173 |
+
|
174 |
+
# Second column - Output (conditionally rendered)
|
175 |
+
with gr.Column(scale=1):
|
176 |
+
gr.Markdown("## Output")
|
177 |
+
|
178 |
+
# Container for conditional rendering
|
179 |
+
with gr.Group(visible=True) as gallery_container:
|
180 |
+
gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
|
181 |
+
|
182 |
+
# Always visible output image
|
183 |
+
output_image = gr.Image(label="Generated Image", type="numpy")
|
184 |
+
|
185 |
+
# Handle form submission
|
186 |
+
submit_btn = gr.Button("Generate")
|
187 |
+
|
188 |
+
# Define the processing logic
|
189 |
+
def update_outputs(class_selection, class_id, show_gallery_value, slider_value, labels_text):
|
190 |
+
# Determine which class to use - either from dropdown or direct input
|
191 |
+
if class_selection:
|
192 |
+
# Extract class ID from the dropdown selection
|
193 |
+
selected_class_id = int(class_selection.split(":")[0])
|
194 |
+
else:
|
195 |
+
selected_class_id = int(class_id)
|
196 |
+
|
197 |
+
# Update the visibility of the gallery container
|
198 |
+
gallery_container.visible = show_gallery_value
|
199 |
+
|
200 |
+
try:
|
201 |
+
# Parse the labels from the text input
|
202 |
+
if labels_text and "," in labels_text:
|
203 |
+
labels = [int(label.strip()) for label in labels_text.split(",")]
|
204 |
+
else:
|
205 |
+
# Default labels if none provided or in wrong format
|
206 |
+
labels = [1, 4, 16, 64, 256]
|
207 |
+
except:
|
208 |
+
labels = [1, 4, 16, 64, 256]
|
209 |
+
|
210 |
+
while len(labels) < 3:
|
211 |
+
labels.append(256)
|
212 |
+
|
213 |
+
# Generate the image based on the selected class
|
214 |
+
slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
|
215 |
+
|
216 |
+
recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
|
217 |
+
|
218 |
+
# Always generate the model decomposition for potential gallery display
|
219 |
+
model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
|
220 |
+
|
221 |
+
if not show_gallery_value:
|
222 |
+
# If only the image should be shown, return just the processed image
|
223 |
+
return gallery_container, [], recon
|
224 |
+
else:
|
225 |
+
# Create image variations and pair them with labels
|
226 |
+
gallery_images = [
|
227 |
+
(recon, f'Generated from class {selected_class_id}'),
|
228 |
+
] + [(img, 'Gen. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
|
229 |
+
return gallery_container, gallery_images, recon
|
230 |
+
|
231 |
+
# Connect the inputs and outputs
|
232 |
+
submit_btn.click(
|
233 |
+
fn=update_outputs,
|
234 |
+
inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
|
235 |
+
outputs=[gallery_container, gallery, output_image]
|
236 |
+
)
|
237 |
+
|
238 |
+
# Also update when checkbox changes
|
239 |
+
show_gallery.change(
|
240 |
+
fn=lambda value: gr.update(visible=value),
|
241 |
+
inputs=[show_gallery],
|
242 |
+
outputs=[gallery_container]
|
243 |
+
)
|
244 |
+
|
245 |
+
# Add examples
|
246 |
+
examples = [
|
247 |
+
# ["0: tench, Tinca tinca", 0, True, 4.0, "1,2,4,8,16"],
|
248 |
+
["1: goldfish, Carassius auratus", 1, True, 4.0, "1,2,4,8,16"],
|
249 |
+
# ["2: great white shark, white shark", 2, True, 4.0, "1,2,4,8,16"],
|
250 |
+
]
|
251 |
+
|
252 |
+
gr.Examples(
|
253 |
+
examples=examples,
|
254 |
+
inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
|
255 |
+
outputs=[gallery_container, gallery, output_image],
|
256 |
+
fn=update_outputs,
|
257 |
+
cache_examples=False
|
258 |
+
)
|
259 |
+
|
260 |
+
# Launch the demo
|
261 |
+
if __name__ == "__main__":
|
262 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.4
|
2 |
+
accelerate
|
3 |
+
diffusers[torch]
|
4 |
+
transformers
|
5 |
+
safetensors
|
6 |
+
omegaconf
|
7 |
+
tensorboard
|
8 |
+
huggingface-hub
|
9 |
+
einops
|
10 |
+
timm
|
11 |
+
scipy
|
12 |
+
scikit-learn
|
13 |
+
scikit-image
|
14 |
+
git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
|
15 |
+
opencv-python-headless
|
16 |
+
torchmetrics
|
semanticist/engine/diffusion_trainer.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import os.path as osp
|
3 |
+
import shutil
|
4 |
+
from tqdm.auto import tqdm
|
5 |
+
from einops import rearrange
|
6 |
+
from accelerate import Accelerator
|
7 |
+
from torchvision.utils import make_grid, save_image
|
8 |
+
from torch.utils.data import DataLoader, random_split, DistributedSampler
|
9 |
+
from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache
|
10 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
11 |
+
from torchmetrics.functional.image import (
|
12 |
+
peak_signal_noise_ratio as psnr,
|
13 |
+
structural_similarity_index_measure as ssim
|
14 |
+
)
|
15 |
+
from semanticist.engine.trainer_utils import (
|
16 |
+
instantiate_from_config, concat_all_gather,
|
17 |
+
save_img_batch, get_fid_stats,
|
18 |
+
EMAModel, PaddedDataset, create_scheduler, load_state_dict,
|
19 |
+
load_safetensors, setup_result_folders, create_optimizer
|
20 |
+
)
|
21 |
+
|
22 |
+
class DiffusionTrainer:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
model,
|
26 |
+
dataset,
|
27 |
+
test_dataset=None,
|
28 |
+
test_only=False,
|
29 |
+
num_epoch=400,
|
30 |
+
valid_size=32,
|
31 |
+
blr=1e-4,
|
32 |
+
cosine_lr=True,
|
33 |
+
lr_min=0,
|
34 |
+
warmup_epochs=100,
|
35 |
+
warmup_steps=None,
|
36 |
+
warmup_lr_init=0,
|
37 |
+
decay_steps=None,
|
38 |
+
batch_size=32,
|
39 |
+
eval_bs=32,
|
40 |
+
test_bs=64,
|
41 |
+
num_workers=8,
|
42 |
+
pin_memory=False,
|
43 |
+
max_grad_norm=None,
|
44 |
+
grad_accum_steps=1,
|
45 |
+
precision='bf16',
|
46 |
+
save_every=10000,
|
47 |
+
sample_every=1000,
|
48 |
+
fid_every=50000,
|
49 |
+
result_folder=None,
|
50 |
+
log_dir="./log",
|
51 |
+
cfg=3.0,
|
52 |
+
test_num_slots=None,
|
53 |
+
eval_fid=False,
|
54 |
+
fid_stats=None,
|
55 |
+
enable_ema=False,
|
56 |
+
compile=False,
|
57 |
+
):
|
58 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
59 |
+
self.accelerator = Accelerator(
|
60 |
+
kwargs_handlers=[kwargs],
|
61 |
+
mixed_precision=precision,
|
62 |
+
gradient_accumulation_steps=grad_accum_steps,
|
63 |
+
log_with="tensorboard",
|
64 |
+
project_dir=log_dir,
|
65 |
+
)
|
66 |
+
|
67 |
+
self.model = instantiate_from_config(model)
|
68 |
+
self.num_slots = model.params.num_slots
|
69 |
+
|
70 |
+
if test_dataset is not None:
|
71 |
+
test_dataset = instantiate_from_config(test_dataset)
|
72 |
+
self.test_ds = test_dataset
|
73 |
+
|
74 |
+
# Calculate padded dataset size to ensure even distribution
|
75 |
+
total_size = len(test_dataset)
|
76 |
+
world_size = self.accelerator.num_processes
|
77 |
+
padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
|
78 |
+
self.test_dataset_size = total_size
|
79 |
+
|
80 |
+
self.test_ds = PaddedDataset(self.test_ds, padding_size)
|
81 |
+
self.test_dl = DataLoader(
|
82 |
+
self.test_ds,
|
83 |
+
batch_size=test_bs,
|
84 |
+
num_workers=num_workers,
|
85 |
+
pin_memory=pin_memory,
|
86 |
+
shuffle=False,
|
87 |
+
drop_last=True,
|
88 |
+
)
|
89 |
+
if self.accelerator.is_main_process:
|
90 |
+
print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
|
91 |
+
else:
|
92 |
+
self.test_dl = None
|
93 |
+
self.test_only = test_only
|
94 |
+
|
95 |
+
if not test_only:
|
96 |
+
dataset = instantiate_from_config(dataset)
|
97 |
+
train_size = len(dataset) - valid_size
|
98 |
+
self.train_ds, self.valid_ds = random_split(
|
99 |
+
dataset,
|
100 |
+
[train_size, valid_size],
|
101 |
+
generator=torch.Generator().manual_seed(42),
|
102 |
+
)
|
103 |
+
if self.accelerator.is_main_process:
|
104 |
+
print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
|
105 |
+
|
106 |
+
sampler = DistributedSampler(
|
107 |
+
self.train_ds,
|
108 |
+
num_replicas=self.accelerator.num_processes,
|
109 |
+
rank=self.accelerator.process_index,
|
110 |
+
shuffle=True,
|
111 |
+
)
|
112 |
+
self.train_dl = DataLoader(
|
113 |
+
self.train_ds,
|
114 |
+
batch_size=batch_size,
|
115 |
+
sampler=sampler,
|
116 |
+
num_workers=num_workers,
|
117 |
+
pin_memory=pin_memory,
|
118 |
+
drop_last=True,
|
119 |
+
)
|
120 |
+
self.valid_dl = DataLoader(
|
121 |
+
self.valid_ds,
|
122 |
+
batch_size=eval_bs,
|
123 |
+
shuffle=False,
|
124 |
+
num_workers=num_workers,
|
125 |
+
pin_memory=pin_memory,
|
126 |
+
drop_last=False,
|
127 |
+
)
|
128 |
+
|
129 |
+
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
|
130 |
+
lr = blr * effective_bs / 256
|
131 |
+
if self.accelerator.is_main_process:
|
132 |
+
print(f"Effective batch size is {effective_bs}")
|
133 |
+
|
134 |
+
self.g_optim = create_optimizer(self.model, weight_decay=0.05, learning_rate=lr,) # accelerator=self.accelerator)
|
135 |
+
|
136 |
+
if warmup_epochs is not None:
|
137 |
+
warmup_steps = warmup_epochs * len(self.train_dl)
|
138 |
+
|
139 |
+
self.g_sched = create_scheduler(
|
140 |
+
self.g_optim,
|
141 |
+
num_epoch,
|
142 |
+
len(self.train_dl),
|
143 |
+
lr_min,
|
144 |
+
warmup_steps,
|
145 |
+
warmup_lr_init,
|
146 |
+
decay_steps,
|
147 |
+
cosine_lr
|
148 |
+
)
|
149 |
+
self.accelerator.register_for_checkpointing(self.g_sched)
|
150 |
+
self.model, self.g_optim, self.g_sched = self.accelerator.prepare(self.model, self.g_optim, self.g_sched)
|
151 |
+
else:
|
152 |
+
self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
|
153 |
+
|
154 |
+
self.steps = 0
|
155 |
+
self.loaded_steps = -1
|
156 |
+
|
157 |
+
if compile:
|
158 |
+
_model = self.accelerator.unwrap_model(self.model)
|
159 |
+
_model.vae = torch.compile(_model.vae, mode="reduce-overhead")
|
160 |
+
_model.dit = torch.compile(_model.dit, mode="reduce-overhead")
|
161 |
+
# _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why
|
162 |
+
_model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
|
163 |
+
|
164 |
+
self.enable_ema = enable_ema
|
165 |
+
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
|
166 |
+
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
|
167 |
+
self.accelerator.register_for_checkpointing(self.ema_model)
|
168 |
+
|
169 |
+
self._load_checkpoint(model.params.ckpt_path)
|
170 |
+
if self.test_only:
|
171 |
+
self.steps = self.loaded_steps
|
172 |
+
|
173 |
+
self.num_epoch = num_epoch
|
174 |
+
self.save_every = save_every
|
175 |
+
self.sample_every = sample_every
|
176 |
+
self.fid_every = fid_every
|
177 |
+
self.max_grad_norm = max_grad_norm
|
178 |
+
|
179 |
+
self.cfg = cfg
|
180 |
+
self.test_num_slots = test_num_slots
|
181 |
+
if self.test_num_slots is not None:
|
182 |
+
self.test_num_slots = min(self.test_num_slots, self.num_slots)
|
183 |
+
else:
|
184 |
+
self.test_num_slots = self.num_slots
|
185 |
+
eval_fid = eval_fid or model.params.eval_fid # legacy
|
186 |
+
self.eval_fid = eval_fid
|
187 |
+
if eval_fid:
|
188 |
+
if fid_stats is None:
|
189 |
+
fid_stats = model.params.fid_stats # legacy
|
190 |
+
assert fid_stats is not None
|
191 |
+
assert test_dataset is not None
|
192 |
+
self.fid_stats = fid_stats
|
193 |
+
|
194 |
+
self.result_folder = result_folder
|
195 |
+
self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder)
|
196 |
+
|
197 |
+
@property
|
198 |
+
def device(self):
|
199 |
+
return self.accelerator.device
|
200 |
+
|
201 |
+
def _load_checkpoint(self, ckpt_path=None):
|
202 |
+
if ckpt_path is None or not osp.exists(ckpt_path):
|
203 |
+
return
|
204 |
+
|
205 |
+
model = self.accelerator.unwrap_model(self.model)
|
206 |
+
|
207 |
+
if osp.isdir(ckpt_path):
|
208 |
+
# ckpt_path is something like 'path/to/models/step10/'
|
209 |
+
self.loaded_steps = int(
|
210 |
+
ckpt_path.split("step")[-1].split("/")[0]
|
211 |
+
)
|
212 |
+
if not self.test_only:
|
213 |
+
self.accelerator.load_state(ckpt_path)
|
214 |
+
else:
|
215 |
+
if self.enable_ema:
|
216 |
+
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
|
217 |
+
if osp.exists(model_path):
|
218 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
219 |
+
load_state_dict(state_dict, model)
|
220 |
+
if self.accelerator.is_main_process:
|
221 |
+
print(f"Loaded ema model from {model_path}")
|
222 |
+
else:
|
223 |
+
model_path = osp.join(ckpt_path, "model.safetensors")
|
224 |
+
if osp.exists(model_path):
|
225 |
+
load_safetensors(model_path, model)
|
226 |
+
else:
|
227 |
+
# ckpt_path is something like 'path/to/models/step10.pt'
|
228 |
+
if ckpt_path.endswith(".safetensors"):
|
229 |
+
load_safetensors(ckpt_path, model)
|
230 |
+
else:
|
231 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
232 |
+
load_state_dict(state_dict, model)
|
233 |
+
if self.accelerator.is_main_process:
|
234 |
+
print(f"Loaded checkpoint from {ckpt_path}")
|
235 |
+
|
236 |
+
def train(self, config=None):
|
237 |
+
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
238 |
+
if self.accelerator.is_main_process:
|
239 |
+
print(f"number of learnable parameters: {n_parameters//1e6}M")
|
240 |
+
if config is not None:
|
241 |
+
# save the config
|
242 |
+
from omegaconf import OmegaConf
|
243 |
+
if isinstance(config, str) and osp.exists(config):
|
244 |
+
# If it's a path, copy the file to config.yaml
|
245 |
+
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
|
246 |
+
else:
|
247 |
+
# If it's an OmegaConf object, dump it
|
248 |
+
config_save_path = osp.join(self.result_folder, "config.yaml")
|
249 |
+
OmegaConf.save(config, config_save_path)
|
250 |
+
|
251 |
+
self.accelerator.init_trackers("semanticist")
|
252 |
+
|
253 |
+
if self.test_only:
|
254 |
+
empty_cache()
|
255 |
+
self.evaluate()
|
256 |
+
self.accelerator.wait_for_everyone()
|
257 |
+
empty_cache()
|
258 |
+
return
|
259 |
+
|
260 |
+
for epoch in range(self.num_epoch):
|
261 |
+
if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
|
262 |
+
if self.accelerator.is_main_process:
|
263 |
+
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
|
264 |
+
self.steps += len(self.train_dl)
|
265 |
+
continue
|
266 |
+
|
267 |
+
if self.steps < self.loaded_steps:
|
268 |
+
for _ in self.train_dl:
|
269 |
+
self.steps += 1
|
270 |
+
if self.steps >= self.loaded_steps:
|
271 |
+
break
|
272 |
+
|
273 |
+
|
274 |
+
self.accelerator.unwrap_model(self.model).current_epoch = epoch
|
275 |
+
self.model.train() # Set model to training mode
|
276 |
+
|
277 |
+
logger = MetricLogger(delimiter=" ")
|
278 |
+
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
279 |
+
header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
|
280 |
+
print_freq = 20
|
281 |
+
for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
|
282 |
+
img, _ = batch
|
283 |
+
img = img.to(self.device, non_blocking=True)
|
284 |
+
self.steps += 1
|
285 |
+
|
286 |
+
with self.accelerator.accumulate(self.model):
|
287 |
+
with self.accelerator.autocast():
|
288 |
+
if self.steps == 1:
|
289 |
+
print(f"Training batch size: {img.size(0)}")
|
290 |
+
print(f"Hello from index {self.accelerator.local_process_index}")
|
291 |
+
losses = self.model(img, epoch=epoch)
|
292 |
+
# combine
|
293 |
+
loss = sum([v for _, v in losses.items()])
|
294 |
+
|
295 |
+
self.accelerator.backward(loss)
|
296 |
+
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
|
297 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
298 |
+
self.g_optim.step()
|
299 |
+
if self.g_sched is not None:
|
300 |
+
self.g_sched.step_update(self.steps)
|
301 |
+
self.g_optim.zero_grad()
|
302 |
+
|
303 |
+
self.accelerator.wait_for_everyone()
|
304 |
+
|
305 |
+
# update ema with state dict
|
306 |
+
if self.enable_ema:
|
307 |
+
self.ema_model.update(self.accelerator.unwrap_model(self.model))
|
308 |
+
|
309 |
+
for key, value in losses.items():
|
310 |
+
logger.update(**{key: value.item()})
|
311 |
+
logger.update(lr=self.g_optim.param_groups[0]["lr"])
|
312 |
+
|
313 |
+
if self.steps % self.save_every == 0:
|
314 |
+
self.save()
|
315 |
+
|
316 |
+
if (self.steps % self.sample_every == 0) or (self.steps % self.fid_every == 0):
|
317 |
+
empty_cache()
|
318 |
+
self.evaluate()
|
319 |
+
self.accelerator.wait_for_everyone()
|
320 |
+
empty_cache()
|
321 |
+
|
322 |
+
write_dict = dict(epoch=epoch)
|
323 |
+
for key, value in losses.items(): # omitted all_gather here
|
324 |
+
write_dict.update(**{key: value.item()})
|
325 |
+
write_dict.update(lr=self.g_optim.param_groups[0]["lr"])
|
326 |
+
self.accelerator.log(write_dict, step=self.steps)
|
327 |
+
|
328 |
+
logger.synchronize_between_processes()
|
329 |
+
if self.accelerator.is_main_process:
|
330 |
+
print("Averaged stats:", logger)
|
331 |
+
|
332 |
+
self.accelerator.end_training()
|
333 |
+
self.save()
|
334 |
+
if self.accelerator.is_main_process:
|
335 |
+
print("Train finished!")
|
336 |
+
|
337 |
+
def save(self):
|
338 |
+
self.accelerator.wait_for_everyone()
|
339 |
+
self.accelerator.save_state(
|
340 |
+
os.path.join(self.model_saved_dir, f"step{self.steps}")
|
341 |
+
)
|
342 |
+
|
343 |
+
@torch.no_grad()
|
344 |
+
def evaluate(self):
|
345 |
+
self.model.eval()
|
346 |
+
if not self.test_only:
|
347 |
+
with tqdm(
|
348 |
+
self.valid_dl,
|
349 |
+
dynamic_ncols=True,
|
350 |
+
disable=not self.accelerator.is_main_process,
|
351 |
+
) as valid_dl:
|
352 |
+
for batch_i, batch in enumerate(valid_dl):
|
353 |
+
if isinstance(batch, tuple) or isinstance(batch, list):
|
354 |
+
img, targets = batch[0], batch[1]
|
355 |
+
else:
|
356 |
+
img = batch
|
357 |
+
|
358 |
+
with self.accelerator.autocast():
|
359 |
+
rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
|
360 |
+
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
|
361 |
+
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
|
362 |
+
imgs_and_recs = imgs_and_recs.detach().cpu().float()
|
363 |
+
|
364 |
+
grid = make_grid(
|
365 |
+
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
|
366 |
+
)
|
367 |
+
if self.accelerator.is_main_process:
|
368 |
+
save_image(
|
369 |
+
grid,
|
370 |
+
os.path.join(
|
371 |
+
self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
|
372 |
+
),
|
373 |
+
)
|
374 |
+
|
375 |
+
if self.cfg != 1.0:
|
376 |
+
with self.accelerator.autocast():
|
377 |
+
rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
|
378 |
+
|
379 |
+
imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
|
380 |
+
imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
|
381 |
+
imgs_and_recs = imgs_and_recs.detach().cpu().float()
|
382 |
+
|
383 |
+
grid = make_grid(
|
384 |
+
imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
|
385 |
+
)
|
386 |
+
if self.accelerator.is_main_process:
|
387 |
+
save_image(
|
388 |
+
grid,
|
389 |
+
os.path.join(
|
390 |
+
self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
|
391 |
+
),
|
392 |
+
)
|
393 |
+
if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
|
394 |
+
real_dir = "./dataset/imagenet/val256"
|
395 |
+
rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
|
396 |
+
os.makedirs(rec_dir, exist_ok=True)
|
397 |
+
|
398 |
+
if self.cfg != 1.0:
|
399 |
+
rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
|
400 |
+
os.makedirs(rec_cfg_dir, exist_ok=True)
|
401 |
+
|
402 |
+
def process_batch(cfg_value, save_dir, header):
|
403 |
+
logger = MetricLogger(delimiter=" ")
|
404 |
+
print_freq = 5
|
405 |
+
psnr_values = []
|
406 |
+
ssim_values = []
|
407 |
+
total_processed = 0
|
408 |
+
|
409 |
+
for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
|
410 |
+
imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
|
411 |
+
|
412 |
+
# Skip processing if we've already processed all real samples
|
413 |
+
if total_processed >= self.test_dataset_size:
|
414 |
+
break
|
415 |
+
|
416 |
+
imgs = imgs.to(self.device, non_blocking=True)
|
417 |
+
if targets is not None:
|
418 |
+
targets = targets.to(self.device, non_blocking=True)
|
419 |
+
|
420 |
+
with self.accelerator.autocast():
|
421 |
+
recs = self.model(imgs, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
|
422 |
+
|
423 |
+
psnr_val = psnr(recs, imgs, data_range=1.0)
|
424 |
+
ssim_val = ssim(recs, imgs, data_range=1.0)
|
425 |
+
|
426 |
+
recs = concat_all_gather(recs).detach()
|
427 |
+
psnr_val = concat_all_gather(psnr_val.view(1))
|
428 |
+
ssim_val = concat_all_gather(ssim_val.view(1))
|
429 |
+
|
430 |
+
# Remove padding after gathering from all GPUs
|
431 |
+
samples_in_batch = min(
|
432 |
+
recs.size(0), # Always use the gathered size
|
433 |
+
self.test_dataset_size - total_processed
|
434 |
+
)
|
435 |
+
recs = recs[:samples_in_batch]
|
436 |
+
psnr_val = psnr_val[:samples_in_batch]
|
437 |
+
ssim_val = ssim_val[:samples_in_batch]
|
438 |
+
psnr_values.append(psnr_val)
|
439 |
+
ssim_values.append(ssim_val)
|
440 |
+
|
441 |
+
if self.accelerator.is_main_process:
|
442 |
+
rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
|
443 |
+
for j in range(recs.size(0))]
|
444 |
+
save_img_batch(recs.cpu(), rec_paths)
|
445 |
+
|
446 |
+
total_processed += samples_in_batch
|
447 |
+
|
448 |
+
self.accelerator.wait_for_everyone()
|
449 |
+
|
450 |
+
return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
|
451 |
+
|
452 |
+
# Helper function to calculate and log metrics
|
453 |
+
def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
|
454 |
+
if self.accelerator.is_main_process:
|
455 |
+
metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
|
456 |
+
fid = metrics_dict["frechet_inception_distance"]
|
457 |
+
inception_score = metrics_dict["inception_score_mean"]
|
458 |
+
|
459 |
+
metric_prefix = "fid"
|
460 |
+
isc_prefix = "isc"
|
461 |
+
self.accelerator.log({
|
462 |
+
metric_prefix: fid,
|
463 |
+
isc_prefix: inception_score,
|
464 |
+
f"psnr": psnr_val,
|
465 |
+
f"ssim": ssim_val,
|
466 |
+
"cfg": cfg_value
|
467 |
+
}, step=self.steps)
|
468 |
+
|
469 |
+
print(f"{'CFG: {cfg_value}'} "
|
470 |
+
f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
|
471 |
+
f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
|
472 |
+
|
473 |
+
# Process without CFG
|
474 |
+
if self.cfg == 1.0 or not self.test_only:
|
475 |
+
psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
|
476 |
+
calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
|
477 |
+
|
478 |
+
# Process with CFG if needed
|
479 |
+
if self.cfg != 1.0:
|
480 |
+
psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
|
481 |
+
calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
|
482 |
+
|
483 |
+
# Cleanup
|
484 |
+
if self.accelerator.is_main_process:
|
485 |
+
shutil.rmtree(rec_dir)
|
486 |
+
if self.cfg != 1.0:
|
487 |
+
shutil.rmtree(rec_cfg_dir)
|
488 |
+
self.model.train()
|
semanticist/engine/gpt_trainer.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import os.path as osp
|
3 |
+
import shutil
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
+
import torch.nn as nn
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from torchvision.utils import make_grid, save_image
|
10 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
11 |
+
from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache
|
12 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
13 |
+
from semanticist.stage2.gpt import GPT_models
|
14 |
+
from semanticist.stage2.generate import generate
|
15 |
+
from pathlib import Path
|
16 |
+
import time
|
17 |
+
|
18 |
+
from semanticist.engine.trainer_utils import (
|
19 |
+
instantiate_from_config, concat_all_gather,
|
20 |
+
save_img_batch, get_fid_stats,
|
21 |
+
EMAModel, create_scheduler, load_state_dict, load_safetensors,
|
22 |
+
setup_result_folders, create_optimizer,
|
23 |
+
CacheDataLoader
|
24 |
+
)
|
25 |
+
|
26 |
+
class GPTTrainer(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
ae_model,
|
30 |
+
gpt_model,
|
31 |
+
dataset,
|
32 |
+
test_only=False,
|
33 |
+
num_test_images=50000,
|
34 |
+
num_epoch=400,
|
35 |
+
eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear
|
36 |
+
blr=1e-4,
|
37 |
+
cosine_lr=False,
|
38 |
+
lr_min=0,
|
39 |
+
warmup_epochs=100,
|
40 |
+
warmup_steps=None,
|
41 |
+
warmup_lr_init=0,
|
42 |
+
decay_steps=None,
|
43 |
+
batch_size=32,
|
44 |
+
cache_bs=8,
|
45 |
+
test_bs=100,
|
46 |
+
num_workers=8,
|
47 |
+
pin_memory=False,
|
48 |
+
max_grad_norm=None,
|
49 |
+
grad_accum_steps=1,
|
50 |
+
precision='bf16',
|
51 |
+
save_every=10000,
|
52 |
+
sample_every=1000,
|
53 |
+
fid_every=50000,
|
54 |
+
result_folder=None,
|
55 |
+
log_dir="./log",
|
56 |
+
ae_cfg=1.0,
|
57 |
+
cfg=6.0,
|
58 |
+
cfg_schedule="linear",
|
59 |
+
temperature=1.0,
|
60 |
+
train_num_slots=None,
|
61 |
+
test_num_slots=None,
|
62 |
+
eval_fid=False,
|
63 |
+
fid_stats=None,
|
64 |
+
enable_ema=False,
|
65 |
+
compile=False,
|
66 |
+
enable_cache_latents=True,
|
67 |
+
cache_dir='/dev/shm/slot_cache'
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
71 |
+
self.accelerator = Accelerator(
|
72 |
+
kwargs_handlers=[kwargs],
|
73 |
+
mixed_precision=precision,
|
74 |
+
gradient_accumulation_steps=grad_accum_steps,
|
75 |
+
log_with="tensorboard",
|
76 |
+
project_dir=log_dir,
|
77 |
+
)
|
78 |
+
|
79 |
+
self.ae_model = instantiate_from_config(ae_model)
|
80 |
+
ae_model_path = ae_model.params.ckpt_path
|
81 |
+
assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl")
|
82 |
+
assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist"
|
83 |
+
self._load_checkpoint(ae_model_path, self.ae_model)
|
84 |
+
|
85 |
+
self.ae_model.to(self.device)
|
86 |
+
for param in self.ae_model.parameters():
|
87 |
+
param.requires_grad = False
|
88 |
+
self.ae_model.eval()
|
89 |
+
|
90 |
+
self.model_name = gpt_model.target
|
91 |
+
if 'GPT' in gpt_model.target:
|
92 |
+
self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params)
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Unknown model type: {gpt_model.target}")
|
95 |
+
self.num_slots = ae_model.params.num_slots
|
96 |
+
self.slot_dim = ae_model.params.slot_dim
|
97 |
+
|
98 |
+
self.test_only = test_only
|
99 |
+
self.test_bs = test_bs
|
100 |
+
self.num_test_images = num_test_images
|
101 |
+
self.num_classes = gpt_model.params.num_classes
|
102 |
+
self.batch_size = batch_size
|
103 |
+
if not test_only:
|
104 |
+
self.train_ds = instantiate_from_config(dataset)
|
105 |
+
train_size = len(self.train_ds)
|
106 |
+
if self.accelerator.is_main_process:
|
107 |
+
print(f"train dataset size: {train_size}")
|
108 |
+
|
109 |
+
sampler = DistributedSampler(
|
110 |
+
self.train_ds,
|
111 |
+
num_replicas=self.accelerator.num_processes,
|
112 |
+
rank=self.accelerator.process_index,
|
113 |
+
shuffle=True,
|
114 |
+
)
|
115 |
+
self.train_dl = DataLoader(
|
116 |
+
self.train_ds,
|
117 |
+
batch_size=batch_size if not enable_cache_latents else cache_bs,
|
118 |
+
sampler=sampler,
|
119 |
+
num_workers=num_workers,
|
120 |
+
pin_memory=pin_memory,
|
121 |
+
drop_last=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
|
125 |
+
lr = blr * effective_bs / 256
|
126 |
+
if self.accelerator.is_main_process:
|
127 |
+
print(f"Effective batch size is {effective_bs}")
|
128 |
+
|
129 |
+
self.g_optim = create_optimizer(self.gpt_model, weight_decay=0.05, learning_rate=lr)
|
130 |
+
|
131 |
+
if warmup_epochs is not None:
|
132 |
+
warmup_steps = warmup_epochs * len(self.train_dl)
|
133 |
+
|
134 |
+
self.g_sched = create_scheduler(
|
135 |
+
self.g_optim,
|
136 |
+
num_epoch,
|
137 |
+
len(self.train_dl),
|
138 |
+
lr_min,
|
139 |
+
warmup_steps,
|
140 |
+
warmup_lr_init,
|
141 |
+
decay_steps,
|
142 |
+
cosine_lr
|
143 |
+
)
|
144 |
+
self.accelerator.register_for_checkpointing(self.g_sched)
|
145 |
+
self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(self.gpt_model, self.g_optim, self.g_sched)
|
146 |
+
else:
|
147 |
+
self.gpt_model = self.accelerator.prepare(self.gpt_model)
|
148 |
+
|
149 |
+
self.steps = 0
|
150 |
+
self.loaded_steps = -1
|
151 |
+
|
152 |
+
if compile:
|
153 |
+
self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead")
|
154 |
+
_model = self.accelerator.unwrap_model(self.gpt_model)
|
155 |
+
_model = torch.compile(_model, mode="reduce-overhead")
|
156 |
+
|
157 |
+
self.enable_ema = enable_ema
|
158 |
+
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
|
159 |
+
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device)
|
160 |
+
self.accelerator.register_for_checkpointing(self.ema_model)
|
161 |
+
|
162 |
+
self._load_checkpoint(gpt_model.params.ckpt_path)
|
163 |
+
if self.test_only:
|
164 |
+
self.steps = self.loaded_steps
|
165 |
+
|
166 |
+
self.num_epoch = num_epoch
|
167 |
+
self.save_every = save_every
|
168 |
+
self.sample_every = sample_every
|
169 |
+
self.fid_every = fid_every
|
170 |
+
self.max_grad_norm = max_grad_norm
|
171 |
+
self.eval_classes = eval_classes
|
172 |
+
self.cfg = cfg
|
173 |
+
self.ae_cfg = ae_cfg
|
174 |
+
self.cfg_schedule = cfg_schedule
|
175 |
+
self.temperature = temperature
|
176 |
+
self.train_num_slots = train_num_slots
|
177 |
+
self.test_num_slots = test_num_slots
|
178 |
+
if self.train_num_slots is not None:
|
179 |
+
self.train_num_slots = min(self.train_num_slots, self.num_slots)
|
180 |
+
else:
|
181 |
+
self.train_num_slots = self.num_slots
|
182 |
+
if self.test_num_slots is not None:
|
183 |
+
self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots)
|
184 |
+
else:
|
185 |
+
self.num_slots_to_gen = self.train_num_slots
|
186 |
+
self.eval_fid = eval_fid
|
187 |
+
if eval_fid:
|
188 |
+
assert fid_stats is not None
|
189 |
+
self.fid_stats = fid_stats
|
190 |
+
|
191 |
+
# Setup result folders
|
192 |
+
self.result_folder = result_folder
|
193 |
+
self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder)
|
194 |
+
|
195 |
+
# Setup cache
|
196 |
+
self.cache_dir = Path(cache_dir)
|
197 |
+
self.enable_cache_latents = enable_cache_latents
|
198 |
+
self.cache_loader = None
|
199 |
+
|
200 |
+
@property
|
201 |
+
def device(self):
|
202 |
+
return self.accelerator.device
|
203 |
+
|
204 |
+
def _load_checkpoint(self, ckpt_path=None, model=None):
|
205 |
+
if ckpt_path is None or not osp.exists(ckpt_path):
|
206 |
+
return
|
207 |
+
|
208 |
+
if model is None:
|
209 |
+
model = self.accelerator.unwrap_model(self.gpt_model)
|
210 |
+
|
211 |
+
if osp.isdir(ckpt_path):
|
212 |
+
self.loaded_steps = int(
|
213 |
+
ckpt_path.split("step")[-1].split("/")[0]
|
214 |
+
)
|
215 |
+
if not self.test_only:
|
216 |
+
self.accelerator.load_state(ckpt_path)
|
217 |
+
else:
|
218 |
+
if self.enable_ema:
|
219 |
+
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
|
220 |
+
if osp.exists(model_path):
|
221 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
222 |
+
load_state_dict(state_dict, model)
|
223 |
+
if self.accelerator.is_main_process:
|
224 |
+
print(f"Loaded ema model from {model_path}")
|
225 |
+
else:
|
226 |
+
model_path = osp.join(ckpt_path, "model.safetensors")
|
227 |
+
if osp.exists(model_path):
|
228 |
+
load_safetensors(model_path, model)
|
229 |
+
else:
|
230 |
+
if ckpt_path.endswith(".safetensors"):
|
231 |
+
load_safetensors(ckpt_path, model)
|
232 |
+
else:
|
233 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
234 |
+
load_state_dict(state_dict, model)
|
235 |
+
if self.accelerator.is_main_process:
|
236 |
+
print(f"Loaded checkpoint from {ckpt_path}")
|
237 |
+
|
238 |
+
def _build_cache(self):
|
239 |
+
"""Build cache for slots and targets."""
|
240 |
+
rank = self.accelerator.process_index
|
241 |
+
world_size = self.accelerator.num_processes
|
242 |
+
|
243 |
+
# Clean up any existing cache files first
|
244 |
+
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
|
245 |
+
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
|
246 |
+
|
247 |
+
if slots_file.exists():
|
248 |
+
os.remove(slots_file)
|
249 |
+
if targets_file.exists():
|
250 |
+
os.remove(targets_file)
|
251 |
+
|
252 |
+
dataset_size = len(self.train_dl.dataset)
|
253 |
+
shard_size = dataset_size // world_size
|
254 |
+
|
255 |
+
# Detect number of augmentations from first batch
|
256 |
+
with torch.no_grad():
|
257 |
+
sample_batch = next(iter(self.train_dl))
|
258 |
+
img, _ = sample_batch
|
259 |
+
num_augs = img.shape[1] if len(img.shape) == 5 else 1
|
260 |
+
|
261 |
+
print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...")
|
262 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
263 |
+
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
|
264 |
+
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
|
265 |
+
|
266 |
+
# Create memory-mapped files
|
267 |
+
slots_mmap = np.memmap(
|
268 |
+
slots_file,
|
269 |
+
dtype='float32',
|
270 |
+
mode='w+',
|
271 |
+
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
|
272 |
+
)
|
273 |
+
|
274 |
+
targets_mmap = np.memmap(
|
275 |
+
targets_file,
|
276 |
+
dtype='int64',
|
277 |
+
mode='w+',
|
278 |
+
shape=(shard_size * num_augs,)
|
279 |
+
)
|
280 |
+
|
281 |
+
# Cache data
|
282 |
+
with torch.no_grad():
|
283 |
+
for i, batch in enumerate(tqdm(
|
284 |
+
self.train_dl,
|
285 |
+
desc=f"Rank {rank}: Caching data",
|
286 |
+
disable=not self.accelerator.is_local_main_process
|
287 |
+
)):
|
288 |
+
imgs, targets = batch
|
289 |
+
if len(imgs.shape) == 5: # [B, num_augs, C, H, W]
|
290 |
+
B, A, C, H, W = imgs.shape
|
291 |
+
imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W]
|
292 |
+
targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs]
|
293 |
+
|
294 |
+
# Split imgs into n chunks
|
295 |
+
num_splits = num_augs
|
296 |
+
split_size = imgs.shape[0] // num_splits
|
297 |
+
imgs_splits = torch.split(imgs, split_size)
|
298 |
+
targets_splits = torch.split(targets, split_size)
|
299 |
+
|
300 |
+
start_idx = i * self.train_dl.batch_size * num_augs
|
301 |
+
|
302 |
+
for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)):
|
303 |
+
img_split = img_split.to(self.device, non_blocking=True)
|
304 |
+
slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :]
|
305 |
+
|
306 |
+
split_start = start_idx + (split_idx * split_size)
|
307 |
+
split_end = split_start + img_split.shape[0]
|
308 |
+
|
309 |
+
# Write directly to mmap files
|
310 |
+
slots_mmap[split_start:split_end] = slots_split.cpu().numpy()
|
311 |
+
targets_mmap[split_start:split_end] = targets_split.numpy()
|
312 |
+
|
313 |
+
# Close the mmap files
|
314 |
+
del slots_mmap
|
315 |
+
del targets_mmap
|
316 |
+
|
317 |
+
# Reopen in read mode
|
318 |
+
self.cached_latents = np.memmap(
|
319 |
+
slots_file,
|
320 |
+
dtype='float32',
|
321 |
+
mode='r',
|
322 |
+
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
|
323 |
+
)
|
324 |
+
|
325 |
+
self.cached_targets = np.memmap(
|
326 |
+
targets_file,
|
327 |
+
dtype='int64',
|
328 |
+
mode='r',
|
329 |
+
shape=(shard_size * num_augs,)
|
330 |
+
)
|
331 |
+
|
332 |
+
# Store the number of augmentations for the cache loader
|
333 |
+
self.num_augs = num_augs
|
334 |
+
|
335 |
+
def _setup_cache(self):
|
336 |
+
"""Setup cache if enabled."""
|
337 |
+
self._build_cache()
|
338 |
+
self.accelerator.wait_for_everyone()
|
339 |
+
|
340 |
+
# Initialize cache loader if cache exists
|
341 |
+
if self.cached_latents is not None:
|
342 |
+
self.cache_loader = CacheDataLoader(
|
343 |
+
slots=self.cached_latents,
|
344 |
+
targets=self.cached_targets,
|
345 |
+
batch_size=self.batch_size,
|
346 |
+
num_augs=self.num_augs,
|
347 |
+
seed=42 + self.accelerator.process_index
|
348 |
+
)
|
349 |
+
|
350 |
+
def __del__(self):
|
351 |
+
"""Cleanup cache files."""
|
352 |
+
if self.enable_cache_latents:
|
353 |
+
rank = self.accelerator.process_index
|
354 |
+
world_size = self.accelerator.num_processes
|
355 |
+
|
356 |
+
# Clean up slots cache
|
357 |
+
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
|
358 |
+
if slots_file.exists():
|
359 |
+
os.remove(slots_file)
|
360 |
+
|
361 |
+
# Clean up targets cache
|
362 |
+
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
|
363 |
+
if targets_file.exists():
|
364 |
+
os.remove(targets_file)
|
365 |
+
|
366 |
+
def _train_step(self, slots, targets=None):
|
367 |
+
"""Execute single training step."""
|
368 |
+
|
369 |
+
with self.accelerator.accumulate(self.gpt_model):
|
370 |
+
with self.accelerator.autocast():
|
371 |
+
loss = self.gpt_model(slots, targets)
|
372 |
+
|
373 |
+
self.accelerator.backward(loss)
|
374 |
+
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
|
375 |
+
self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
|
376 |
+
self.g_optim.step()
|
377 |
+
if self.g_sched is not None:
|
378 |
+
self.g_sched.step_update(self.steps)
|
379 |
+
self.g_optim.zero_grad()
|
380 |
+
|
381 |
+
# Update EMA model if enabled
|
382 |
+
if self.enable_ema:
|
383 |
+
self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model))
|
384 |
+
|
385 |
+
return loss
|
386 |
+
|
387 |
+
def _train_epoch_cached(self, epoch, logger):
|
388 |
+
"""Train one epoch using cached data."""
|
389 |
+
self.cache_loader.set_epoch(epoch)
|
390 |
+
header = f'Epoch: [{epoch}/{self.num_epoch}]'
|
391 |
+
|
392 |
+
for batch in logger.log_every(self.cache_loader, 20, header):
|
393 |
+
slots, targets = (b.to(self.device, non_blocking=True) for b in batch)
|
394 |
+
|
395 |
+
self.steps += 1
|
396 |
+
|
397 |
+
if self.steps == 1:
|
398 |
+
print(f"Training batch size: {len(slots)}")
|
399 |
+
print(f"Hello from index {self.accelerator.local_process_index}")
|
400 |
+
|
401 |
+
loss = self._train_step(slots, targets)
|
402 |
+
self._handle_periodic_ops(loss, logger)
|
403 |
+
|
404 |
+
def _train_epoch_uncached(self, epoch, logger):
|
405 |
+
"""Train one epoch using raw data."""
|
406 |
+
header = f'Epoch: [{epoch}/{self.num_epoch}]'
|
407 |
+
|
408 |
+
for batch in logger.log_every(self.train_dl, 20, header):
|
409 |
+
img, targets = (b.to(self.device, non_blocking=True) for b in batch)
|
410 |
+
|
411 |
+
self.steps += 1
|
412 |
+
|
413 |
+
if self.steps == 1:
|
414 |
+
print(f"Training batch size: {img.size(0)}")
|
415 |
+
print(f"Hello from index {self.accelerator.local_process_index}")
|
416 |
+
|
417 |
+
slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :]
|
418 |
+
loss = self._train_step(slots, targets)
|
419 |
+
self._handle_periodic_ops(loss, logger)
|
420 |
+
|
421 |
+
def _handle_periodic_ops(self, loss, logger):
|
422 |
+
"""Handle periodic operations and logging."""
|
423 |
+
logger.update(loss=loss.item())
|
424 |
+
logger.update(lr=self.g_optim.param_groups[0]["lr"])
|
425 |
+
|
426 |
+
if self.steps % self.save_every == 0:
|
427 |
+
self.save()
|
428 |
+
|
429 |
+
if (self.steps % self.sample_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0):
|
430 |
+
empty_cache()
|
431 |
+
self.evaluate()
|
432 |
+
self.accelerator.wait_for_everyone()
|
433 |
+
empty_cache()
|
434 |
+
|
435 |
+
def _save_config(self, config):
|
436 |
+
"""Save configuration file."""
|
437 |
+
if config is not None and self.accelerator.is_main_process:
|
438 |
+
import shutil
|
439 |
+
from omegaconf import OmegaConf
|
440 |
+
|
441 |
+
if isinstance(config, str) and osp.exists(config):
|
442 |
+
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
|
443 |
+
else:
|
444 |
+
config_save_path = osp.join(self.result_folder, "config.yaml")
|
445 |
+
OmegaConf.save(config, config_save_path)
|
446 |
+
|
447 |
+
def _should_skip_epoch(self, epoch):
|
448 |
+
"""Check if epoch should be skipped due to loaded checkpoint."""
|
449 |
+
loader = self.train_dl if not self.enable_cache_latents else self.cache_loader
|
450 |
+
if ((epoch + 1) * len(loader)) <= self.loaded_steps:
|
451 |
+
if self.accelerator.is_main_process:
|
452 |
+
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
|
453 |
+
self.steps += len(loader)
|
454 |
+
return True
|
455 |
+
|
456 |
+
if self.steps < self.loaded_steps:
|
457 |
+
for _ in loader:
|
458 |
+
self.steps += 1
|
459 |
+
if self.steps >= self.loaded_steps:
|
460 |
+
break
|
461 |
+
return False
|
462 |
+
|
463 |
+
def train(self, config=None):
|
464 |
+
"""Main training loop."""
|
465 |
+
# Initial setup
|
466 |
+
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
467 |
+
if self.accelerator.is_main_process:
|
468 |
+
print(f"number of learnable parameters: {n_parameters//1e6}M")
|
469 |
+
|
470 |
+
self._save_config(config)
|
471 |
+
self.accelerator.init_trackers("gpt")
|
472 |
+
|
473 |
+
# Handle test-only mode
|
474 |
+
if self.test_only:
|
475 |
+
empty_cache()
|
476 |
+
self.evaluate()
|
477 |
+
self.accelerator.wait_for_everyone()
|
478 |
+
empty_cache()
|
479 |
+
return
|
480 |
+
|
481 |
+
# Setup cache if enabled
|
482 |
+
if self.enable_cache_latents:
|
483 |
+
self._setup_cache()
|
484 |
+
|
485 |
+
# Training loop
|
486 |
+
for epoch in range(self.num_epoch):
|
487 |
+
if self._should_skip_epoch(epoch):
|
488 |
+
continue
|
489 |
+
|
490 |
+
self.gpt_model.train()
|
491 |
+
logger = MetricLogger(delimiter=" ")
|
492 |
+
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
493 |
+
|
494 |
+
# Choose training path based on cache availability
|
495 |
+
if self.enable_cache_latents:
|
496 |
+
self._train_epoch_cached(epoch, logger)
|
497 |
+
else:
|
498 |
+
self._train_epoch_uncached(epoch, logger)
|
499 |
+
|
500 |
+
# Synchronize and log epoch stats
|
501 |
+
logger.synchronize_between_processes()
|
502 |
+
if self.accelerator.is_main_process:
|
503 |
+
print("Averaged stats:", logger)
|
504 |
+
|
505 |
+
# Finish training
|
506 |
+
self.accelerator.end_training()
|
507 |
+
self.save()
|
508 |
+
if self.accelerator.is_main_process:
|
509 |
+
print("Train finished!")
|
510 |
+
|
511 |
+
def save(self):
|
512 |
+
self.accelerator.wait_for_everyone()
|
513 |
+
self.accelerator.save_state(
|
514 |
+
os.path.join(self.model_saved_dir, f"step{self.steps}")
|
515 |
+
)
|
516 |
+
|
517 |
+
@torch.no_grad()
|
518 |
+
def evaluate(self, use_ema=True):
|
519 |
+
self.gpt_model.eval()
|
520 |
+
unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model)
|
521 |
+
# switch to ema params, only when eval_fid is True
|
522 |
+
# if test_only, we directly load the ema dict and skip here
|
523 |
+
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
|
524 |
+
if use_ema:
|
525 |
+
if hasattr(self, "ema_model"):
|
526 |
+
model_without_ddp = self.accelerator.unwrap_model(self.gpt_model)
|
527 |
+
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
528 |
+
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
|
529 |
+
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
|
530 |
+
if "nested_sampler" in name:
|
531 |
+
continue
|
532 |
+
ema_state_dict[name] = self.ema_model.state_dict()[name]
|
533 |
+
if self.accelerator.is_main_process:
|
534 |
+
print("Switch to ema")
|
535 |
+
model_without_ddp.load_state_dict(ema_state_dict)
|
536 |
+
else:
|
537 |
+
print("EMA model not found, using original model")
|
538 |
+
use_ema = False
|
539 |
+
|
540 |
+
if not self.test_only:
|
541 |
+
classes = torch.tensor(self.eval_classes, device=self.device)
|
542 |
+
with self.accelerator.autocast():
|
543 |
+
slots = generate(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, cfg_schedule=self.cfg_schedule, temperature=self.temperature)
|
544 |
+
if self.num_slots_to_gen < self.num_slots:
|
545 |
+
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
|
546 |
+
null_slots = null_slots[:, self.num_slots_to_gen:, :]
|
547 |
+
slots = torch.cat([slots, null_slots], dim=1)
|
548 |
+
imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now
|
549 |
+
|
550 |
+
imgs = concat_all_gather(imgs)
|
551 |
+
if self.accelerator.num_processes > 16:
|
552 |
+
imgs = imgs[:16*len(self.eval_classes)]
|
553 |
+
imgs = imgs.detach().cpu()
|
554 |
+
grid = make_grid(
|
555 |
+
imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1)
|
556 |
+
)
|
557 |
+
if self.accelerator.is_main_process:
|
558 |
+
save_image(
|
559 |
+
grid,
|
560 |
+
os.path.join(
|
561 |
+
self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg"
|
562 |
+
),
|
563 |
+
)
|
564 |
+
if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)):
|
565 |
+
# Create output directory (only on main process)
|
566 |
+
save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}")
|
567 |
+
if self.accelerator.is_main_process:
|
568 |
+
os.makedirs(save_folder, exist_ok=True)
|
569 |
+
|
570 |
+
# Setup for distributed generation
|
571 |
+
world_size = self.accelerator.num_processes
|
572 |
+
local_rank = self.accelerator.process_index
|
573 |
+
batch_size = self.test_bs
|
574 |
+
|
575 |
+
# Create balanced class distribution
|
576 |
+
num_classes = self.num_classes
|
577 |
+
images_per_class = self.num_test_images // num_classes
|
578 |
+
class_labels = np.repeat(np.arange(num_classes), images_per_class)
|
579 |
+
|
580 |
+
# Shuffle the class labels to ensure random ordering
|
581 |
+
np.random.shuffle(class_labels)
|
582 |
+
|
583 |
+
total_images = len(class_labels)
|
584 |
+
|
585 |
+
padding_size = world_size * batch_size - (total_images % (world_size * batch_size))
|
586 |
+
class_labels = np.pad(class_labels, (0, padding_size), 'constant')
|
587 |
+
padded_total_images = len(class_labels)
|
588 |
+
|
589 |
+
# Distribute workload across GPUs
|
590 |
+
images_per_gpu = padded_total_images // world_size
|
591 |
+
start_idx = local_rank * images_per_gpu
|
592 |
+
end_idx = min(start_idx + images_per_gpu, padded_total_images)
|
593 |
+
local_class_labels = class_labels[start_idx:end_idx]
|
594 |
+
local_num_steps = len(local_class_labels) // batch_size
|
595 |
+
|
596 |
+
if self.accelerator.is_main_process:
|
597 |
+
print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs")
|
598 |
+
|
599 |
+
used_time = 0
|
600 |
+
gen_img_cnt = 0
|
601 |
+
|
602 |
+
for i in range(local_num_steps):
|
603 |
+
if self.accelerator.is_main_process and i % 10 == 0:
|
604 |
+
print(f"Generation step {i}/{local_num_steps}")
|
605 |
+
|
606 |
+
# Get and pad labels for current batch
|
607 |
+
batch_start = i * batch_size
|
608 |
+
batch_end = batch_start + batch_size
|
609 |
+
labels = local_class_labels[batch_start:batch_end]
|
610 |
+
|
611 |
+
# Convert to tensors and track real vs padding
|
612 |
+
labels = torch.tensor(labels, device=self.device)
|
613 |
+
|
614 |
+
# Generate images
|
615 |
+
self.accelerator.wait_for_everyone()
|
616 |
+
start_time = time.time()
|
617 |
+
with torch.no_grad():
|
618 |
+
with self.accelerator.autocast():
|
619 |
+
slots = generate(unwraped_gpt_model, labels, self.num_slots_to_gen,
|
620 |
+
cfg_scale=self.cfg,
|
621 |
+
cfg_schedule=self.cfg_schedule,
|
622 |
+
temperature=self.temperature)
|
623 |
+
if self.num_slots_to_gen < self.num_slots:
|
624 |
+
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
|
625 |
+
null_slots = null_slots[:, self.num_slots_to_gen:, :]
|
626 |
+
slots = torch.cat([slots, null_slots], dim=1)
|
627 |
+
imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg)
|
628 |
+
|
629 |
+
samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt)
|
630 |
+
|
631 |
+
# Update timing stats
|
632 |
+
used_time += time.time() - start_time
|
633 |
+
gen_img_cnt += samples_in_batch
|
634 |
+
if self.accelerator.is_main_process and i % 10 == 0:
|
635 |
+
print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image")
|
636 |
+
|
637 |
+
gathered_imgs = concat_all_gather(imgs)
|
638 |
+
gathered_imgs = gathered_imgs[:samples_in_batch]
|
639 |
+
|
640 |
+
# Save images (only on main process)
|
641 |
+
if self.accelerator.is_main_process:
|
642 |
+
real_imgs = gathered_imgs.detach().cpu()
|
643 |
+
|
644 |
+
save_paths = [
|
645 |
+
os.path.join(save_folder, f"{str(idx).zfill(5)}.png")
|
646 |
+
for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt)
|
647 |
+
]
|
648 |
+
save_img_batch(real_imgs, save_paths)
|
649 |
+
|
650 |
+
# Calculate metrics (only on main process)
|
651 |
+
self.accelerator.wait_for_everyone()
|
652 |
+
if self.accelerator.is_main_process:
|
653 |
+
generated_files = len(os.listdir(save_folder))
|
654 |
+
print(f"Generated {generated_files} images out of {total_images} expected")
|
655 |
+
|
656 |
+
metrics_dict = get_fid_stats(save_folder, None, self.fid_stats)
|
657 |
+
fid = metrics_dict["frechet_inception_distance"]
|
658 |
+
inception_score = metrics_dict["inception_score_mean"]
|
659 |
+
|
660 |
+
metric_prefix = "fid_ema" if use_ema else "fid"
|
661 |
+
isc_prefix = "isc_ema" if use_ema else "isc"
|
662 |
+
|
663 |
+
self.accelerator.log({
|
664 |
+
metric_prefix: fid,
|
665 |
+
isc_prefix: inception_score,
|
666 |
+
"gpt_cfg": self.cfg,
|
667 |
+
"ae_cfg": self.ae_cfg,
|
668 |
+
"cfg_schedule": self.cfg_schedule,
|
669 |
+
"temperature": self.temperature,
|
670 |
+
"num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots
|
671 |
+
}, step=self.steps)
|
672 |
+
|
673 |
+
# Print comprehensive CFG information
|
674 |
+
cfg_info = (
|
675 |
+
f"{'EMA ' if use_ema else ''}CFG params: "
|
676 |
+
f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, "
|
677 |
+
f"cfg_schedule={self.cfg_schedule}, "
|
678 |
+
f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, "
|
679 |
+
f"temperature={self.temperature}"
|
680 |
+
)
|
681 |
+
print(cfg_info)
|
682 |
+
print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}")
|
683 |
+
|
684 |
+
# Cleanup
|
685 |
+
shutil.rmtree(save_folder)
|
686 |
+
|
687 |
+
# back to no ema
|
688 |
+
if use_ema:
|
689 |
+
if self.accelerator.is_main_process:
|
690 |
+
print("Switch back from ema")
|
691 |
+
model_without_ddp.load_state_dict(model_state_dict)
|
692 |
+
|
693 |
+
self.gpt_model.train()
|
694 |
+
|
semanticist/engine/trainer_utils.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch_fidelity
|
5 |
+
from collections import OrderedDict
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
+
import importlib
|
8 |
+
from torch.optim import AdamW
|
9 |
+
from semanticist.utils.lr_scheduler import build_scheduler
|
10 |
+
|
11 |
+
|
12 |
+
def get_obj_from_str(string, reload=False):
|
13 |
+
"""Get object from string path."""
|
14 |
+
module, cls = string.rsplit(".", 1)
|
15 |
+
if reload:
|
16 |
+
module_imp = importlib.import_module(module)
|
17 |
+
importlib.reload(module_imp)
|
18 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
19 |
+
|
20 |
+
|
21 |
+
def instantiate_from_config(config):
|
22 |
+
"""Instantiate an object from a config dictionary."""
|
23 |
+
if not "target" in config:
|
24 |
+
raise KeyError("Expected key `target` to instantiate.")
|
25 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
26 |
+
|
27 |
+
|
28 |
+
def is_dist_avail_and_initialized():
|
29 |
+
"""Check if distributed training is available and initialized."""
|
30 |
+
if not torch.distributed.is_initialized():
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
|
35 |
+
def is_main_process():
|
36 |
+
"""Check if the current process is the main process."""
|
37 |
+
return not is_dist_avail_and_initialized() or torch.distributed.get_rank() == 0
|
38 |
+
|
39 |
+
|
40 |
+
def concat_all_gather(tensor):
|
41 |
+
"""
|
42 |
+
Performs all_gather operation on the provided tensors.
|
43 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
44 |
+
"""
|
45 |
+
tensors_gather = [torch.ones_like(tensor)
|
46 |
+
for _ in range(torch.distributed.get_world_size())]
|
47 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
48 |
+
|
49 |
+
output = torch.cat(tensors_gather, dim=0)
|
50 |
+
return output
|
51 |
+
|
52 |
+
|
53 |
+
def requires_grad(model, flag=True):
|
54 |
+
"""Set requires_grad flag for all model parameters."""
|
55 |
+
for p in model.parameters():
|
56 |
+
p.requires_grad = flag
|
57 |
+
|
58 |
+
|
59 |
+
def save_img(img, save_path):
|
60 |
+
"""Save a single image to disk."""
|
61 |
+
img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255)
|
62 |
+
img = img.astype(np.uint8)[:, :, ::-1]
|
63 |
+
cv2.imwrite(save_path, img)
|
64 |
+
|
65 |
+
|
66 |
+
def save_img_batch(imgs, save_paths):
|
67 |
+
"""Process and save multiple images at once using a thread pool."""
|
68 |
+
# Convert to numpy and prepare all images in one go
|
69 |
+
imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
|
70 |
+
imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
|
71 |
+
|
72 |
+
with ThreadPoolExecutor(max_workers=32) as pool:
|
73 |
+
# Submit all tasks at once
|
74 |
+
futures = [pool.submit(cv2.imwrite, path, img)
|
75 |
+
for path, img in zip(save_paths, imgs)]
|
76 |
+
# Wait for all tasks to complete
|
77 |
+
for future in futures:
|
78 |
+
future.result() # This will raise any exceptions that occurred
|
79 |
+
|
80 |
+
|
81 |
+
def get_fid_stats(real_dir, rec_dir, fid_stats):
|
82 |
+
"""Calculate FID statistics between real and reconstructed images."""
|
83 |
+
stats = torch_fidelity.calculate_metrics(
|
84 |
+
input1=rec_dir,
|
85 |
+
input2=real_dir,
|
86 |
+
fid_statistics_file=fid_stats,
|
87 |
+
cuda=True,
|
88 |
+
isc=True,
|
89 |
+
fid=True,
|
90 |
+
kid=False,
|
91 |
+
prc=False,
|
92 |
+
verbose=False,
|
93 |
+
)
|
94 |
+
return stats
|
95 |
+
|
96 |
+
|
97 |
+
def create_scheduler(optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps,
|
98 |
+
warmup_lr_init, decay_steps, cosine_lr):
|
99 |
+
"""Create a learning rate scheduler."""
|
100 |
+
scheduler = build_scheduler(
|
101 |
+
optimizer,
|
102 |
+
num_epoch,
|
103 |
+
steps_per_epoch,
|
104 |
+
lr_min,
|
105 |
+
warmup_steps,
|
106 |
+
warmup_lr_init,
|
107 |
+
decay_steps,
|
108 |
+
cosine_lr,
|
109 |
+
)
|
110 |
+
return scheduler
|
111 |
+
|
112 |
+
|
113 |
+
def load_state_dict(state_dict, model):
|
114 |
+
"""Helper to load a state dict with proper prefix handling."""
|
115 |
+
if 'state_dict' in state_dict:
|
116 |
+
state_dict = state_dict['state_dict']
|
117 |
+
# Remove '_orig_mod' prefix if present
|
118 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
119 |
+
missing, unexpected = model.load_state_dict(
|
120 |
+
state_dict, strict=False
|
121 |
+
)
|
122 |
+
if is_main_process():
|
123 |
+
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
|
124 |
+
|
125 |
+
|
126 |
+
def load_safetensors(path, model):
|
127 |
+
"""Helper to load a safetensors checkpoint."""
|
128 |
+
from safetensors.torch import safe_open
|
129 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
130 |
+
state_dict = {k: f.get_tensor(k) for k in f.keys()}
|
131 |
+
load_state_dict(state_dict, model)
|
132 |
+
|
133 |
+
|
134 |
+
def setup_result_folders(result_folder):
|
135 |
+
"""Setup result folders for saving models and images."""
|
136 |
+
model_saved_dir = os.path.join(result_folder, "models")
|
137 |
+
os.makedirs(model_saved_dir, exist_ok=True)
|
138 |
+
|
139 |
+
image_saved_dir = os.path.join(result_folder, "images")
|
140 |
+
os.makedirs(image_saved_dir, exist_ok=True)
|
141 |
+
|
142 |
+
return model_saved_dir, image_saved_dir
|
143 |
+
|
144 |
+
|
145 |
+
def create_optimizer(model, weight_decay, learning_rate, betas=(0.9, 0.95)):
|
146 |
+
"""Create an AdamW optimizer with weight decay for 2D parameters only."""
|
147 |
+
# start with all of the candidate parameters
|
148 |
+
param_dict = {pn: p for pn, p in model.named_parameters()}
|
149 |
+
# filter out those that do not require grad
|
150 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
151 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
152 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
153 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
154 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
155 |
+
optim_groups = [
|
156 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
157 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
158 |
+
]
|
159 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
160 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
161 |
+
if is_main_process():
|
162 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
163 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
164 |
+
optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas)
|
165 |
+
return optimizer
|
166 |
+
|
167 |
+
|
168 |
+
class EMAModel:
|
169 |
+
"""Model Exponential Moving Average."""
|
170 |
+
def __init__(self, model, device, decay=0.999):
|
171 |
+
self.device = device
|
172 |
+
self.decay = decay
|
173 |
+
self.ema_params = OrderedDict(
|
174 |
+
(name, param.clone().detach().to(device))
|
175 |
+
for name, param in model.named_parameters()
|
176 |
+
if param.requires_grad
|
177 |
+
)
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def update(self, model):
|
181 |
+
for name, param in model.named_parameters():
|
182 |
+
if param.requires_grad:
|
183 |
+
if name in self.ema_params:
|
184 |
+
self.ema_params[name].lerp_(param.data, 1 - self.decay)
|
185 |
+
else:
|
186 |
+
self.ema_params[name] = param.data.clone().detach()
|
187 |
+
|
188 |
+
def state_dict(self):
|
189 |
+
return self.ema_params
|
190 |
+
|
191 |
+
def load_state_dict(self, params):
|
192 |
+
self.ema_params = OrderedDict(
|
193 |
+
(name, param.clone().detach().to(self.device))
|
194 |
+
for name, param in params.items()
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
class PaddedDataset(torch.utils.data.Dataset):
|
199 |
+
"""Dataset wrapper that pads a dataset to ensure even distribution across processes."""
|
200 |
+
def __init__(self, dataset, padding_size):
|
201 |
+
self.dataset = dataset
|
202 |
+
self.padding_size = padding_size
|
203 |
+
|
204 |
+
def __len__(self):
|
205 |
+
return len(self.dataset) + self.padding_size
|
206 |
+
|
207 |
+
def __getitem__(self, idx):
|
208 |
+
if idx < len(self.dataset):
|
209 |
+
return self.dataset[idx]
|
210 |
+
return self.dataset[0]
|
211 |
+
|
212 |
+
class CacheDataLoader:
|
213 |
+
"""DataLoader-like interface for cached data with epoch-based shuffling."""
|
214 |
+
def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None):
|
215 |
+
self.slots = slots
|
216 |
+
self.targets = targets
|
217 |
+
self.batch_size = batch_size
|
218 |
+
self.num_augs = num_augs
|
219 |
+
self.seed = seed
|
220 |
+
self.epoch = 0
|
221 |
+
# Original dataset size (before augmentations)
|
222 |
+
self.num_samples = len(slots) // num_augs
|
223 |
+
|
224 |
+
def set_epoch(self, epoch):
|
225 |
+
"""Set epoch for deterministic shuffling."""
|
226 |
+
self.epoch = epoch
|
227 |
+
|
228 |
+
def __len__(self):
|
229 |
+
"""Return number of batches based on original dataset size."""
|
230 |
+
return self.num_samples // self.batch_size
|
231 |
+
|
232 |
+
def __iter__(self):
|
233 |
+
"""Return random indices for current epoch."""
|
234 |
+
g = torch.Generator()
|
235 |
+
g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch)
|
236 |
+
|
237 |
+
# Randomly sample indices from the entire augmented dataset
|
238 |
+
indices = torch.randint(
|
239 |
+
0, len(self.slots),
|
240 |
+
(self.num_samples,),
|
241 |
+
generator=g
|
242 |
+
).numpy()
|
243 |
+
|
244 |
+
# Yield batches of indices
|
245 |
+
for start in range(0, self.num_samples, self.batch_size):
|
246 |
+
end = min(start + self.batch_size, self.num_samples)
|
247 |
+
batch_indices = indices[start:end]
|
248 |
+
yield (
|
249 |
+
torch.from_numpy(self.slots[batch_indices]),
|
250 |
+
torch.from_numpy(self.targets[batch_indices])
|
251 |
+
)
|
semanticist/stage1/diffuse_slot.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from diffusers import AutoencoderKL
|
7 |
+
from semanticist.stage1 import vision_transformer
|
8 |
+
from semanticist.stage1.diffusion import create_diffusion
|
9 |
+
from semanticist.stage1.diffusion_transfomer import DiT
|
10 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
+
|
12 |
+
class DiT_with_autoenc_cond(DiT):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
*args,
|
16 |
+
num_autoenc=32,
|
17 |
+
autoenc_dim=4,
|
18 |
+
use_repa=False,
|
19 |
+
z_dim=768,
|
20 |
+
encoder_depth=8,
|
21 |
+
projector_dim=2048,
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
self.autoenc_dim = autoenc_dim
|
26 |
+
self.hidden_size = kwargs["hidden_size"]
|
27 |
+
self.null_cond = nn.Parameter(torch.zeros(1, num_autoenc, autoenc_dim))
|
28 |
+
torch.nn.init.normal_(self.null_cond, std=.02)
|
29 |
+
self.autoenc_cond_embedder = nn.Linear(autoenc_dim, self.hidden_size)
|
30 |
+
self.y_embedder = nn.Identity()
|
31 |
+
self.cond_drop_prob = 0.1
|
32 |
+
|
33 |
+
self.use_repa = use_repa
|
34 |
+
self._repa_hook = None
|
35 |
+
self.encoder_depth = encoder_depth
|
36 |
+
if use_repa:
|
37 |
+
self.projector = build_mlp(self.hidden_size, projector_dim, z_dim)
|
38 |
+
|
39 |
+
def embed_cond(self, autoenc_cond, drop_mask=None):
|
40 |
+
# autoenc_cond: (N, K, D)
|
41 |
+
# drop_ids: (N)
|
42 |
+
# self.null_cond: (1, K, D)
|
43 |
+
batch_size = autoenc_cond.shape[0]
|
44 |
+
if drop_mask is None:
|
45 |
+
# randomly drop all conditions, for classifier-free guidance
|
46 |
+
if self.training:
|
47 |
+
drop_ids = (
|
48 |
+
torch.rand(batch_size, 1, 1, device=autoenc_cond.device)
|
49 |
+
< self.cond_drop_prob
|
50 |
+
)
|
51 |
+
autoenc_cond_drop = torch.where(drop_ids, self.null_cond, autoenc_cond)
|
52 |
+
else:
|
53 |
+
autoenc_cond_drop = autoenc_cond
|
54 |
+
else:
|
55 |
+
# randomly drop some conditions according to the drop_mask (N, K)
|
56 |
+
# True means keep
|
57 |
+
autoenc_cond_drop = torch.where(drop_mask[:, :, None], autoenc_cond, self.null_cond)
|
58 |
+
return self.autoenc_cond_embedder(autoenc_cond_drop)
|
59 |
+
|
60 |
+
def forward(self, x, t, autoenc_cond, drop_mask=None):
|
61 |
+
"""
|
62 |
+
Forward pass of DiT.
|
63 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
64 |
+
t: (N,) tensor of diffusion timesteps
|
65 |
+
autoenc_cond: (N, K, D) tensor of autoencoder conditions (slots)
|
66 |
+
"""
|
67 |
+
x = (
|
68 |
+
self.x_embedder(x) + self.pos_embed
|
69 |
+
) # (N, T, D), where T = H * W / patch_size ** 2
|
70 |
+
c = self.t_embedder(t) # (N, D)
|
71 |
+
autoenc = self.embed_cond(autoenc_cond, drop_mask)
|
72 |
+
num_tokens = x.shape[1]
|
73 |
+
x = torch.cat((x, autoenc), dim=1)
|
74 |
+
|
75 |
+
for i, block in enumerate(self.blocks):
|
76 |
+
x = block(x, c) # (N, T, D)
|
77 |
+
if (i + 1) == self.encoder_depth and self.use_repa:
|
78 |
+
projected = self.projector(x)
|
79 |
+
self._repa_hook = projected[:, :num_tokens]
|
80 |
+
|
81 |
+
x = x[:, :num_tokens]
|
82 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
83 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
84 |
+
return x
|
85 |
+
|
86 |
+
def forward_with_cfg(self, x, t, autoenc_cond, drop_mask, y=None, cfg_scale=1.0):
|
87 |
+
"""
|
88 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
89 |
+
"""
|
90 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
91 |
+
half = x[: len(x) // 2]
|
92 |
+
combined = torch.cat([half, half], dim=0)
|
93 |
+
model_out = self.forward(combined, t, autoenc_cond, drop_mask)
|
94 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
95 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
96 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
97 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
98 |
+
return torch.cat([eps, rest], dim=1)
|
99 |
+
|
100 |
+
#################################################################################
|
101 |
+
# DiT Configs #
|
102 |
+
#################################################################################
|
103 |
+
|
104 |
+
|
105 |
+
def DiT_with_autoenc_cond_XL_2(**kwargs):
|
106 |
+
return DiT_with_autoenc_cond(
|
107 |
+
depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def DiT_with_autoenc_cond_XL_4(**kwargs):
|
112 |
+
return DiT_with_autoenc_cond(
|
113 |
+
depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def DiT_with_autoenc_cond_XL_8(**kwargs):
|
118 |
+
return DiT_with_autoenc_cond(
|
119 |
+
depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def DiT_with_autoenc_cond_L_2(**kwargs):
|
124 |
+
return DiT_with_autoenc_cond(
|
125 |
+
depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
def DiT_with_autoenc_cond_L_4(**kwargs):
|
130 |
+
return DiT_with_autoenc_cond(
|
131 |
+
depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
def DiT_with_autoenc_cond_L_8(**kwargs):
|
136 |
+
return DiT_with_autoenc_cond(
|
137 |
+
depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def DiT_with_autoenc_cond_B_2(**kwargs):
|
142 |
+
return DiT_with_autoenc_cond(
|
143 |
+
depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def DiT_with_autoenc_cond_B_4(**kwargs):
|
148 |
+
return DiT_with_autoenc_cond(
|
149 |
+
depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def DiT_with_autoenc_cond_B_8(**kwargs):
|
154 |
+
return DiT_with_autoenc_cond(
|
155 |
+
depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
def DiT_with_autoenc_cond_S_2(**kwargs):
|
160 |
+
return DiT_with_autoenc_cond(
|
161 |
+
depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def DiT_with_autoenc_cond_S_4(**kwargs):
|
166 |
+
return DiT_with_autoenc_cond(
|
167 |
+
depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def DiT_with_autoenc_cond_S_8(**kwargs):
|
172 |
+
return DiT_with_autoenc_cond(
|
173 |
+
depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
DiT_with_autoenc_cond_models = {
|
178 |
+
"DiT-XL-2": DiT_with_autoenc_cond_XL_2,
|
179 |
+
"DiT-XL-4": DiT_with_autoenc_cond_XL_4,
|
180 |
+
"DiT-XL-8": DiT_with_autoenc_cond_XL_8,
|
181 |
+
"DiT-L-2": DiT_with_autoenc_cond_L_2,
|
182 |
+
"DiT-L-4": DiT_with_autoenc_cond_L_4,
|
183 |
+
"DiT-L-8": DiT_with_autoenc_cond_L_8,
|
184 |
+
"DiT-B-2": DiT_with_autoenc_cond_B_2,
|
185 |
+
"DiT-B-4": DiT_with_autoenc_cond_B_4,
|
186 |
+
"DiT-B-8": DiT_with_autoenc_cond_B_8,
|
187 |
+
"DiT-S-2": DiT_with_autoenc_cond_S_2,
|
188 |
+
"DiT-S-4": DiT_with_autoenc_cond_S_4,
|
189 |
+
"DiT-S-8": DiT_with_autoenc_cond_S_8,
|
190 |
+
}
|
191 |
+
|
192 |
+
class NestedSampler(nn.Module):
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
num_slots,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.num_slots = num_slots
|
199 |
+
self.register_buffer("arange", torch.arange(num_slots))
|
200 |
+
|
201 |
+
def uniform_sample(self, num):
|
202 |
+
return torch.randint(1, self.num_slots + 1, (num,))
|
203 |
+
|
204 |
+
def sample(self, num):
|
205 |
+
samples = self.uniform_sample(num)
|
206 |
+
return samples
|
207 |
+
|
208 |
+
def forward(self, batch_size, device, inference_with_n_slots=-1):
|
209 |
+
if self.training:
|
210 |
+
b = self.sample(batch_size).to(device)
|
211 |
+
else:
|
212 |
+
if inference_with_n_slots != -1:
|
213 |
+
b = torch.full((batch_size,), inference_with_n_slots, device=device)
|
214 |
+
else:
|
215 |
+
b = torch.full((batch_size,), self.num_slots, device=device)
|
216 |
+
b = torch.clamp(b, max=self.num_slots)
|
217 |
+
|
218 |
+
slot_mask = self.arange[None, :] < b[:, None] # (batch_size, num_slots)
|
219 |
+
return slot_mask
|
220 |
+
|
221 |
+
class DiffuseSlot(nn.Module):
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
encoder="vit_base_patch16",
|
225 |
+
drop_path_rate=0.1,
|
226 |
+
enc_img_size=256,
|
227 |
+
enc_causal=True,
|
228 |
+
num_slots=16,
|
229 |
+
slot_dim=256,
|
230 |
+
norm_slots=False,
|
231 |
+
enable_nest=False,
|
232 |
+
enable_nest_after=-1,
|
233 |
+
vae="stabilityai/sd-vae-ft-ema",
|
234 |
+
dit_model="DiT-B-4",
|
235 |
+
num_sampling_steps="ddim25",
|
236 |
+
use_repa=False,
|
237 |
+
repa_encoder_depth=8,
|
238 |
+
repa_loss_weight=1.0,
|
239 |
+
**kwargs,
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.use_repa = use_repa
|
244 |
+
self.repa_loss_weight = repa_loss_weight
|
245 |
+
if use_repa:
|
246 |
+
self.repa_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
247 |
+
self.repa_encoder.image_size = 224
|
248 |
+
for param in self.repa_encoder.parameters():
|
249 |
+
param.requires_grad = False
|
250 |
+
self.repa_encoder.eval()
|
251 |
+
|
252 |
+
self.diffusion = create_diffusion(timestep_respacing="")
|
253 |
+
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps)
|
254 |
+
self.dit_input_size = enc_img_size // 8 if not "mar" in vae else enc_img_size // 16
|
255 |
+
self.dit_in_channels = 4 if not "mar" in vae else 16
|
256 |
+
self.dit = DiT_with_autoenc_cond_models[dit_model](
|
257 |
+
input_size=self.dit_input_size,
|
258 |
+
in_channels=self.dit_in_channels,
|
259 |
+
num_autoenc=num_slots,
|
260 |
+
autoenc_dim=slot_dim,
|
261 |
+
use_repa=use_repa,
|
262 |
+
encoder_depth=repa_encoder_depth,
|
263 |
+
z_dim=768,
|
264 |
+
)
|
265 |
+
self.vae = AutoencoderKL.from_pretrained(vae)
|
266 |
+
self.scaling_factor = self.vae.config.scaling_factor
|
267 |
+
self.vae.eval().requires_grad_(False)
|
268 |
+
|
269 |
+
self.enc_img_size = enc_img_size
|
270 |
+
self.enc_causal = enc_causal
|
271 |
+
encoder_fn = vision_transformer.__dict__[encoder]
|
272 |
+
|
273 |
+
self.encoder = encoder_fn(
|
274 |
+
img_size=[enc_img_size],
|
275 |
+
num_slots=num_slots,
|
276 |
+
drop_path_rate=drop_path_rate,
|
277 |
+
)
|
278 |
+
self.num_slots = num_slots
|
279 |
+
self.norm_slots = norm_slots
|
280 |
+
self.num_channels = self.encoder.num_features
|
281 |
+
|
282 |
+
self.encoder2slot = nn.Linear(self.num_channels, slot_dim)
|
283 |
+
self.nested_sampler = NestedSampler(num_slots)
|
284 |
+
self.enable_nest = enable_nest
|
285 |
+
self.enable_nest_after = enable_nest_after
|
286 |
+
|
287 |
+
@torch.no_grad()
|
288 |
+
def vae_encode(self, x):
|
289 |
+
x = x * 2 - 1
|
290 |
+
x = self.vae.encode(x)
|
291 |
+
if hasattr(x, 'latent_dist'):
|
292 |
+
x = x.latent_dist
|
293 |
+
return x.sample().mul_(self.scaling_factor)
|
294 |
+
|
295 |
+
@torch.no_grad()
|
296 |
+
def vae_decode(self, z):
|
297 |
+
z = self.vae.decode(z / self.scaling_factor)
|
298 |
+
if hasattr(z, 'sample'):
|
299 |
+
z = z.sample
|
300 |
+
return (z + 1) / 2
|
301 |
+
|
302 |
+
@torch.no_grad()
|
303 |
+
def repa_encode(self, x):
|
304 |
+
mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
305 |
+
std = torch.Tensor(IMAGENET_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
306 |
+
x = (x - mean) / std
|
307 |
+
if self.repa_encoder.image_size != self.enc_img_size:
|
308 |
+
x = torch.nn.functional.interpolate(x, self.repa_encoder.image_size, mode='bicubic')
|
309 |
+
x = self.repa_encoder.forward_features(x)['x_norm_patchtokens']
|
310 |
+
return x
|
311 |
+
|
312 |
+
def encode_slots(self, x):
|
313 |
+
slots = self.encoder(x, is_causal=self.enc_causal)
|
314 |
+
slots = self.encoder2slot(slots)
|
315 |
+
if self.norm_slots:
|
316 |
+
slots_std = torch.std(slots, dim=-1, keepdim=True)
|
317 |
+
slots_mean = torch.mean(slots, dim=-1, keepdim=True)
|
318 |
+
slots = (slots - slots_mean) / slots_std
|
319 |
+
return slots
|
320 |
+
|
321 |
+
def forward_with_latents(self,
|
322 |
+
x_vae,
|
323 |
+
slots,
|
324 |
+
z,
|
325 |
+
sample=False,
|
326 |
+
epoch=None,
|
327 |
+
inference_with_n_slots=-1,
|
328 |
+
cfg=1.0):
|
329 |
+
losses = {}
|
330 |
+
batch_size = x_vae.shape[0]
|
331 |
+
device = x_vae.device
|
332 |
+
|
333 |
+
if (
|
334 |
+
epoch is not None
|
335 |
+
and epoch >= self.enable_nest_after
|
336 |
+
and self.enable_nest_after != -1
|
337 |
+
):
|
338 |
+
self.enable_nest = True
|
339 |
+
|
340 |
+
t = torch.randint(0, 1000, (x_vae.shape[0],), device=device)
|
341 |
+
|
342 |
+
if self.enable_nest or inference_with_n_slots != -1:
|
343 |
+
drop_mask = self.nested_sampler(
|
344 |
+
batch_size, device,
|
345 |
+
inference_with_n_slots=inference_with_n_slots,
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
drop_mask = None
|
349 |
+
|
350 |
+
if sample:
|
351 |
+
return self.sample(slots, drop_mask=drop_mask, cfg=cfg)
|
352 |
+
|
353 |
+
model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask)
|
354 |
+
loss_dict = self.diffusion.training_losses(self.dit, x_vae, t, model_kwargs)
|
355 |
+
diff_loss = loss_dict["loss"].mean()
|
356 |
+
losses["diff_loss"] = diff_loss
|
357 |
+
|
358 |
+
if self.use_repa:
|
359 |
+
assert self.dit._repa_hook is not None and z is not None
|
360 |
+
z_tilde = self.dit._repa_hook
|
361 |
+
|
362 |
+
if z_tilde.shape[1] != z.shape[1]:
|
363 |
+
z_tilde = interpolate_features(z_tilde, z.shape[1])
|
364 |
+
|
365 |
+
z_tilde = F.normalize(z_tilde, dim=-1)
|
366 |
+
z = F.normalize(z, dim=-1)
|
367 |
+
repa_loss = -torch.sum(z_tilde * z, dim=-1)
|
368 |
+
losses["repa_loss"] = repa_loss.mean() * self.repa_loss_weight
|
369 |
+
|
370 |
+
return losses
|
371 |
+
|
372 |
+
|
373 |
+
def forward(self,
|
374 |
+
x,
|
375 |
+
sample=False,
|
376 |
+
epoch=None,
|
377 |
+
inference_with_n_slots=-1,
|
378 |
+
cfg=1.0):
|
379 |
+
|
380 |
+
x_vae = self.vae_encode(x)
|
381 |
+
z = self.repa_encode(x) if self.use_repa else None
|
382 |
+
slots = self.encode_slots(x)
|
383 |
+
return self.forward_with_latents(x_vae, slots, z, sample, epoch, inference_with_n_slots, cfg)
|
384 |
+
|
385 |
+
|
386 |
+
@torch.no_grad()
|
387 |
+
def sample(self, slots, drop_mask=None, cfg=1.0):
|
388 |
+
batch_size = slots.shape[0]
|
389 |
+
device = slots.device
|
390 |
+
z = torch.randn(batch_size, self.dit_in_channels, self.dit_input_size, self.dit_input_size, device=device)
|
391 |
+
if cfg != 1.0:
|
392 |
+
z = torch.cat([z, z], 0)
|
393 |
+
null_slots = self.dit.null_cond.expand(batch_size, -1, -1)
|
394 |
+
slots = torch.cat([slots, null_slots], 0)
|
395 |
+
if drop_mask is not None:
|
396 |
+
null_cond_mask = torch.ones_like(drop_mask)
|
397 |
+
drop_mask = torch.cat([drop_mask, null_cond_mask], 0)
|
398 |
+
model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask, cfg_scale=cfg)
|
399 |
+
sample_fn = self.dit.forward_with_cfg
|
400 |
+
else:
|
401 |
+
model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask)
|
402 |
+
sample_fn = self.dit.forward
|
403 |
+
samples = self.gen_diffusion.p_sample_loop(
|
404 |
+
sample_fn,
|
405 |
+
z.shape,
|
406 |
+
z,
|
407 |
+
clip_denoised=False,
|
408 |
+
model_kwargs=model_kwargs,
|
409 |
+
progress=False,
|
410 |
+
device=device,
|
411 |
+
)
|
412 |
+
if cfg != 1.0:
|
413 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
414 |
+
samples = self.vae_decode(samples)
|
415 |
+
return samples
|
416 |
+
|
417 |
+
def train(self, mode=True):
|
418 |
+
"""Override train() to keep certain components in eval mode"""
|
419 |
+
super().train(mode)
|
420 |
+
self.vae.eval()
|
421 |
+
return self
|
422 |
+
|
423 |
+
|
424 |
+
def build_mlp(hidden_size, projector_dim, z_dim):
|
425 |
+
return nn.Sequential(
|
426 |
+
nn.Linear(hidden_size, projector_dim),
|
427 |
+
nn.SiLU(),
|
428 |
+
nn.Linear(projector_dim, projector_dim),
|
429 |
+
nn.SiLU(),
|
430 |
+
nn.Linear(projector_dim, z_dim),
|
431 |
+
)
|
432 |
+
|
433 |
+
def interpolate_features(x, target_len):
|
434 |
+
"""Interpolate features to match target sequence length.
|
435 |
+
Args:
|
436 |
+
x: tensor of shape (B, T1, D)
|
437 |
+
target_len: desired sequence length T2
|
438 |
+
Returns:
|
439 |
+
tensor of shape (B, T2, D)
|
440 |
+
"""
|
441 |
+
B, T1, D = x.shape
|
442 |
+
H1 = W1 = int(math.sqrt(T1))
|
443 |
+
H2 = W2 = int(math.sqrt(target_len))
|
444 |
+
|
445 |
+
# Reshape to 2D spatial dimensions and move channels to second dimension
|
446 |
+
x = x.reshape(B, H1, W1, D).permute(0, 3, 1, 2)
|
447 |
+
|
448 |
+
# Interpolate
|
449 |
+
x = F.interpolate(x, size=(H2, W2), mode='bicubic', align_corners=False)
|
450 |
+
|
451 |
+
# Reshape back to sequence
|
452 |
+
return x.permute(0, 2, 3, 1).reshape(B, target_len, D)
|
semanticist/stage1/diffusion/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from . import gaussian_diffusion as gd
|
7 |
+
from .respace import SpacedDiffusion, space_timesteps
|
8 |
+
|
9 |
+
|
10 |
+
def create_diffusion(
|
11 |
+
timestep_respacing,
|
12 |
+
noise_schedule="linear",
|
13 |
+
use_kl=False,
|
14 |
+
sigma_small=False,
|
15 |
+
predict_xstart=False,
|
16 |
+
learn_sigma=True,
|
17 |
+
rescale_learned_sigmas=False,
|
18 |
+
diffusion_steps=1000
|
19 |
+
):
|
20 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
21 |
+
if use_kl:
|
22 |
+
loss_type = gd.LossType.RESCALED_KL
|
23 |
+
elif rescale_learned_sigmas:
|
24 |
+
loss_type = gd.LossType.RESCALED_MSE
|
25 |
+
else:
|
26 |
+
loss_type = gd.LossType.MSE
|
27 |
+
if timestep_respacing is None or timestep_respacing == "":
|
28 |
+
timestep_respacing = [diffusion_steps]
|
29 |
+
return SpacedDiffusion(
|
30 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
31 |
+
betas=betas,
|
32 |
+
model_mean_type=(
|
33 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
34 |
+
),
|
35 |
+
model_var_type=(
|
36 |
+
(
|
37 |
+
gd.ModelVarType.FIXED_LARGE
|
38 |
+
if not sigma_small
|
39 |
+
else gd.ModelVarType.FIXED_SMALL
|
40 |
+
)
|
41 |
+
if not learn_sigma
|
42 |
+
else gd.ModelVarType.LEARNED_RANGE
|
43 |
+
),
|
44 |
+
loss_type=loss_type
|
45 |
+
# rescale_timesteps=rescale_timesteps,
|
46 |
+
)
|
semanticist/stage1/diffusion/diffusion_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import torch as th
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [
|
26 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
27 |
+
for x in (logvar1, logvar2)
|
28 |
+
]
|
29 |
+
|
30 |
+
return 0.5 * (
|
31 |
+
-1.0
|
32 |
+
+ logvar2
|
33 |
+
- logvar1
|
34 |
+
+ th.exp(logvar1 - logvar2)
|
35 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def approx_standard_normal_cdf(x):
|
40 |
+
"""
|
41 |
+
A fast approximation of the cumulative distribution function of the
|
42 |
+
standard normal.
|
43 |
+
"""
|
44 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
45 |
+
|
46 |
+
|
47 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
48 |
+
"""
|
49 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
50 |
+
:param x: the targets
|
51 |
+
:param means: the Gaussian mean Tensor.
|
52 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
53 |
+
:return: a tensor like x of log probabilities (in nats).
|
54 |
+
"""
|
55 |
+
centered_x = x - means
|
56 |
+
inv_stdv = th.exp(-log_scales)
|
57 |
+
normalized_x = centered_x * inv_stdv
|
58 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
59 |
+
return log_probs
|
60 |
+
|
61 |
+
|
62 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
63 |
+
"""
|
64 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
65 |
+
given image.
|
66 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
67 |
+
rescaled to the range [-1, 1].
|
68 |
+
:param means: the Gaussian mean Tensor.
|
69 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
70 |
+
:return: a tensor like x of log probabilities (in nats).
|
71 |
+
"""
|
72 |
+
assert x.shape == means.shape == log_scales.shape
|
73 |
+
centered_x = x - means
|
74 |
+
inv_stdv = th.exp(-log_scales)
|
75 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
76 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
77 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
78 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
79 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
80 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
81 |
+
cdf_delta = cdf_plus - cdf_min
|
82 |
+
log_probs = th.where(
|
83 |
+
x < -0.999,
|
84 |
+
log_cdf_plus,
|
85 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
86 |
+
)
|
87 |
+
assert log_probs.shape == x.shape
|
88 |
+
return log_probs
|
semanticist/stage1/diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,886 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
import enum
|
12 |
+
|
13 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
+
|
15 |
+
|
16 |
+
def mean_flat(tensor):
|
17 |
+
"""
|
18 |
+
Take the mean over all non-batch dimensions.
|
19 |
+
"""
|
20 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
21 |
+
|
22 |
+
|
23 |
+
class ModelMeanType(enum.Enum):
|
24 |
+
"""
|
25 |
+
Which type of output the model predicts.
|
26 |
+
"""
|
27 |
+
|
28 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
29 |
+
START_X = enum.auto() # the model predicts x_0
|
30 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
31 |
+
|
32 |
+
|
33 |
+
class ModelVarType(enum.Enum):
|
34 |
+
"""
|
35 |
+
What is used as the model's output variance.
|
36 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
37 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
38 |
+
"""
|
39 |
+
|
40 |
+
LEARNED = enum.auto()
|
41 |
+
FIXED_SMALL = enum.auto()
|
42 |
+
FIXED_LARGE = enum.auto()
|
43 |
+
LEARNED_RANGE = enum.auto()
|
44 |
+
|
45 |
+
|
46 |
+
class LossType(enum.Enum):
|
47 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
48 |
+
RESCALED_MSE = (
|
49 |
+
enum.auto()
|
50 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
51 |
+
KL = enum.auto() # use the variational lower-bound
|
52 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
53 |
+
|
54 |
+
def is_vb(self):
|
55 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
56 |
+
|
57 |
+
|
58 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
59 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
60 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
61 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
62 |
+
return betas
|
63 |
+
|
64 |
+
|
65 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
66 |
+
"""
|
67 |
+
This is the deprecated API for creating beta schedules.
|
68 |
+
See get_named_beta_schedule() for the new library of schedules.
|
69 |
+
"""
|
70 |
+
if beta_schedule == "quad":
|
71 |
+
betas = (
|
72 |
+
np.linspace(
|
73 |
+
beta_start ** 0.5,
|
74 |
+
beta_end ** 0.5,
|
75 |
+
num_diffusion_timesteps,
|
76 |
+
dtype=np.float64,
|
77 |
+
)
|
78 |
+
** 2
|
79 |
+
)
|
80 |
+
elif beta_schedule == "linear":
|
81 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
82 |
+
elif beta_schedule == "warmup10":
|
83 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
84 |
+
elif beta_schedule == "warmup50":
|
85 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
86 |
+
elif beta_schedule == "const":
|
87 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
88 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
89 |
+
betas = 1.0 / np.linspace(
|
90 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError(beta_schedule)
|
94 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
95 |
+
return betas
|
96 |
+
|
97 |
+
|
98 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
99 |
+
"""
|
100 |
+
Get a pre-defined beta schedule for the given name.
|
101 |
+
The beta schedule library consists of beta schedules which remain similar
|
102 |
+
in the limit of num_diffusion_timesteps.
|
103 |
+
Beta schedules may be added, but should not be removed or changed once
|
104 |
+
they are committed to maintain backwards compatibility.
|
105 |
+
"""
|
106 |
+
if schedule_name == "linear":
|
107 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
108 |
+
# diffusion steps.
|
109 |
+
scale = 1000 / num_diffusion_timesteps
|
110 |
+
return get_beta_schedule(
|
111 |
+
"linear",
|
112 |
+
beta_start=scale * 0.0001,
|
113 |
+
beta_end=scale * 0.02,
|
114 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
115 |
+
)
|
116 |
+
elif schedule_name == "cosine":
|
117 |
+
return betas_for_alpha_bar(
|
118 |
+
num_diffusion_timesteps,
|
119 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
123 |
+
|
124 |
+
|
125 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
126 |
+
"""
|
127 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
128 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
129 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
130 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
131 |
+
produces the cumulative product of (1-beta) up to that
|
132 |
+
part of the diffusion process.
|
133 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
134 |
+
prevent singularities.
|
135 |
+
"""
|
136 |
+
betas = []
|
137 |
+
for i in range(num_diffusion_timesteps):
|
138 |
+
t1 = i / num_diffusion_timesteps
|
139 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
140 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
141 |
+
return np.array(betas)
|
142 |
+
|
143 |
+
|
144 |
+
class GaussianDiffusion:
|
145 |
+
"""
|
146 |
+
Utilities for training and sampling diffusion models.
|
147 |
+
Original ported from this codebase:
|
148 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
149 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
150 |
+
starting at T and going to 1.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
*,
|
156 |
+
betas,
|
157 |
+
model_mean_type,
|
158 |
+
model_var_type,
|
159 |
+
loss_type
|
160 |
+
):
|
161 |
+
|
162 |
+
self.model_mean_type = model_mean_type
|
163 |
+
self.model_var_type = model_var_type
|
164 |
+
self.loss_type = loss_type
|
165 |
+
|
166 |
+
# Use float64 for accuracy.
|
167 |
+
betas = np.array(betas, dtype=np.float64)
|
168 |
+
self.betas = betas
|
169 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
170 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
171 |
+
|
172 |
+
self.num_timesteps = int(betas.shape[0])
|
173 |
+
|
174 |
+
alphas = 1.0 - betas
|
175 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
176 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
177 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
178 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
179 |
+
|
180 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
181 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
182 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
183 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
184 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
185 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
186 |
+
|
187 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
188 |
+
self.posterior_variance = (
|
189 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
190 |
+
)
|
191 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
192 |
+
self.posterior_log_variance_clipped = np.log(
|
193 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
194 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
195 |
+
|
196 |
+
self.posterior_mean_coef1 = (
|
197 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
198 |
+
)
|
199 |
+
self.posterior_mean_coef2 = (
|
200 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
201 |
+
)
|
202 |
+
|
203 |
+
def q_mean_variance(self, x_start, t):
|
204 |
+
"""
|
205 |
+
Get the distribution q(x_t | x_0).
|
206 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
207 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
208 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
209 |
+
"""
|
210 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
211 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
212 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
213 |
+
return mean, variance, log_variance
|
214 |
+
|
215 |
+
def q_sample(self, x_start, t, noise=None):
|
216 |
+
"""
|
217 |
+
Diffuse the data for a given number of diffusion steps.
|
218 |
+
In other words, sample from q(x_t | x_0).
|
219 |
+
:param x_start: the initial data batch.
|
220 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
221 |
+
:param noise: if specified, the split-out normal noise.
|
222 |
+
:return: A noisy version of x_start.
|
223 |
+
"""
|
224 |
+
if noise is None:
|
225 |
+
noise = th.randn_like(x_start)
|
226 |
+
assert noise.shape == x_start.shape
|
227 |
+
return (
|
228 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
229 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
230 |
+
)
|
231 |
+
|
232 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
233 |
+
"""
|
234 |
+
Compute the mean and variance of the diffusion posterior:
|
235 |
+
q(x_{t-1} | x_t, x_0)
|
236 |
+
"""
|
237 |
+
assert x_start.shape == x_t.shape
|
238 |
+
posterior_mean = (
|
239 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
240 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
241 |
+
)
|
242 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
243 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
244 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
245 |
+
)
|
246 |
+
assert (
|
247 |
+
posterior_mean.shape[0]
|
248 |
+
== posterior_variance.shape[0]
|
249 |
+
== posterior_log_variance_clipped.shape[0]
|
250 |
+
== x_start.shape[0]
|
251 |
+
)
|
252 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
253 |
+
|
254 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
255 |
+
"""
|
256 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
257 |
+
the initial x, x_0.
|
258 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
259 |
+
as input.
|
260 |
+
:param x: the [N x C x ...] tensor at time t.
|
261 |
+
:param t: a 1-D Tensor of timesteps.
|
262 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
263 |
+
:param denoised_fn: if not None, a function which applies to the
|
264 |
+
x_start prediction before it is used to sample. Applies before
|
265 |
+
clip_denoised.
|
266 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
267 |
+
pass to the model. This can be used for conditioning.
|
268 |
+
:return: a dict with the following keys:
|
269 |
+
- 'mean': the model mean output.
|
270 |
+
- 'variance': the model variance output.
|
271 |
+
- 'log_variance': the log of 'variance'.
|
272 |
+
- 'pred_xstart': the prediction for x_0.
|
273 |
+
"""
|
274 |
+
if model_kwargs is None:
|
275 |
+
model_kwargs = {}
|
276 |
+
|
277 |
+
B, C = x.shape[:2]
|
278 |
+
assert t.shape == (B,)
|
279 |
+
model_output = model(x, t, **model_kwargs)
|
280 |
+
if isinstance(model_output, tuple):
|
281 |
+
model_output, extra = model_output
|
282 |
+
else:
|
283 |
+
extra = None
|
284 |
+
|
285 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
286 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
287 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
288 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
289 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
290 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
291 |
+
frac = (model_var_values + 1) / 2
|
292 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
293 |
+
model_variance = th.exp(model_log_variance)
|
294 |
+
else:
|
295 |
+
model_variance, model_log_variance = {
|
296 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
297 |
+
# to get a better decoder log likelihood.
|
298 |
+
ModelVarType.FIXED_LARGE: (
|
299 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
300 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
301 |
+
),
|
302 |
+
ModelVarType.FIXED_SMALL: (
|
303 |
+
self.posterior_variance,
|
304 |
+
self.posterior_log_variance_clipped,
|
305 |
+
),
|
306 |
+
}[self.model_var_type]
|
307 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
308 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
309 |
+
|
310 |
+
def process_xstart(x):
|
311 |
+
if denoised_fn is not None:
|
312 |
+
x = denoised_fn(x)
|
313 |
+
if clip_denoised:
|
314 |
+
return x.clamp(-1, 1)
|
315 |
+
return x
|
316 |
+
|
317 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
318 |
+
pred_xstart = process_xstart(model_output)
|
319 |
+
else:
|
320 |
+
pred_xstart = process_xstart(
|
321 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
322 |
+
)
|
323 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
324 |
+
|
325 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
326 |
+
return {
|
327 |
+
"mean": model_mean,
|
328 |
+
"variance": model_variance,
|
329 |
+
"log_variance": model_log_variance,
|
330 |
+
"pred_xstart": pred_xstart,
|
331 |
+
"extra": extra,
|
332 |
+
}
|
333 |
+
|
334 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
335 |
+
assert x_t.shape == eps.shape
|
336 |
+
return (
|
337 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
338 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
339 |
+
)
|
340 |
+
|
341 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
342 |
+
return (
|
343 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
344 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
345 |
+
|
346 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
347 |
+
"""
|
348 |
+
Compute the mean for the previous step, given a function cond_fn that
|
349 |
+
computes the gradient of a conditional log probability with respect to
|
350 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
351 |
+
condition on y.
|
352 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
353 |
+
"""
|
354 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
355 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
356 |
+
return new_mean
|
357 |
+
|
358 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
359 |
+
"""
|
360 |
+
Compute what the p_mean_variance output would have been, should the
|
361 |
+
model's score function be conditioned by cond_fn.
|
362 |
+
See condition_mean() for details on cond_fn.
|
363 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
364 |
+
from Song et al (2020).
|
365 |
+
"""
|
366 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
367 |
+
|
368 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
369 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
370 |
+
|
371 |
+
out = p_mean_var.copy()
|
372 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
373 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
374 |
+
return out
|
375 |
+
|
376 |
+
def p_sample(
|
377 |
+
self,
|
378 |
+
model,
|
379 |
+
x,
|
380 |
+
t,
|
381 |
+
clip_denoised=True,
|
382 |
+
denoised_fn=None,
|
383 |
+
cond_fn=None,
|
384 |
+
model_kwargs=None,
|
385 |
+
temperature=1.0
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Sample x_{t-1} from the model at the given timestep.
|
389 |
+
:param model: the model to sample from.
|
390 |
+
:param x: the current tensor at x_{t-1}.
|
391 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
392 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
393 |
+
:param denoised_fn: if not None, a function which applies to the
|
394 |
+
x_start prediction before it is used to sample.
|
395 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
396 |
+
similarly to the model.
|
397 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
398 |
+
pass to the model. This can be used for conditioning.
|
399 |
+
:param temperature: temperature scaling during Diff Loss sampling.
|
400 |
+
:return: a dict containing the following keys:
|
401 |
+
- 'sample': a random sample from the model.
|
402 |
+
- 'pred_xstart': a prediction of x_0.
|
403 |
+
"""
|
404 |
+
out = self.p_mean_variance(
|
405 |
+
model,
|
406 |
+
x,
|
407 |
+
t,
|
408 |
+
clip_denoised=clip_denoised,
|
409 |
+
denoised_fn=denoised_fn,
|
410 |
+
model_kwargs=model_kwargs,
|
411 |
+
)
|
412 |
+
noise = th.randn_like(x)
|
413 |
+
nonzero_mask = (
|
414 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
415 |
+
) # no noise when t == 0
|
416 |
+
if cond_fn is not None:
|
417 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
418 |
+
# scale the noise by temperature
|
419 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
|
420 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
421 |
+
|
422 |
+
def p_sample_loop(
|
423 |
+
self,
|
424 |
+
model,
|
425 |
+
shape,
|
426 |
+
noise=None,
|
427 |
+
clip_denoised=True,
|
428 |
+
denoised_fn=None,
|
429 |
+
cond_fn=None,
|
430 |
+
model_kwargs=None,
|
431 |
+
device=None,
|
432 |
+
progress=False,
|
433 |
+
temperature=1.0,
|
434 |
+
):
|
435 |
+
"""
|
436 |
+
Generate samples from the model.
|
437 |
+
:param model: the model module.
|
438 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
439 |
+
:param noise: if specified, the noise from the encoder to sample.
|
440 |
+
Should be of the same shape as `shape`.
|
441 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
442 |
+
:param denoised_fn: if not None, a function which applies to the
|
443 |
+
x_start prediction before it is used to sample.
|
444 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
445 |
+
similarly to the model.
|
446 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
447 |
+
pass to the model. This can be used for conditioning.
|
448 |
+
:param device: if specified, the device to create the samples on.
|
449 |
+
If not specified, use a model parameter's device.
|
450 |
+
:param progress: if True, show a tqdm progress bar.
|
451 |
+
:param temperature: temperature scaling during Diff Loss sampling.
|
452 |
+
:return: a non-differentiable batch of samples.
|
453 |
+
"""
|
454 |
+
final = None
|
455 |
+
for sample in self.p_sample_loop_progressive(
|
456 |
+
model,
|
457 |
+
shape,
|
458 |
+
noise=noise,
|
459 |
+
clip_denoised=clip_denoised,
|
460 |
+
denoised_fn=denoised_fn,
|
461 |
+
cond_fn=cond_fn,
|
462 |
+
model_kwargs=model_kwargs,
|
463 |
+
device=device,
|
464 |
+
progress=progress,
|
465 |
+
temperature=temperature,
|
466 |
+
):
|
467 |
+
final = sample
|
468 |
+
return final["sample"]
|
469 |
+
|
470 |
+
def p_sample_loop_progressive(
|
471 |
+
self,
|
472 |
+
model,
|
473 |
+
shape,
|
474 |
+
noise=None,
|
475 |
+
clip_denoised=True,
|
476 |
+
denoised_fn=None,
|
477 |
+
cond_fn=None,
|
478 |
+
model_kwargs=None,
|
479 |
+
device=None,
|
480 |
+
progress=False,
|
481 |
+
temperature=1.0,
|
482 |
+
):
|
483 |
+
"""
|
484 |
+
Generate samples from the model and yield intermediate samples from
|
485 |
+
each timestep of diffusion.
|
486 |
+
Arguments are the same as p_sample_loop().
|
487 |
+
Returns a generator over dicts, where each dict is the return value of
|
488 |
+
p_sample().
|
489 |
+
"""
|
490 |
+
if device is None:
|
491 |
+
device = next(model.parameters()).device
|
492 |
+
assert isinstance(shape, (tuple, list))
|
493 |
+
if noise is not None:
|
494 |
+
img = noise
|
495 |
+
else:
|
496 |
+
img = th.randn(*shape, device=device)
|
497 |
+
indices = list(range(self.num_timesteps))[::-1]
|
498 |
+
|
499 |
+
if progress:
|
500 |
+
# Lazy import so that we don't depend on tqdm.
|
501 |
+
from tqdm.auto import tqdm
|
502 |
+
|
503 |
+
indices = tqdm(indices)
|
504 |
+
|
505 |
+
for i in indices:
|
506 |
+
t = th.tensor([i] * shape[0], device=device)
|
507 |
+
with th.no_grad():
|
508 |
+
out = self.p_sample(
|
509 |
+
model,
|
510 |
+
img,
|
511 |
+
t,
|
512 |
+
clip_denoised=clip_denoised,
|
513 |
+
denoised_fn=denoised_fn,
|
514 |
+
cond_fn=cond_fn,
|
515 |
+
model_kwargs=model_kwargs,
|
516 |
+
temperature=temperature,
|
517 |
+
)
|
518 |
+
yield out
|
519 |
+
img = out["sample"]
|
520 |
+
|
521 |
+
def ddim_sample(
|
522 |
+
self,
|
523 |
+
model,
|
524 |
+
x,
|
525 |
+
t,
|
526 |
+
clip_denoised=True,
|
527 |
+
denoised_fn=None,
|
528 |
+
cond_fn=None,
|
529 |
+
model_kwargs=None,
|
530 |
+
eta=0.0,
|
531 |
+
):
|
532 |
+
"""
|
533 |
+
Sample x_{t-1} from the model using DDIM.
|
534 |
+
Same usage as p_sample().
|
535 |
+
"""
|
536 |
+
out = self.p_mean_variance(
|
537 |
+
model,
|
538 |
+
x,
|
539 |
+
t,
|
540 |
+
clip_denoised=clip_denoised,
|
541 |
+
denoised_fn=denoised_fn,
|
542 |
+
model_kwargs=model_kwargs,
|
543 |
+
)
|
544 |
+
if cond_fn is not None:
|
545 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
546 |
+
|
547 |
+
# Usually our model outputs epsilon, but we re-derive it
|
548 |
+
# in case we used x_start or x_prev prediction.
|
549 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
550 |
+
|
551 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
552 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
553 |
+
sigma = (
|
554 |
+
eta
|
555 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
556 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
557 |
+
)
|
558 |
+
# Equation 12.
|
559 |
+
noise = th.randn_like(x)
|
560 |
+
mean_pred = (
|
561 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
562 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
563 |
+
)
|
564 |
+
nonzero_mask = (
|
565 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
566 |
+
) # no noise when t == 0
|
567 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
568 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
569 |
+
|
570 |
+
def ddim_reverse_sample(
|
571 |
+
self,
|
572 |
+
model,
|
573 |
+
x,
|
574 |
+
t,
|
575 |
+
clip_denoised=True,
|
576 |
+
denoised_fn=None,
|
577 |
+
cond_fn=None,
|
578 |
+
model_kwargs=None,
|
579 |
+
eta=0.0,
|
580 |
+
):
|
581 |
+
"""
|
582 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
583 |
+
"""
|
584 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
585 |
+
out = self.p_mean_variance(
|
586 |
+
model,
|
587 |
+
x,
|
588 |
+
t,
|
589 |
+
clip_denoised=clip_denoised,
|
590 |
+
denoised_fn=denoised_fn,
|
591 |
+
model_kwargs=model_kwargs,
|
592 |
+
)
|
593 |
+
if cond_fn is not None:
|
594 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
595 |
+
# Usually our model outputs epsilon, but we re-derive it
|
596 |
+
# in case we used x_start or x_prev prediction.
|
597 |
+
eps = (
|
598 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
599 |
+
- out["pred_xstart"]
|
600 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
601 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
602 |
+
|
603 |
+
# Equation 12. reversed
|
604 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
605 |
+
|
606 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
607 |
+
|
608 |
+
def ddim_sample_loop(
|
609 |
+
self,
|
610 |
+
model,
|
611 |
+
shape,
|
612 |
+
noise=None,
|
613 |
+
clip_denoised=True,
|
614 |
+
denoised_fn=None,
|
615 |
+
cond_fn=None,
|
616 |
+
model_kwargs=None,
|
617 |
+
device=None,
|
618 |
+
progress=False,
|
619 |
+
eta=0.0,
|
620 |
+
):
|
621 |
+
"""
|
622 |
+
Generate samples from the model using DDIM.
|
623 |
+
Same usage as p_sample_loop().
|
624 |
+
"""
|
625 |
+
final = None
|
626 |
+
for sample in self.ddim_sample_loop_progressive(
|
627 |
+
model,
|
628 |
+
shape,
|
629 |
+
noise=noise,
|
630 |
+
clip_denoised=clip_denoised,
|
631 |
+
denoised_fn=denoised_fn,
|
632 |
+
cond_fn=cond_fn,
|
633 |
+
model_kwargs=model_kwargs,
|
634 |
+
device=device,
|
635 |
+
progress=progress,
|
636 |
+
eta=eta,
|
637 |
+
):
|
638 |
+
final = sample
|
639 |
+
return final["sample"]
|
640 |
+
|
641 |
+
def ddim_sample_loop_progressive(
|
642 |
+
self,
|
643 |
+
model,
|
644 |
+
shape,
|
645 |
+
noise=None,
|
646 |
+
clip_denoised=True,
|
647 |
+
denoised_fn=None,
|
648 |
+
cond_fn=None,
|
649 |
+
model_kwargs=None,
|
650 |
+
device=None,
|
651 |
+
progress=False,
|
652 |
+
eta=0.0,
|
653 |
+
):
|
654 |
+
"""
|
655 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
656 |
+
each timestep of DDIM.
|
657 |
+
Same usage as p_sample_loop_progressive().
|
658 |
+
"""
|
659 |
+
if device is None:
|
660 |
+
device = next(model.parameters()).device
|
661 |
+
assert isinstance(shape, (tuple, list))
|
662 |
+
if noise is not None:
|
663 |
+
img = noise
|
664 |
+
else:
|
665 |
+
img = th.randn(*shape, device=device)
|
666 |
+
indices = list(range(self.num_timesteps))[::-1]
|
667 |
+
|
668 |
+
if progress:
|
669 |
+
# Lazy import so that we don't depend on tqdm.
|
670 |
+
from tqdm.auto import tqdm
|
671 |
+
|
672 |
+
indices = tqdm(indices)
|
673 |
+
|
674 |
+
for i in indices:
|
675 |
+
t = th.tensor([i] * shape[0], device=device)
|
676 |
+
with th.no_grad():
|
677 |
+
out = self.ddim_sample(
|
678 |
+
model,
|
679 |
+
img,
|
680 |
+
t,
|
681 |
+
clip_denoised=clip_denoised,
|
682 |
+
denoised_fn=denoised_fn,
|
683 |
+
cond_fn=cond_fn,
|
684 |
+
model_kwargs=model_kwargs,
|
685 |
+
eta=eta,
|
686 |
+
)
|
687 |
+
yield out
|
688 |
+
img = out["sample"]
|
689 |
+
|
690 |
+
def _vb_terms_bpd(
|
691 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
692 |
+
):
|
693 |
+
"""
|
694 |
+
Get a term for the variational lower-bound.
|
695 |
+
The resulting units are bits (rather than nats, as one might expect).
|
696 |
+
This allows for comparison to other papers.
|
697 |
+
:return: a dict with the following keys:
|
698 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
699 |
+
- 'pred_xstart': the x_0 predictions.
|
700 |
+
"""
|
701 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
702 |
+
x_start=x_start, x_t=x_t, t=t
|
703 |
+
)
|
704 |
+
out = self.p_mean_variance(
|
705 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
706 |
+
)
|
707 |
+
kl = normal_kl(
|
708 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
709 |
+
)
|
710 |
+
kl = mean_flat(kl) / np.log(2.0)
|
711 |
+
|
712 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
713 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
714 |
+
)
|
715 |
+
assert decoder_nll.shape == x_start.shape
|
716 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
717 |
+
|
718 |
+
# At the first timestep return the decoder NLL,
|
719 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
720 |
+
output = th.where((t == 0), decoder_nll, kl)
|
721 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
722 |
+
|
723 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
724 |
+
"""
|
725 |
+
Compute training losses for a single timestep.
|
726 |
+
:param model: the model to evaluate loss on.
|
727 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
728 |
+
:param t: a batch of timestep indices.
|
729 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
730 |
+
pass to the model. This can be used for conditioning.
|
731 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
732 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
733 |
+
Some mean or variance settings may also have other keys.
|
734 |
+
"""
|
735 |
+
if model_kwargs is None:
|
736 |
+
model_kwargs = {}
|
737 |
+
if noise is None:
|
738 |
+
noise = th.randn_like(x_start)
|
739 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
740 |
+
|
741 |
+
terms = {}
|
742 |
+
|
743 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
744 |
+
terms["loss"] = self._vb_terms_bpd(
|
745 |
+
model=model,
|
746 |
+
x_start=x_start,
|
747 |
+
x_t=x_t,
|
748 |
+
t=t,
|
749 |
+
clip_denoised=False,
|
750 |
+
model_kwargs=model_kwargs,
|
751 |
+
)["output"]
|
752 |
+
if self.loss_type == LossType.RESCALED_KL:
|
753 |
+
terms["loss"] *= self.num_timesteps
|
754 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
755 |
+
model_output = model(x_t, t, **model_kwargs)
|
756 |
+
|
757 |
+
if self.model_var_type in [
|
758 |
+
ModelVarType.LEARNED,
|
759 |
+
ModelVarType.LEARNED_RANGE,
|
760 |
+
]:
|
761 |
+
B, C = x_t.shape[:2]
|
762 |
+
if len(model_output.shape) == len(x_t.shape) + 1:
|
763 |
+
x_t = x_t.unsqueeze(-1).expand(*([-1] * (len(x_t.shape))), model_output.shape[-1])
|
764 |
+
x_start = x_start.unsqueeze(-1).expand(*([-1] * (len(x_start.shape))), model_output.shape[-1])
|
765 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
766 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
767 |
+
# Learn the variance using the variational bound, but don't let
|
768 |
+
# it affect our mean prediction.
|
769 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
770 |
+
terms["vb"] = self._vb_terms_bpd(
|
771 |
+
model=lambda *args, r=frozen_out: r,
|
772 |
+
x_start=x_start,
|
773 |
+
x_t=x_t,
|
774 |
+
t=t,
|
775 |
+
clip_denoised=False,
|
776 |
+
)["output"]
|
777 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
778 |
+
# Divide by 1000 for equivalence with initial implementation.
|
779 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
780 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
781 |
+
|
782 |
+
target = {
|
783 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
784 |
+
x_start=x_start, x_t=x_t, t=t
|
785 |
+
)[0],
|
786 |
+
ModelMeanType.START_X: x_start,
|
787 |
+
ModelMeanType.EPSILON: noise,
|
788 |
+
}[self.model_mean_type]
|
789 |
+
if len(model_output.shape) == len(target.shape) + 1:
|
790 |
+
target = target.unsqueeze(-1).expand(*([-1] * (len(target.shape))), model_output.shape[-1])
|
791 |
+
assert model_output.shape == target.shape == x_start.shape
|
792 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
793 |
+
if "vb" in terms:
|
794 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
795 |
+
else:
|
796 |
+
terms["loss"] = terms["mse"]
|
797 |
+
else:
|
798 |
+
raise NotImplementedError(self.loss_type)
|
799 |
+
|
800 |
+
return terms
|
801 |
+
|
802 |
+
def _prior_bpd(self, x_start):
|
803 |
+
"""
|
804 |
+
Get the prior KL term for the variational lower-bound, measured in
|
805 |
+
bits-per-dim.
|
806 |
+
This term can't be optimized, as it only depends on the encoder.
|
807 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
808 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
809 |
+
"""
|
810 |
+
batch_size = x_start.shape[0]
|
811 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
812 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
813 |
+
kl_prior = normal_kl(
|
814 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
815 |
+
)
|
816 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
817 |
+
|
818 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
819 |
+
"""
|
820 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
821 |
+
as well as other related quantities.
|
822 |
+
:param model: the model to evaluate loss on.
|
823 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
824 |
+
:param clip_denoised: if True, clip denoised samples.
|
825 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
826 |
+
pass to the model. This can be used for conditioning.
|
827 |
+
:return: a dict containing the following keys:
|
828 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
829 |
+
- prior_bpd: the prior term in the lower-bound.
|
830 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
831 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
832 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
833 |
+
"""
|
834 |
+
device = x_start.device
|
835 |
+
batch_size = x_start.shape[0]
|
836 |
+
|
837 |
+
vb = []
|
838 |
+
xstart_mse = []
|
839 |
+
mse = []
|
840 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
841 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
842 |
+
noise = th.randn_like(x_start)
|
843 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
844 |
+
# Calculate VLB term at the current timestep
|
845 |
+
with th.no_grad():
|
846 |
+
out = self._vb_terms_bpd(
|
847 |
+
model,
|
848 |
+
x_start=x_start,
|
849 |
+
x_t=x_t,
|
850 |
+
t=t_batch,
|
851 |
+
clip_denoised=clip_denoised,
|
852 |
+
model_kwargs=model_kwargs,
|
853 |
+
)
|
854 |
+
vb.append(out["output"])
|
855 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
856 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
857 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
858 |
+
|
859 |
+
vb = th.stack(vb, dim=1)
|
860 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
861 |
+
mse = th.stack(mse, dim=1)
|
862 |
+
|
863 |
+
prior_bpd = self._prior_bpd(x_start)
|
864 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
865 |
+
return {
|
866 |
+
"total_bpd": total_bpd,
|
867 |
+
"prior_bpd": prior_bpd,
|
868 |
+
"vb": vb,
|
869 |
+
"xstart_mse": xstart_mse,
|
870 |
+
"mse": mse,
|
871 |
+
}
|
872 |
+
|
873 |
+
|
874 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
875 |
+
"""
|
876 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
877 |
+
:param arr: the 1-D numpy array.
|
878 |
+
:param timesteps: a tensor of indices into the array to extract.
|
879 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
880 |
+
dimension equal to the length of timesteps.
|
881 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
882 |
+
"""
|
883 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
884 |
+
while len(res.shape) < len(broadcast_shape):
|
885 |
+
res = res[..., None]
|
886 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
semanticist/stage1/diffusion/respace.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
if section_count <= 1:
|
52 |
+
frac_stride = 1
|
53 |
+
else:
|
54 |
+
frac_stride = (size - 1) / (section_count - 1)
|
55 |
+
cur_idx = 0.0
|
56 |
+
taken_steps = []
|
57 |
+
for _ in range(section_count):
|
58 |
+
taken_steps.append(start_idx + round(cur_idx))
|
59 |
+
cur_idx += frac_stride
|
60 |
+
all_steps += taken_steps
|
61 |
+
start_idx += size
|
62 |
+
return set(all_steps)
|
63 |
+
|
64 |
+
|
65 |
+
class SpacedDiffusion(GaussianDiffusion):
|
66 |
+
"""
|
67 |
+
A diffusion process which can skip steps in a base diffusion process.
|
68 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
69 |
+
original diffusion process to retain.
|
70 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, use_timesteps, **kwargs):
|
74 |
+
self.use_timesteps = set(use_timesteps)
|
75 |
+
self.timestep_map = []
|
76 |
+
self.original_num_steps = len(kwargs["betas"])
|
77 |
+
|
78 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
79 |
+
last_alpha_cumprod = 1.0
|
80 |
+
new_betas = []
|
81 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
82 |
+
if i in self.use_timesteps:
|
83 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
84 |
+
last_alpha_cumprod = alpha_cumprod
|
85 |
+
self.timestep_map.append(i)
|
86 |
+
kwargs["betas"] = np.array(new_betas)
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
def p_mean_variance(
|
90 |
+
self, model, *args, **kwargs
|
91 |
+
): # pylint: disable=signature-differs
|
92 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
93 |
+
|
94 |
+
def training_losses(
|
95 |
+
self, model, *args, **kwargs
|
96 |
+
): # pylint: disable=signature-differs
|
97 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
98 |
+
|
99 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
100 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
101 |
+
|
102 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
103 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
104 |
+
|
105 |
+
def _wrap_model(self, model):
|
106 |
+
if isinstance(model, _WrappedModel):
|
107 |
+
return model
|
108 |
+
return _WrappedModel(
|
109 |
+
model, self.timestep_map, self.original_num_steps
|
110 |
+
)
|
111 |
+
|
112 |
+
def _scale_timesteps(self, t):
|
113 |
+
# Scaling is done by the wrapped model.
|
114 |
+
return t
|
115 |
+
|
116 |
+
|
117 |
+
class _WrappedModel:
|
118 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
119 |
+
self.model = model
|
120 |
+
self.timestep_map = timestep_map
|
121 |
+
# self.rescale_timesteps = rescale_timesteps
|
122 |
+
self.original_num_steps = original_num_steps
|
123 |
+
|
124 |
+
def __call__(self, x, ts, **kwargs):
|
125 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
126 |
+
new_ts = map_tensor[ts]
|
127 |
+
# if self.rescale_timesteps:
|
128 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
129 |
+
return self.model(x, new_ts, **kwargs)
|
130 |
+
|
semanticist/stage1/diffusion/timestep_sampler.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as th
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def create_named_schedule_sampler(name, diffusion):
|
14 |
+
"""
|
15 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
+
:param name: the name of the sampler.
|
17 |
+
:param diffusion: the diffusion object to sample for.
|
18 |
+
"""
|
19 |
+
if name == "uniform":
|
20 |
+
return UniformSampler(diffusion)
|
21 |
+
elif name == "loss-second-moment":
|
22 |
+
return LossSecondMomentResampler(diffusion)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleSampler(ABC):
|
28 |
+
"""
|
29 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
+
variance of the objective.
|
31 |
+
By default, samplers perform unbiased importance sampling, in which the
|
32 |
+
objective's mean is unchanged.
|
33 |
+
However, subclasses may override sample() to change how the resampled
|
34 |
+
terms are reweighted, allowing for actual changes in the objective.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def weights(self):
|
39 |
+
"""
|
40 |
+
Get a numpy array of weights, one per diffusion step.
|
41 |
+
The weights needn't be normalized, but must be positive.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def sample(self, batch_size, device):
|
45 |
+
"""
|
46 |
+
Importance-sample timesteps for a batch.
|
47 |
+
:param batch_size: the number of timesteps.
|
48 |
+
:param device: the torch device to save to.
|
49 |
+
:return: a tuple (timesteps, weights):
|
50 |
+
- timesteps: a tensor of timestep indices.
|
51 |
+
- weights: a tensor of weights to scale the resulting losses.
|
52 |
+
"""
|
53 |
+
w = self.weights()
|
54 |
+
p = w / np.sum(w)
|
55 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
+
return indices, weights
|
60 |
+
|
61 |
+
|
62 |
+
class UniformSampler(ScheduleSampler):
|
63 |
+
def __init__(self, diffusion):
|
64 |
+
self.diffusion = diffusion
|
65 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
+
|
67 |
+
def weights(self):
|
68 |
+
return self._weights
|
69 |
+
|
70 |
+
|
71 |
+
class LossAwareSampler(ScheduleSampler):
|
72 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
+
"""
|
74 |
+
Update the reweighting using losses from a model.
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
:param local_ts: an integer Tensor of timesteps.
|
80 |
+
:param local_losses: a 1D Tensor of losses.
|
81 |
+
"""
|
82 |
+
batch_sizes = [
|
83 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
84 |
+
for _ in range(dist.get_world_size())
|
85 |
+
]
|
86 |
+
dist.all_gather(
|
87 |
+
batch_sizes,
|
88 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
89 |
+
)
|
90 |
+
|
91 |
+
# Pad all_gather batches to be the maximum batch size.
|
92 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
93 |
+
max_bs = max(batch_sizes)
|
94 |
+
|
95 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
96 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
97 |
+
dist.all_gather(timestep_batches, local_ts)
|
98 |
+
dist.all_gather(loss_batches, local_losses)
|
99 |
+
timesteps = [
|
100 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
101 |
+
]
|
102 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
103 |
+
self.update_with_all_losses(timesteps, losses)
|
104 |
+
|
105 |
+
@abstractmethod
|
106 |
+
def update_with_all_losses(self, ts, losses):
|
107 |
+
"""
|
108 |
+
Update the reweighting using losses from a model.
|
109 |
+
Sub-classes should override this method to update the reweighting
|
110 |
+
using losses from the model.
|
111 |
+
This method directly updates the reweighting without synchronizing
|
112 |
+
between workers. It is called by update_with_local_losses from all
|
113 |
+
ranks with identical arguments. Thus, it should have deterministic
|
114 |
+
behavior to maintain state across workers.
|
115 |
+
:param ts: a list of int timesteps.
|
116 |
+
:param losses: a list of float losses, one per timestep.
|
117 |
+
"""
|
118 |
+
|
119 |
+
|
120 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
121 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
122 |
+
self.diffusion = diffusion
|
123 |
+
self.history_per_term = history_per_term
|
124 |
+
self.uniform_prob = uniform_prob
|
125 |
+
self._loss_history = np.zeros(
|
126 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
127 |
+
)
|
128 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
129 |
+
|
130 |
+
def weights(self):
|
131 |
+
if not self._warmed_up():
|
132 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
133 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
134 |
+
weights /= np.sum(weights)
|
135 |
+
weights *= 1 - self.uniform_prob
|
136 |
+
weights += self.uniform_prob / len(weights)
|
137 |
+
return weights
|
138 |
+
|
139 |
+
def update_with_all_losses(self, ts, losses):
|
140 |
+
for t, loss in zip(ts, losses):
|
141 |
+
if self._loss_counts[t] == self.history_per_term:
|
142 |
+
# Shift out the oldest loss term.
|
143 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
144 |
+
self._loss_history[t, -1] = loss
|
145 |
+
else:
|
146 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
147 |
+
self._loss_counts[t] += 1
|
148 |
+
|
149 |
+
def _warmed_up(self):
|
150 |
+
return (self._loss_counts == self.history_per_term).all()
|
semanticist/stage1/diffusion_transfomer.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
import math
|
16 |
+
from timm.models.vision_transformer import PatchEmbed, Mlp
|
17 |
+
from semanticist.stage1.fused_attention import Attention
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def modulate(x, shift, scale):
|
22 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
23 |
+
|
24 |
+
|
25 |
+
#################################################################################
|
26 |
+
# Embedding Layers for Timesteps and Class Labels #
|
27 |
+
#################################################################################
|
28 |
+
|
29 |
+
class TimestepEmbedder(nn.Module):
|
30 |
+
"""
|
31 |
+
Embeds scalar timesteps into vector representations.
|
32 |
+
"""
|
33 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
34 |
+
super().__init__()
|
35 |
+
self.mlp = nn.Sequential(
|
36 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
37 |
+
nn.SiLU(),
|
38 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
39 |
+
)
|
40 |
+
self.frequency_embedding_size = frequency_embedding_size
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def timestep_embedding(t, dim, max_period=10000):
|
44 |
+
"""
|
45 |
+
Create sinusoidal timestep embeddings.
|
46 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
47 |
+
These may be fractional.
|
48 |
+
:param dim: the dimension of the output.
|
49 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
50 |
+
:return: an (N, D) Tensor of positional embeddings.
|
51 |
+
"""
|
52 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
53 |
+
half = dim // 2
|
54 |
+
freqs = torch.exp(
|
55 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
56 |
+
).to(device=t.device)
|
57 |
+
args = t[:, None].float() * freqs[None]
|
58 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
59 |
+
if dim % 2:
|
60 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
61 |
+
return embedding
|
62 |
+
|
63 |
+
def forward(self, t):
|
64 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
65 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
66 |
+
return t_emb
|
67 |
+
|
68 |
+
|
69 |
+
class LabelEmbedder(nn.Module):
|
70 |
+
"""
|
71 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
72 |
+
"""
|
73 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
74 |
+
super().__init__()
|
75 |
+
use_cfg_embedding = dropout_prob > 0
|
76 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
77 |
+
self.num_classes = num_classes
|
78 |
+
self.dropout_prob = dropout_prob
|
79 |
+
|
80 |
+
def token_drop(self, labels, force_drop_ids=None):
|
81 |
+
"""
|
82 |
+
Drops labels to enable classifier-free guidance.
|
83 |
+
"""
|
84 |
+
if force_drop_ids is None:
|
85 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
86 |
+
else:
|
87 |
+
drop_ids = force_drop_ids == 1
|
88 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
89 |
+
return labels
|
90 |
+
|
91 |
+
def forward(self, labels, train, force_drop_ids=None):
|
92 |
+
use_dropout = self.dropout_prob > 0
|
93 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
94 |
+
labels = self.token_drop(labels, force_drop_ids)
|
95 |
+
embeddings = self.embedding_table(labels)
|
96 |
+
return embeddings
|
97 |
+
|
98 |
+
|
99 |
+
#################################################################################
|
100 |
+
# Core DiT Model #
|
101 |
+
#################################################################################
|
102 |
+
|
103 |
+
class DiTBlock(nn.Module):
|
104 |
+
"""
|
105 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
106 |
+
"""
|
107 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
108 |
+
super().__init__()
|
109 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
110 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
111 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
112 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
113 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
114 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
115 |
+
self.adaLN_modulation = nn.Sequential(
|
116 |
+
nn.SiLU(),
|
117 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x, c, mask=None):
|
121 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
122 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)
|
123 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class FinalLayer(nn.Module):
|
128 |
+
"""
|
129 |
+
The final layer of DiT.
|
130 |
+
"""
|
131 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
132 |
+
super().__init__()
|
133 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
134 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
135 |
+
self.adaLN_modulation = nn.Sequential(
|
136 |
+
nn.SiLU(),
|
137 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, x, c):
|
141 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
142 |
+
x = modulate(self.norm_final(x), shift, scale)
|
143 |
+
x = self.linear(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class DiT(nn.Module):
|
148 |
+
"""
|
149 |
+
Diffusion model with a Transformer backbone.
|
150 |
+
"""
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
input_size=32,
|
154 |
+
patch_size=2,
|
155 |
+
in_channels=4,
|
156 |
+
hidden_size=1152,
|
157 |
+
depth=28,
|
158 |
+
num_heads=16,
|
159 |
+
mlp_ratio=4.0,
|
160 |
+
class_dropout_prob=0.1,
|
161 |
+
num_classes=1000,
|
162 |
+
learn_sigma=True,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.learn_sigma = learn_sigma
|
166 |
+
self.in_channels = in_channels
|
167 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
168 |
+
self.patch_size = patch_size
|
169 |
+
self.num_heads = num_heads
|
170 |
+
|
171 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
172 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
173 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
174 |
+
num_patches = self.x_embedder.num_patches
|
175 |
+
# Will use fixed sin-cos embedding:
|
176 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
177 |
+
|
178 |
+
self.blocks = nn.ModuleList([
|
179 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
180 |
+
])
|
181 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
182 |
+
self.initialize_weights()
|
183 |
+
|
184 |
+
def initialize_weights(self):
|
185 |
+
# Initialize transformer layers:
|
186 |
+
def _basic_init(module):
|
187 |
+
if isinstance(module, nn.Linear):
|
188 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
189 |
+
if module.bias is not None:
|
190 |
+
nn.init.constant_(module.bias, 0)
|
191 |
+
self.apply(_basic_init)
|
192 |
+
|
193 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
194 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
195 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
196 |
+
|
197 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
198 |
+
w = self.x_embedder.proj.weight.data
|
199 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
200 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
201 |
+
|
202 |
+
# Initialize label embedding table:
|
203 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
204 |
+
|
205 |
+
# Initialize timestep embedding MLP:
|
206 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
207 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
208 |
+
|
209 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
210 |
+
for block in self.blocks:
|
211 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
212 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
213 |
+
|
214 |
+
# Zero-out output layers:
|
215 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
216 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
217 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
218 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
219 |
+
|
220 |
+
def unpatchify(self, x):
|
221 |
+
"""
|
222 |
+
x: (N, T, patch_size**2 * C)
|
223 |
+
imgs: (N, H, W, C)
|
224 |
+
"""
|
225 |
+
c = self.out_channels
|
226 |
+
p = self.x_embedder.patch_size[0]
|
227 |
+
h = w = int(x.shape[1] ** 0.5)
|
228 |
+
assert h * w == x.shape[1]
|
229 |
+
|
230 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
231 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
232 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
233 |
+
return imgs
|
234 |
+
|
235 |
+
def forward(self, x, t, y):
|
236 |
+
"""
|
237 |
+
Forward pass of DiT.
|
238 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
239 |
+
t: (N,) tensor of diffusion timesteps
|
240 |
+
y: (N,) tensor of class labels
|
241 |
+
"""
|
242 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
243 |
+
t = self.t_embedder(t) # (N, D)
|
244 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
245 |
+
c = t + y # (N, D)
|
246 |
+
for block in self.blocks:
|
247 |
+
x = block(x, c) # (N, T, D)
|
248 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
249 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
250 |
+
return x
|
251 |
+
|
252 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
253 |
+
"""
|
254 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
255 |
+
"""
|
256 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
257 |
+
half = x[: len(x) // 2]
|
258 |
+
combined = torch.cat([half, half], dim=0)
|
259 |
+
model_out = self.forward(combined, t, y)
|
260 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
261 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
262 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
263 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
264 |
+
# eps, rest = model_out[:, :3], model_out[:, 3:]
|
265 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
266 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
267 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
268 |
+
return torch.cat([eps, rest], dim=1)
|
269 |
+
|
270 |
+
|
271 |
+
#################################################################################
|
272 |
+
# Sine/Cosine Positional Embedding Functions #
|
273 |
+
#################################################################################
|
274 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
275 |
+
|
276 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
277 |
+
"""
|
278 |
+
grid_size: int of the grid height and width
|
279 |
+
return:
|
280 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
281 |
+
"""
|
282 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
283 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
284 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
285 |
+
grid = np.stack(grid, axis=0)
|
286 |
+
|
287 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
288 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
289 |
+
if cls_token and extra_tokens > 0:
|
290 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
291 |
+
return pos_embed
|
292 |
+
|
293 |
+
|
294 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
295 |
+
assert embed_dim % 2 == 0
|
296 |
+
|
297 |
+
# use half of dimensions to encode grid_h
|
298 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
299 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
300 |
+
|
301 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
302 |
+
return emb
|
303 |
+
|
304 |
+
|
305 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
306 |
+
"""
|
307 |
+
embed_dim: output dimension for each position
|
308 |
+
pos: a list of positions to be encoded: size (M,)
|
309 |
+
out: (M, D)
|
310 |
+
"""
|
311 |
+
assert embed_dim % 2 == 0
|
312 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
313 |
+
omega /= embed_dim / 2.
|
314 |
+
omega = 1. / 10000**omega # (D/2,)
|
315 |
+
|
316 |
+
pos = pos.reshape(-1) # (M,)
|
317 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
318 |
+
|
319 |
+
emb_sin = np.sin(out) # (M, D/2)
|
320 |
+
emb_cos = np.cos(out) # (M, D/2)
|
321 |
+
|
322 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
323 |
+
return emb
|
324 |
+
|
325 |
+
|
326 |
+
#################################################################################
|
327 |
+
# DiT Configs #
|
328 |
+
#################################################################################
|
329 |
+
|
330 |
+
def DiT_XL_2(**kwargs):
|
331 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
332 |
+
|
333 |
+
def DiT_XL_4(**kwargs):
|
334 |
+
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
335 |
+
|
336 |
+
def DiT_XL_8(**kwargs):
|
337 |
+
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
338 |
+
|
339 |
+
def DiT_L_2(**kwargs):
|
340 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
341 |
+
|
342 |
+
def DiT_L_4(**kwargs):
|
343 |
+
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
344 |
+
|
345 |
+
def DiT_L_8(**kwargs):
|
346 |
+
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
347 |
+
|
348 |
+
def DiT_B_2(**kwargs):
|
349 |
+
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
350 |
+
|
351 |
+
def DiT_B_4(**kwargs):
|
352 |
+
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
353 |
+
|
354 |
+
def DiT_B_8(**kwargs):
|
355 |
+
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
356 |
+
|
357 |
+
def DiT_S_2(**kwargs):
|
358 |
+
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
359 |
+
|
360 |
+
def DiT_S_4(**kwargs):
|
361 |
+
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
362 |
+
|
363 |
+
def DiT_S_8(**kwargs):
|
364 |
+
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
365 |
+
|
366 |
+
|
367 |
+
DiT_models = {
|
368 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
369 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
370 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
371 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
372 |
+
}
|
semanticist/stage1/fused_attention.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from typing import Type
|
4 |
+
|
5 |
+
class Attention(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
num_heads: int = 8,
|
10 |
+
qkv_bias: bool = False,
|
11 |
+
qk_norm: bool = False,
|
12 |
+
proj_bias: bool = True,
|
13 |
+
attn_drop: float = 0.,
|
14 |
+
proj_drop: float = 0.,
|
15 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
19 |
+
self.num_heads = num_heads
|
20 |
+
self.head_dim = dim // num_heads
|
21 |
+
self.scale = self.head_dim ** -0.5
|
22 |
+
|
23 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
24 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
25 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
26 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
27 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
28 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
29 |
+
|
30 |
+
def forward(self, x, attn_mask=None):
|
31 |
+
B, N, C = x.shape
|
32 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
33 |
+
q, k, v = qkv.unbind(0)
|
34 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
35 |
+
|
36 |
+
x = F.scaled_dot_product_attention(
|
37 |
+
q, k, v,
|
38 |
+
attn_mask=attn_mask,
|
39 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
40 |
+
)
|
41 |
+
|
42 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
43 |
+
x = self.proj(x)
|
44 |
+
x = self.proj_drop(x)
|
45 |
+
return x
|
semanticist/stage1/pos_embed.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Position embedding utils
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# 2D sine-cosine position embedding
|
16 |
+
# References:
|
17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
19 |
+
# --------------------------------------------------------
|
20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
21 |
+
"""
|
22 |
+
grid_size: int of the grid height and width
|
23 |
+
return:
|
24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
25 |
+
"""
|
26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
29 |
+
grid = np.stack(grid, axis=0)
|
30 |
+
|
31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
33 |
+
if cls_token:
|
34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
35 |
+
return pos_embed
|
36 |
+
|
37 |
+
|
38 |
+
def get_1d_sincos_pos_embed(embed_dim, grid_size):
|
39 |
+
grid = np.arange(grid_size, dtype=np.float32)
|
40 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
|
41 |
+
return pos_embed
|
42 |
+
|
43 |
+
|
44 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
45 |
+
assert embed_dim % 2 == 0
|
46 |
+
|
47 |
+
# use half of dimensions to encode grid_h
|
48 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
49 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
50 |
+
|
51 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
52 |
+
return emb
|
53 |
+
|
54 |
+
|
55 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
56 |
+
"""
|
57 |
+
embed_dim: output dimension for each position
|
58 |
+
pos: a list of positions to be encoded: size (M,)
|
59 |
+
out: (M, D)
|
60 |
+
"""
|
61 |
+
assert embed_dim % 2 == 0
|
62 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
63 |
+
omega /= embed_dim / 2.
|
64 |
+
omega = 1. / 10000**omega # (D/2,)
|
65 |
+
|
66 |
+
pos = pos.reshape(-1) # (M,)
|
67 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
68 |
+
|
69 |
+
emb_sin = np.sin(out) # (M, D/2)
|
70 |
+
emb_cos = np.cos(out) # (M, D/2)
|
71 |
+
|
72 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
73 |
+
return emb
|
74 |
+
|
75 |
+
|
76 |
+
# --------------------------------------------------------
|
77 |
+
# Interpolate position embeddings for high-resolution
|
78 |
+
# References:
|
79 |
+
# DeiT: https://github.com/facebookresearch/deit
|
80 |
+
# --------------------------------------------------------
|
81 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
82 |
+
if 'pos_embed' in checkpoint_model:
|
83 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
84 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
85 |
+
num_patches = model.patch_embed.num_patches
|
86 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
87 |
+
# height (== width) for the checkpoint position embedding
|
88 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
89 |
+
# height (== width) for the new position embedding
|
90 |
+
new_size = int(num_patches ** 0.5)
|
91 |
+
# class_token and dist_token are kept unchanged
|
92 |
+
if orig_size != new_size:
|
93 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
94 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
95 |
+
# only the position tokens are interpolated
|
96 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
97 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
98 |
+
pos_tokens = torch.nn.functional.interpolate(
|
99 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
100 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
101 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
102 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
semanticist/stage1/transport/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
2 |
+
|
3 |
+
def create_transport(
|
4 |
+
path_type='Linear',
|
5 |
+
prediction="velocity",
|
6 |
+
loss_weight=None,
|
7 |
+
train_eps=None,
|
8 |
+
sample_eps=None,
|
9 |
+
):
|
10 |
+
"""function for creating Transport object
|
11 |
+
**Note**: model prediction defaults to velocity
|
12 |
+
Args:
|
13 |
+
- path_type: type of path to use; default to linear
|
14 |
+
- learn_score: set model prediction to score
|
15 |
+
- learn_noise: set model prediction to noise
|
16 |
+
- velocity_weighted: weight loss by velocity weight
|
17 |
+
- likelihood_weighted: weight loss by likelihood weight
|
18 |
+
- train_eps: small epsilon for avoiding instability during training
|
19 |
+
- sample_eps: small epsilon for avoiding instability during sampling
|
20 |
+
"""
|
21 |
+
|
22 |
+
if prediction == "noise":
|
23 |
+
model_type = ModelType.NOISE
|
24 |
+
elif prediction == "score":
|
25 |
+
model_type = ModelType.SCORE
|
26 |
+
else:
|
27 |
+
model_type = ModelType.VELOCITY
|
28 |
+
|
29 |
+
if loss_weight == "velocity":
|
30 |
+
loss_type = WeightType.VELOCITY
|
31 |
+
elif loss_weight == "likelihood":
|
32 |
+
loss_type = WeightType.LIKELIHOOD
|
33 |
+
else:
|
34 |
+
loss_type = WeightType.NONE
|
35 |
+
|
36 |
+
path_choice = {
|
37 |
+
"Linear": PathType.LINEAR,
|
38 |
+
"GVP": PathType.GVP,
|
39 |
+
"VP": PathType.VP,
|
40 |
+
}
|
41 |
+
|
42 |
+
path_type = path_choice[path_type]
|
43 |
+
|
44 |
+
if (path_type in [PathType.VP]):
|
45 |
+
train_eps = 1e-5 if train_eps is None else train_eps
|
46 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
47 |
+
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
48 |
+
train_eps = 1e-3 if train_eps is None else train_eps
|
49 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
50 |
+
else: # velocity & [GVP, LINEAR] is stable everywhere
|
51 |
+
train_eps = 0
|
52 |
+
sample_eps = 0
|
53 |
+
|
54 |
+
# create flow state
|
55 |
+
state = Transport(
|
56 |
+
model_type=model_type,
|
57 |
+
path_type=path_type,
|
58 |
+
loss_type=loss_type,
|
59 |
+
train_eps=train_eps,
|
60 |
+
sample_eps=sample_eps,
|
61 |
+
)
|
62 |
+
|
63 |
+
return state
|
semanticist/stage1/transport/integrators.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch as th
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchdiffeq import odeint
|
5 |
+
from functools import partial
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
class sde:
|
9 |
+
"""SDE solver class"""
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
drift,
|
13 |
+
diffusion,
|
14 |
+
*,
|
15 |
+
t0,
|
16 |
+
t1,
|
17 |
+
num_steps,
|
18 |
+
sampler_type,
|
19 |
+
temperature=1.0,
|
20 |
+
):
|
21 |
+
assert t0 < t1, "SDE sampler has to be in forward time"
|
22 |
+
|
23 |
+
self.num_timesteps = num_steps
|
24 |
+
self.t = th.linspace(t0, t1, num_steps)
|
25 |
+
self.dt = self.t[1] - self.t[0]
|
26 |
+
self.drift = drift
|
27 |
+
self.diffusion = diffusion
|
28 |
+
self.sampler_type = sampler_type
|
29 |
+
self.temperature = temperature
|
30 |
+
|
31 |
+
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
32 |
+
w_cur = th.randn(x.size()).to(x)
|
33 |
+
t = th.ones(x.size(0)).to(x) * t
|
34 |
+
dw = w_cur * th.sqrt(self.dt)
|
35 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
36 |
+
diffusion = self.diffusion(x, t)
|
37 |
+
mean_x = x + drift * self.dt
|
38 |
+
x = mean_x + th.sqrt(2 * diffusion) * dw * self.temperature
|
39 |
+
return x, mean_x
|
40 |
+
|
41 |
+
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
42 |
+
w_cur = th.randn(x.size()).to(x)
|
43 |
+
dw = w_cur * th.sqrt(self.dt) * self.temperature
|
44 |
+
t_cur = th.ones(x.size(0)).to(x) * t
|
45 |
+
diffusion = self.diffusion(x, t_cur)
|
46 |
+
xhat = x + th.sqrt(2 * diffusion) * dw
|
47 |
+
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
48 |
+
xp = xhat + self.dt * K1
|
49 |
+
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
50 |
+
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
51 |
+
|
52 |
+
def __forward_fn(self):
|
53 |
+
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
54 |
+
sampler_dict = {
|
55 |
+
"Euler": self.__Euler_Maruyama_step,
|
56 |
+
"Heun": self.__Heun_step,
|
57 |
+
}
|
58 |
+
|
59 |
+
try:
|
60 |
+
sampler = sampler_dict[self.sampler_type]
|
61 |
+
except:
|
62 |
+
raise NotImplementedError("Smapler type not implemented.")
|
63 |
+
|
64 |
+
return sampler
|
65 |
+
|
66 |
+
def sample(self, init, model, **model_kwargs):
|
67 |
+
"""forward loop of sde"""
|
68 |
+
x = init
|
69 |
+
mean_x = init
|
70 |
+
samples = []
|
71 |
+
sampler = self.__forward_fn()
|
72 |
+
for ti in self.t[:-1]:
|
73 |
+
with th.no_grad():
|
74 |
+
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
75 |
+
samples.append(x)
|
76 |
+
|
77 |
+
return samples
|
78 |
+
|
79 |
+
class ode:
|
80 |
+
"""ODE solver class"""
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
drift,
|
84 |
+
*,
|
85 |
+
t0,
|
86 |
+
t1,
|
87 |
+
sampler_type,
|
88 |
+
num_steps,
|
89 |
+
atol,
|
90 |
+
rtol,
|
91 |
+
temperature=1.0,
|
92 |
+
):
|
93 |
+
assert t0 < t1, "ODE sampler has to be in forward time"
|
94 |
+
|
95 |
+
self.drift = drift
|
96 |
+
self.t = th.linspace(t0, t1, num_steps)
|
97 |
+
self.atol = atol
|
98 |
+
self.rtol = rtol
|
99 |
+
self.sampler_type = sampler_type
|
100 |
+
self.temperature = temperature
|
101 |
+
|
102 |
+
def sample(self, x, model, **model_kwargs):
|
103 |
+
|
104 |
+
device = x[0].device if isinstance(x, tuple) else x.device
|
105 |
+
def _fn(t, x):
|
106 |
+
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
107 |
+
# For ODE, we scale the drift by the temperature
|
108 |
+
# This is equivalent to scaling time by 1/temperature
|
109 |
+
model_output = self.drift(x, t, model, **model_kwargs)
|
110 |
+
if self.temperature != 1.0:
|
111 |
+
# If it's a tuple (for likelihood calculation), only scale the first element
|
112 |
+
if isinstance(model_output, tuple):
|
113 |
+
scaled_output = (model_output[0] / self.temperature, model_output[1])
|
114 |
+
return scaled_output
|
115 |
+
else:
|
116 |
+
return model_output / self.temperature
|
117 |
+
return model_output
|
118 |
+
|
119 |
+
t = self.t.to(device)
|
120 |
+
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
121 |
+
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
122 |
+
samples = odeint(
|
123 |
+
_fn,
|
124 |
+
x,
|
125 |
+
t,
|
126 |
+
method=self.sampler_type,
|
127 |
+
atol=atol,
|
128 |
+
rtol=rtol
|
129 |
+
)
|
130 |
+
return samples
|
semanticist/stage1/transport/path.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
import numpy as np
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
def expand_t_like_x(t, x):
|
6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
7 |
+
Args:
|
8 |
+
t: [batch_dim,], time vector
|
9 |
+
x: [batch_dim,...], data point
|
10 |
+
"""
|
11 |
+
dims = [1] * (len(x.size()) - 1)
|
12 |
+
t = t.view(t.size(0), *dims)
|
13 |
+
return t
|
14 |
+
|
15 |
+
|
16 |
+
#################### Coupling Plans ####################
|
17 |
+
|
18 |
+
class ICPlan:
|
19 |
+
"""Linear Coupling Plan"""
|
20 |
+
def __init__(self, sigma=0.0):
|
21 |
+
self.sigma = sigma
|
22 |
+
|
23 |
+
def compute_alpha_t(self, t):
|
24 |
+
"""Compute the data coefficient along the path"""
|
25 |
+
return t, 1
|
26 |
+
|
27 |
+
def compute_sigma_t(self, t):
|
28 |
+
"""Compute the noise coefficient along the path"""
|
29 |
+
return 1 - t, -1
|
30 |
+
|
31 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
32 |
+
"""Compute the ratio between d_alpha and alpha"""
|
33 |
+
return 1 / t
|
34 |
+
|
35 |
+
def compute_drift(self, x, t):
|
36 |
+
"""We always output sde according to score parametrization; """
|
37 |
+
t = expand_t_like_x(t, x)
|
38 |
+
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
39 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
40 |
+
drift = alpha_ratio * x
|
41 |
+
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
42 |
+
|
43 |
+
return -drift, diffusion
|
44 |
+
|
45 |
+
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
46 |
+
"""Compute the diffusion term of the SDE
|
47 |
+
Args:
|
48 |
+
x: [batch_dim, ...], data point
|
49 |
+
t: [batch_dim,], time vector
|
50 |
+
form: str, form of the diffusion term
|
51 |
+
norm: float, norm of the diffusion term
|
52 |
+
"""
|
53 |
+
t = expand_t_like_x(t, x)
|
54 |
+
choices = {
|
55 |
+
"constant": norm,
|
56 |
+
"SBDM": norm * self.compute_drift(x, t)[1],
|
57 |
+
"sigma": norm * self.compute_sigma_t(t)[0],
|
58 |
+
"linear": norm * (1 - t),
|
59 |
+
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
60 |
+
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
61 |
+
}
|
62 |
+
|
63 |
+
try:
|
64 |
+
diffusion = choices[form]
|
65 |
+
except KeyError:
|
66 |
+
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
67 |
+
|
68 |
+
return diffusion
|
69 |
+
|
70 |
+
def get_score_from_velocity(self, velocity, x, t):
|
71 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
72 |
+
Args:
|
73 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
74 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
75 |
+
t: [batch_dim,] time tensor
|
76 |
+
"""
|
77 |
+
t = expand_t_like_x(t, x)
|
78 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
79 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
80 |
+
mean = x
|
81 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
82 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
83 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
84 |
+
return score
|
85 |
+
|
86 |
+
def get_noise_from_velocity(self, velocity, x, t):
|
87 |
+
"""Wrapper function: transfrom velocity prediction model to denoiser
|
88 |
+
Args:
|
89 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
90 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
91 |
+
t: [batch_dim,] time tensor
|
92 |
+
"""
|
93 |
+
t = expand_t_like_x(t, x)
|
94 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
95 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
96 |
+
mean = x
|
97 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
98 |
+
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
99 |
+
noise = (reverse_alpha_ratio * velocity - mean) / var
|
100 |
+
return noise
|
101 |
+
|
102 |
+
def get_velocity_from_score(self, score, x, t):
|
103 |
+
"""Wrapper function: transfrom score prediction model to velocity
|
104 |
+
Args:
|
105 |
+
score: [batch_dim, ...] shaped tensor; score model output
|
106 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
107 |
+
t: [batch_dim,] time tensor
|
108 |
+
"""
|
109 |
+
t = expand_t_like_x(t, x)
|
110 |
+
drift, var = self.compute_drift(x, t)
|
111 |
+
velocity = var * score - drift
|
112 |
+
return velocity
|
113 |
+
|
114 |
+
def compute_mu_t(self, t, x0, x1):
|
115 |
+
"""Compute the mean of time-dependent density p_t"""
|
116 |
+
t = expand_t_like_x(t, x1)
|
117 |
+
alpha_t, _ = self.compute_alpha_t(t)
|
118 |
+
sigma_t, _ = self.compute_sigma_t(t)
|
119 |
+
return alpha_t * x1 + sigma_t * x0
|
120 |
+
|
121 |
+
def compute_xt(self, t, x0, x1):
|
122 |
+
"""Sample xt from time-dependent density p_t; rng is required"""
|
123 |
+
xt = self.compute_mu_t(t, x0, x1)
|
124 |
+
return xt
|
125 |
+
|
126 |
+
def compute_ut(self, t, x0, x1, xt):
|
127 |
+
"""Compute the vector field corresponding to p_t"""
|
128 |
+
t = expand_t_like_x(t, x1)
|
129 |
+
_, d_alpha_t = self.compute_alpha_t(t)
|
130 |
+
_, d_sigma_t = self.compute_sigma_t(t)
|
131 |
+
return d_alpha_t * x1 + d_sigma_t * x0
|
132 |
+
|
133 |
+
def plan(self, t, x0, x1):
|
134 |
+
xt = self.compute_xt(t, x0, x1)
|
135 |
+
ut = self.compute_ut(t, x0, x1, xt)
|
136 |
+
return t, xt, ut
|
137 |
+
|
138 |
+
|
139 |
+
class VPCPlan(ICPlan):
|
140 |
+
"""class for VP path flow matching"""
|
141 |
+
|
142 |
+
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
143 |
+
self.sigma_min = sigma_min
|
144 |
+
self.sigma_max = sigma_max
|
145 |
+
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
146 |
+
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
147 |
+
|
148 |
+
|
149 |
+
def compute_alpha_t(self, t):
|
150 |
+
"""Compute coefficient of x1"""
|
151 |
+
alpha_t = self.log_mean_coeff(t)
|
152 |
+
alpha_t = th.exp(alpha_t)
|
153 |
+
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
154 |
+
return alpha_t, d_alpha_t
|
155 |
+
|
156 |
+
def compute_sigma_t(self, t):
|
157 |
+
"""Compute coefficient of x0"""
|
158 |
+
p_sigma_t = 2 * self.log_mean_coeff(t)
|
159 |
+
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
160 |
+
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
161 |
+
return sigma_t, d_sigma_t
|
162 |
+
|
163 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
164 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
165 |
+
return self.d_log_mean_coeff(t)
|
166 |
+
|
167 |
+
def compute_drift(self, x, t):
|
168 |
+
"""Compute the drift term of the SDE"""
|
169 |
+
t = expand_t_like_x(t, x)
|
170 |
+
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
171 |
+
return -0.5 * beta_t * x, beta_t / 2
|
172 |
+
|
173 |
+
|
174 |
+
class GVPCPlan(ICPlan):
|
175 |
+
def __init__(self, sigma=0.0):
|
176 |
+
super().__init__(sigma)
|
177 |
+
|
178 |
+
def compute_alpha_t(self, t):
|
179 |
+
"""Compute coefficient of x1"""
|
180 |
+
alpha_t = th.sin(t * np.pi / 2)
|
181 |
+
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
182 |
+
return alpha_t, d_alpha_t
|
183 |
+
|
184 |
+
def compute_sigma_t(self, t):
|
185 |
+
"""Compute coefficient of x0"""
|
186 |
+
sigma_t = th.cos(t * np.pi / 2)
|
187 |
+
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
188 |
+
return sigma_t, d_sigma_t
|
189 |
+
|
190 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
191 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
192 |
+
return np.pi / (2 * th.tan(t * np.pi / 2))
|
semanticist/stage1/transport/transport.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import enum
|
6 |
+
|
7 |
+
from . import path
|
8 |
+
from .utils import EasyDict, log_state, mean_flat
|
9 |
+
from .integrators import ode, sde
|
10 |
+
|
11 |
+
class ModelType(enum.Enum):
|
12 |
+
"""
|
13 |
+
Which type of output the model predicts.
|
14 |
+
"""
|
15 |
+
|
16 |
+
NOISE = enum.auto() # the model predicts epsilon
|
17 |
+
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
18 |
+
VELOCITY = enum.auto() # the model predicts v(x)
|
19 |
+
|
20 |
+
class PathType(enum.Enum):
|
21 |
+
"""
|
22 |
+
Which type of path to use.
|
23 |
+
"""
|
24 |
+
|
25 |
+
LINEAR = enum.auto()
|
26 |
+
GVP = enum.auto()
|
27 |
+
VP = enum.auto()
|
28 |
+
|
29 |
+
class WeightType(enum.Enum):
|
30 |
+
"""
|
31 |
+
Which type of weighting to use.
|
32 |
+
"""
|
33 |
+
|
34 |
+
NONE = enum.auto()
|
35 |
+
VELOCITY = enum.auto()
|
36 |
+
LIKELIHOOD = enum.auto()
|
37 |
+
|
38 |
+
|
39 |
+
class Transport:
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
*,
|
44 |
+
model_type,
|
45 |
+
path_type,
|
46 |
+
loss_type,
|
47 |
+
train_eps,
|
48 |
+
sample_eps,
|
49 |
+
):
|
50 |
+
path_options = {
|
51 |
+
PathType.LINEAR: path.ICPlan,
|
52 |
+
PathType.GVP: path.GVPCPlan,
|
53 |
+
PathType.VP: path.VPCPlan,
|
54 |
+
}
|
55 |
+
|
56 |
+
self.loss_type = loss_type
|
57 |
+
self.model_type = model_type
|
58 |
+
self.path_sampler = path_options[path_type]()
|
59 |
+
self.train_eps = train_eps
|
60 |
+
self.sample_eps = sample_eps
|
61 |
+
|
62 |
+
def prior_logp(self, z):
|
63 |
+
'''
|
64 |
+
Standard multivariate normal prior
|
65 |
+
Assume z is batched
|
66 |
+
'''
|
67 |
+
shape = th.tensor(z.size())
|
68 |
+
N = th.prod(shape[1:])
|
69 |
+
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
|
70 |
+
return th.vmap(_fn)(z)
|
71 |
+
|
72 |
+
|
73 |
+
def check_interval(
|
74 |
+
self,
|
75 |
+
train_eps,
|
76 |
+
sample_eps,
|
77 |
+
*,
|
78 |
+
diffusion_form="SBDM",
|
79 |
+
sde=False,
|
80 |
+
reverse=False,
|
81 |
+
eval=False,
|
82 |
+
last_step_size=0.0,
|
83 |
+
):
|
84 |
+
t0 = 0
|
85 |
+
t1 = 1
|
86 |
+
eps = train_eps if not eval else sample_eps
|
87 |
+
if (type(self.path_sampler) in [path.VPCPlan]):
|
88 |
+
|
89 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
90 |
+
|
91 |
+
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
|
92 |
+
and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
|
93 |
+
|
94 |
+
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
95 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
96 |
+
|
97 |
+
if reverse:
|
98 |
+
t0, t1 = 1 - t0, 1 - t1
|
99 |
+
|
100 |
+
return t0, t1
|
101 |
+
|
102 |
+
|
103 |
+
def sample(self, x1):
|
104 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
105 |
+
Args:
|
106 |
+
x1 - data point; [batch, *dim]
|
107 |
+
"""
|
108 |
+
|
109 |
+
x0 = th.randn_like(x1)
|
110 |
+
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
111 |
+
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
112 |
+
t = t.to(x1)
|
113 |
+
return t, x0, x1
|
114 |
+
|
115 |
+
|
116 |
+
def training_losses(
|
117 |
+
self,
|
118 |
+
model,
|
119 |
+
x1,
|
120 |
+
model_kwargs=None
|
121 |
+
):
|
122 |
+
"""Loss for training the score model
|
123 |
+
Args:
|
124 |
+
- model: backbone model; could be score, noise, or velocity
|
125 |
+
- x1: datapoint
|
126 |
+
- model_kwargs: additional arguments for the model
|
127 |
+
"""
|
128 |
+
if model_kwargs == None:
|
129 |
+
model_kwargs = {}
|
130 |
+
|
131 |
+
t, x0, x1 = self.sample(x1)
|
132 |
+
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
133 |
+
model_output = model(xt, t, **model_kwargs)
|
134 |
+
if len(model_output.shape) == len(xt.shape) + 1:
|
135 |
+
x0 = x0.unsqueeze(-1).expand(*([-1] * (len(x0.shape))), model_output.shape[-1])
|
136 |
+
xt = xt.unsqueeze(-1).expand(*([-1] * (len(xt.shape))), model_output.shape[-1])
|
137 |
+
ut = ut.unsqueeze(-1).expand(*([-1] * (len(ut.shape))), model_output.shape[-1])
|
138 |
+
B, C = xt.shape[:2]
|
139 |
+
assert model_output.shape == (B, C, *xt.shape[2:])
|
140 |
+
|
141 |
+
terms = {}
|
142 |
+
terms['pred'] = model_output
|
143 |
+
if self.model_type == ModelType.VELOCITY:
|
144 |
+
terms['loss'] = mean_flat(((model_output - ut) ** 2))
|
145 |
+
else:
|
146 |
+
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
147 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
148 |
+
if self.loss_type in [WeightType.VELOCITY]:
|
149 |
+
weight = (drift_var / sigma_t) ** 2
|
150 |
+
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
151 |
+
weight = drift_var / (sigma_t ** 2)
|
152 |
+
elif self.loss_type in [WeightType.NONE]:
|
153 |
+
weight = 1
|
154 |
+
else:
|
155 |
+
raise NotImplementedError()
|
156 |
+
|
157 |
+
if self.model_type == ModelType.NOISE:
|
158 |
+
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
|
159 |
+
else:
|
160 |
+
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
161 |
+
|
162 |
+
return terms
|
163 |
+
|
164 |
+
|
165 |
+
def get_drift(
|
166 |
+
self
|
167 |
+
):
|
168 |
+
"""member function for obtaining the drift of the probability flow ODE"""
|
169 |
+
def score_ode(x, t, model, **model_kwargs):
|
170 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
171 |
+
model_output = model(x, t, **model_kwargs)
|
172 |
+
return (-drift_mean + drift_var * model_output) # by change of variable
|
173 |
+
|
174 |
+
def noise_ode(x, t, model, **model_kwargs):
|
175 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
176 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
177 |
+
model_output = model(x, t, **model_kwargs)
|
178 |
+
score = model_output / -sigma_t
|
179 |
+
return (-drift_mean + drift_var * score)
|
180 |
+
|
181 |
+
def velocity_ode(x, t, model, **model_kwargs):
|
182 |
+
model_output = model(x, t, **model_kwargs)
|
183 |
+
return model_output
|
184 |
+
|
185 |
+
if self.model_type == ModelType.NOISE:
|
186 |
+
drift_fn = noise_ode
|
187 |
+
elif self.model_type == ModelType.SCORE:
|
188 |
+
drift_fn = score_ode
|
189 |
+
else:
|
190 |
+
drift_fn = velocity_ode
|
191 |
+
|
192 |
+
def body_fn(x, t, model, **model_kwargs):
|
193 |
+
model_output = drift_fn(x, t, model, **model_kwargs)
|
194 |
+
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
195 |
+
return model_output
|
196 |
+
|
197 |
+
return body_fn
|
198 |
+
|
199 |
+
|
200 |
+
def get_score(
|
201 |
+
self,
|
202 |
+
):
|
203 |
+
"""member function for obtaining score of
|
204 |
+
x_t = alpha_t * x + sigma_t * eps"""
|
205 |
+
if self.model_type == ModelType.NOISE:
|
206 |
+
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
207 |
+
elif self.model_type == ModelType.SCORE:
|
208 |
+
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
209 |
+
elif self.model_type == ModelType.VELOCITY:
|
210 |
+
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
|
211 |
+
else:
|
212 |
+
raise NotImplementedError()
|
213 |
+
|
214 |
+
return score_fn
|
215 |
+
|
216 |
+
|
217 |
+
class Sampler:
|
218 |
+
"""Sampler class for the transport model"""
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
transport,
|
222 |
+
):
|
223 |
+
"""Constructor for a general sampler; supporting different sampling methods
|
224 |
+
Args:
|
225 |
+
- transport: an tranport object specify model prediction & interpolant type
|
226 |
+
"""
|
227 |
+
|
228 |
+
self.transport = transport
|
229 |
+
self.drift = self.transport.get_drift()
|
230 |
+
self.score = self.transport.get_score()
|
231 |
+
|
232 |
+
def __get_sde_diffusion_and_drift(
|
233 |
+
self,
|
234 |
+
*,
|
235 |
+
diffusion_form="SBDM",
|
236 |
+
diffusion_norm=1.0,
|
237 |
+
):
|
238 |
+
|
239 |
+
def diffusion_fn(x, t):
|
240 |
+
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
241 |
+
return diffusion
|
242 |
+
|
243 |
+
sde_drift = \
|
244 |
+
lambda x, t, model, **kwargs: \
|
245 |
+
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
246 |
+
|
247 |
+
sde_diffusion = diffusion_fn
|
248 |
+
|
249 |
+
return sde_drift, sde_diffusion
|
250 |
+
|
251 |
+
def __get_last_step(
|
252 |
+
self,
|
253 |
+
sde_drift,
|
254 |
+
*,
|
255 |
+
last_step,
|
256 |
+
last_step_size,
|
257 |
+
):
|
258 |
+
"""Get the last step function of the SDE solver"""
|
259 |
+
|
260 |
+
if last_step is None:
|
261 |
+
last_step_fn = \
|
262 |
+
lambda x, t, model, **model_kwargs: \
|
263 |
+
x
|
264 |
+
elif last_step == "Mean":
|
265 |
+
last_step_fn = \
|
266 |
+
lambda x, t, model, **model_kwargs: \
|
267 |
+
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
268 |
+
elif last_step == "Tweedie":
|
269 |
+
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
270 |
+
sigma = self.transport.path_sampler.compute_sigma_t
|
271 |
+
last_step_fn = \
|
272 |
+
lambda x, t, model, **model_kwargs: \
|
273 |
+
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
|
274 |
+
elif last_step == "Euler":
|
275 |
+
last_step_fn = \
|
276 |
+
lambda x, t, model, **model_kwargs: \
|
277 |
+
x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
278 |
+
else:
|
279 |
+
raise NotImplementedError()
|
280 |
+
|
281 |
+
return last_step_fn
|
282 |
+
|
283 |
+
def sample_sde(
|
284 |
+
self,
|
285 |
+
*,
|
286 |
+
sampling_method="Euler",
|
287 |
+
diffusion_form="SBDM",
|
288 |
+
diffusion_norm=1.0,
|
289 |
+
last_step="Mean",
|
290 |
+
last_step_size=0.04,
|
291 |
+
num_steps=250,
|
292 |
+
temperature=1.0,
|
293 |
+
):
|
294 |
+
"""returns a sampling function with given SDE settings
|
295 |
+
Args:
|
296 |
+
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
297 |
+
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
298 |
+
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
299 |
+
- last_step: type of the last step; default to identity
|
300 |
+
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
301 |
+
- num_steps: total integration step of SDE
|
302 |
+
- temperature: temperature scaling for the noise during sampling; default to 1.0
|
303 |
+
"""
|
304 |
+
|
305 |
+
if last_step is None:
|
306 |
+
last_step_size = 0.0
|
307 |
+
|
308 |
+
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
309 |
+
diffusion_form=diffusion_form,
|
310 |
+
diffusion_norm=diffusion_norm,
|
311 |
+
)
|
312 |
+
|
313 |
+
t0, t1 = self.transport.check_interval(
|
314 |
+
self.transport.train_eps,
|
315 |
+
self.transport.sample_eps,
|
316 |
+
diffusion_form=diffusion_form,
|
317 |
+
sde=True,
|
318 |
+
eval=True,
|
319 |
+
reverse=False,
|
320 |
+
last_step_size=last_step_size,
|
321 |
+
)
|
322 |
+
|
323 |
+
_sde = sde(
|
324 |
+
sde_drift,
|
325 |
+
sde_diffusion,
|
326 |
+
t0=t0,
|
327 |
+
t1=t1,
|
328 |
+
num_steps=num_steps,
|
329 |
+
sampler_type=sampling_method,
|
330 |
+
temperature=temperature
|
331 |
+
)
|
332 |
+
|
333 |
+
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
334 |
+
|
335 |
+
|
336 |
+
def _sample(init, model, **model_kwargs):
|
337 |
+
xs = _sde.sample(init, model, **model_kwargs)
|
338 |
+
ts = th.ones(init.size(0), device=init.device) * t1
|
339 |
+
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
340 |
+
xs.append(x)
|
341 |
+
|
342 |
+
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
343 |
+
|
344 |
+
return xs
|
345 |
+
|
346 |
+
return _sample
|
347 |
+
|
348 |
+
def sample_ode(
|
349 |
+
self,
|
350 |
+
*,
|
351 |
+
sampling_method="dopri5",
|
352 |
+
num_steps=50,
|
353 |
+
atol=1e-6,
|
354 |
+
rtol=1e-3,
|
355 |
+
reverse=False,
|
356 |
+
temperature=1.0,
|
357 |
+
):
|
358 |
+
"""returns a sampling function with given ODE settings
|
359 |
+
Args:
|
360 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
361 |
+
- num_steps:
|
362 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
363 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
364 |
+
- atol: absolute error tolerance for the solver
|
365 |
+
- rtol: relative error tolerance for the solver
|
366 |
+
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
367 |
+
- temperature: temperature scaling for the drift during sampling; default to 1.0
|
368 |
+
"""
|
369 |
+
if reverse:
|
370 |
+
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
371 |
+
else:
|
372 |
+
drift = self.drift
|
373 |
+
|
374 |
+
t0, t1 = self.transport.check_interval(
|
375 |
+
self.transport.train_eps,
|
376 |
+
self.transport.sample_eps,
|
377 |
+
sde=False,
|
378 |
+
eval=True,
|
379 |
+
reverse=reverse,
|
380 |
+
last_step_size=0.0,
|
381 |
+
)
|
382 |
+
|
383 |
+
_ode = ode(
|
384 |
+
drift=drift,
|
385 |
+
t0=t0,
|
386 |
+
t1=t1,
|
387 |
+
sampler_type=sampling_method,
|
388 |
+
num_steps=num_steps,
|
389 |
+
atol=atol,
|
390 |
+
rtol=rtol,
|
391 |
+
temperature=temperature,
|
392 |
+
)
|
393 |
+
|
394 |
+
return _ode.sample
|
395 |
+
|
396 |
+
def sample_ode_likelihood(
|
397 |
+
self,
|
398 |
+
*,
|
399 |
+
sampling_method="dopri5",
|
400 |
+
num_steps=50,
|
401 |
+
atol=1e-6,
|
402 |
+
rtol=1e-3,
|
403 |
+
temperature=1.0,
|
404 |
+
):
|
405 |
+
|
406 |
+
"""returns a sampling function for calculating likelihood with given ODE settings
|
407 |
+
Args:
|
408 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
409 |
+
- num_steps:
|
410 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
411 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
412 |
+
- atol: absolute error tolerance for the solver
|
413 |
+
- rtol: relative error tolerance for the solver
|
414 |
+
- temperature: temperature scaling for the drift during sampling; default to 1.0
|
415 |
+
"""
|
416 |
+
def _likelihood_drift(x, t, model, **model_kwargs):
|
417 |
+
x, _ = x
|
418 |
+
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
419 |
+
t = th.ones_like(t) * (1 - t)
|
420 |
+
with th.enable_grad():
|
421 |
+
x.requires_grad = True
|
422 |
+
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
423 |
+
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
424 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
425 |
+
return (-drift, logp_grad)
|
426 |
+
|
427 |
+
t0, t1 = self.transport.check_interval(
|
428 |
+
self.transport.train_eps,
|
429 |
+
self.transport.sample_eps,
|
430 |
+
sde=False,
|
431 |
+
eval=True,
|
432 |
+
reverse=False,
|
433 |
+
last_step_size=0.0,
|
434 |
+
)
|
435 |
+
|
436 |
+
_ode = ode(
|
437 |
+
drift=_likelihood_drift,
|
438 |
+
t0=t0,
|
439 |
+
t1=t1,
|
440 |
+
sampler_type=sampling_method,
|
441 |
+
num_steps=num_steps,
|
442 |
+
atol=atol,
|
443 |
+
rtol=rtol,
|
444 |
+
temperature=temperature,
|
445 |
+
)
|
446 |
+
|
447 |
+
def _sample_fn(x, model, **model_kwargs):
|
448 |
+
init_logp = th.zeros(x.size(0)).to(x)
|
449 |
+
input = (x, init_logp)
|
450 |
+
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
451 |
+
drift, delta_logp = drift[-1], delta_logp[-1]
|
452 |
+
prior_logp = self.transport.prior_logp(drift)
|
453 |
+
logp = prior_logp - delta_logp
|
454 |
+
return logp, drift
|
455 |
+
|
456 |
+
return _sample_fn
|
semanticist/stage1/transport/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
|
3 |
+
class EasyDict:
|
4 |
+
|
5 |
+
def __init__(self, sub_dict):
|
6 |
+
for k, v in sub_dict.items():
|
7 |
+
setattr(self, k, v)
|
8 |
+
|
9 |
+
def __getitem__(self, key):
|
10 |
+
return getattr(self, key)
|
11 |
+
|
12 |
+
def mean_flat(x):
|
13 |
+
"""
|
14 |
+
Take the mean over all non-batch dimensions.
|
15 |
+
"""
|
16 |
+
return th.mean(x, dim=list(range(1, len(x.size()))))
|
17 |
+
|
18 |
+
def log_state(state):
|
19 |
+
result = []
|
20 |
+
|
21 |
+
sorted_state = dict(sorted(state.items()))
|
22 |
+
for key, value in sorted_state.items():
|
23 |
+
# Check if the value is an instance of a class
|
24 |
+
if "<object" in str(value) or "object at" in str(value):
|
25 |
+
result.append(f"{key}: [{value.__class__.__name__}]")
|
26 |
+
else:
|
27 |
+
result.append(f"{key}: {value}")
|
28 |
+
|
29 |
+
return '\n'.join(result)
|
semanticist/stage1/vision_transformer.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Mostly copy-paste from timm library.
|
16 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
|
22 |
+
from functools import partial
|
23 |
+
from semanticist.stage1.fused_attention import Attention
|
24 |
+
|
25 |
+
__all__ = ['VisionTransformer', 'vit_tiny_patch16', 'vit_small_patch16',
|
26 |
+
'vit_base_patch16', 'vit_large_patch16', 'vit_huge_patch14']
|
27 |
+
|
28 |
+
|
29 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
30 |
+
if drop_prob == 0. or not training:
|
31 |
+
return x
|
32 |
+
keep_prob = 1 - drop_prob
|
33 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
34 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
35 |
+
if keep_prob > 0.0:
|
36 |
+
random_tensor.div_(keep_prob)
|
37 |
+
return x * random_tensor
|
38 |
+
|
39 |
+
|
40 |
+
class DropPath(nn.Module):
|
41 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, drop_prob=None):
|
45 |
+
super(DropPath, self).__init__()
|
46 |
+
self.drop_prob = drop_prob
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return drop_path(x, self.drop_prob, self.training)
|
50 |
+
|
51 |
+
|
52 |
+
class Mlp(nn.Module):
|
53 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
54 |
+
super().__init__()
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
58 |
+
self.act = act_layer()
|
59 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
60 |
+
self.drop = nn.Dropout(drop)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
x = self.fc1(x)
|
64 |
+
x = self.act(x)
|
65 |
+
x = self.drop(x)
|
66 |
+
x = self.fc2(x)
|
67 |
+
x = self.drop(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class Block(nn.Module):
|
72 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
|
73 |
+
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0):
|
74 |
+
super().__init__()
|
75 |
+
self.norm1 = norm_layer(dim)
|
76 |
+
self.attn = Attention(
|
77 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
78 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
79 |
+
self.norm2 = norm_layer(dim)
|
80 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
81 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
82 |
+
|
83 |
+
if init_values > 0:
|
84 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
85 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
86 |
+
else:
|
87 |
+
self.gamma_1, self.gamma_2 = None, None
|
88 |
+
|
89 |
+
def forward(self, x, attn_mask=None):
|
90 |
+
y = self.attn(self.norm1(x), attn_mask=attn_mask)
|
91 |
+
if self.gamma_1 is None:
|
92 |
+
x = x + self.drop_path(y)
|
93 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
94 |
+
else:
|
95 |
+
x = x + self.drop_path(self.gamma_1 * y)
|
96 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class PatchEmbed(nn.Module):
|
101 |
+
""" Image to Patch Embedding
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
105 |
+
super().__init__()
|
106 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
107 |
+
self.img_size = img_size
|
108 |
+
self.patch_size = patch_size
|
109 |
+
self.num_patches = num_patches
|
110 |
+
|
111 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
B, C, H, W = x.shape
|
115 |
+
return self.proj(x)
|
116 |
+
|
117 |
+
|
118 |
+
class VisionTransformer(nn.Module):
|
119 |
+
""" Vision Transformer """
|
120 |
+
|
121 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12,
|
122 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
|
123 |
+
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
124 |
+
init_values=0, num_slots=16):
|
125 |
+
super().__init__()
|
126 |
+
self.num_features = self.embed_dim = embed_dim
|
127 |
+
|
128 |
+
self.patch_embed = PatchEmbed(
|
129 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
130 |
+
num_patches = self.patch_embed.num_patches
|
131 |
+
|
132 |
+
self.num_slots = num_slots
|
133 |
+
|
134 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
135 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + self.num_slots, embed_dim))
|
136 |
+
self.slot_embed = nn.Parameter(torch.zeros(1, num_slots, embed_dim))
|
137 |
+
|
138 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
139 |
+
|
140 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
141 |
+
self.blocks = nn.ModuleList([
|
142 |
+
Block(
|
143 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
144 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
145 |
+
init_values=init_values)
|
146 |
+
for i in range(depth)])
|
147 |
+
|
148 |
+
self.norm = norm_layer(embed_dim)
|
149 |
+
|
150 |
+
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
151 |
+
nn.init.trunc_normal_(self.cls_token, std=.02)
|
152 |
+
nn.init.trunc_normal_(self.slot_embed, std=.02)
|
153 |
+
self.apply(self._init_weights)
|
154 |
+
|
155 |
+
def _init_weights(self, m):
|
156 |
+
if isinstance(m, nn.Linear):
|
157 |
+
nn.init.trunc_normal_(m.weight, std=.02)
|
158 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
159 |
+
nn.init.constant_(m.bias, 0)
|
160 |
+
elif isinstance(m, nn.LayerNorm):
|
161 |
+
nn.init.constant_(m.bias, 0)
|
162 |
+
nn.init.constant_(m.weight, 1.0)
|
163 |
+
|
164 |
+
def interpolate_pos_encoding(self, x, w, h):
|
165 |
+
npatch = x.shape[1] - 1 - self.num_slots
|
166 |
+
N = self.pos_embed.shape[1] - 1 - self.num_slots
|
167 |
+
if npatch == N and w == h:
|
168 |
+
return self.pos_embed
|
169 |
+
class_pos_embed = self.pos_embed[:, 0]
|
170 |
+
patch_pos_embed = self.pos_embed[:, 1:1+npatch]
|
171 |
+
dim = x.shape[-1]
|
172 |
+
w0 = w // self.patch_embed.patch_size[0]
|
173 |
+
h0 = h // self.patch_embed.patch_size[1]
|
174 |
+
# we add a small number to avoid floating point error in the interpolation
|
175 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
176 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
177 |
+
patch_pos_embed = nn.functional.interpolate(
|
178 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
179 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
180 |
+
mode='bicubic',
|
181 |
+
)
|
182 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
183 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
184 |
+
|
185 |
+
slots_pos_embed = self.pos_embed[:, 1+npatch:]
|
186 |
+
slots_pos_embed = slots_pos_embed.view(1, -1, dim) # (1, num_slots, dim)
|
187 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, slots_pos_embed), dim=1)
|
188 |
+
|
189 |
+
def prepare_tokens(self, x):
|
190 |
+
B, nc, w, h = x.shape
|
191 |
+
x = self.patch_embed(x)
|
192 |
+
x = x.flatten(2).transpose(1, 2)
|
193 |
+
x = torch.cat((self.cls_token.expand(B, -1, -1), x, self.slot_embed.expand(B, -1, -1)), dim=1)
|
194 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
195 |
+
return self.pos_drop(x)
|
196 |
+
|
197 |
+
def forward(self, x, is_causal=True):
|
198 |
+
x = self.prepare_tokens(x)
|
199 |
+
if is_causal:
|
200 |
+
attn_mask = torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool)
|
201 |
+
# slots are causal to each other
|
202 |
+
causal_mask = torch.ones(self.num_slots, self.num_slots, device=x.device, dtype=torch.bool).tril(diagonal=0)
|
203 |
+
attn_mask[-self.num_slots:, -self.num_slots:] = causal_mask
|
204 |
+
# cls token and patches should not see slots
|
205 |
+
attn_mask[:-self.num_slots, -self.num_slots:] = False
|
206 |
+
else:
|
207 |
+
attn_mask = None
|
208 |
+
|
209 |
+
for blk in self.blocks:
|
210 |
+
x = blk(x, attn_mask=attn_mask)
|
211 |
+
|
212 |
+
x = self.norm(x)
|
213 |
+
outcome = x[:, -self.num_slots:] # return the slots
|
214 |
+
return outcome
|
215 |
+
|
216 |
+
def get_intermediate_layers(self, x, n=1):
|
217 |
+
x = self.prepare_tokens(x)
|
218 |
+
# we return the output tokens from the `n` last blocks
|
219 |
+
output = []
|
220 |
+
for i, blk in enumerate(self.blocks):
|
221 |
+
x = blk(x)
|
222 |
+
if len(self.blocks) - i <= n:
|
223 |
+
output.append(self.norm(x))
|
224 |
+
return output
|
225 |
+
|
226 |
+
|
227 |
+
def vit_tiny_patch16(**kwargs):
|
228 |
+
model = VisionTransformer(
|
229 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
230 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
231 |
+
return model
|
232 |
+
|
233 |
+
|
234 |
+
def vit_small_patch16(**kwargs):
|
235 |
+
model = VisionTransformer(
|
236 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
237 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
238 |
+
return model
|
239 |
+
|
240 |
+
|
241 |
+
def vit_base_patch16(**kwargs):
|
242 |
+
model = VisionTransformer(
|
243 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
244 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
245 |
+
return model
|
246 |
+
|
247 |
+
|
248 |
+
def vit_large_patch16(**kwargs):
|
249 |
+
model = VisionTransformer(
|
250 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
251 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
252 |
+
return model
|
253 |
+
|
254 |
+
|
255 |
+
def vit_huge_patch14(**kwargs):
|
256 |
+
model = VisionTransformer(
|
257 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
258 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
259 |
+
return model
|
semanticist/stage2/diffloss.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from semanticist.stage1.diffusion import create_diffusion
|
6 |
+
from semanticist.stage1.transport import create_transport, Sampler
|
7 |
+
|
8 |
+
|
9 |
+
class DiffLoss(nn.Module):
|
10 |
+
"""Diffusion Loss"""
|
11 |
+
def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, predict_xstart=False, use_si=False, cond_method="adaln"):
|
12 |
+
super(DiffLoss, self).__init__()
|
13 |
+
self.in_channels = target_channels
|
14 |
+
self.net = SimpleMLPAdaLN(
|
15 |
+
in_channels=target_channels,
|
16 |
+
model_channels=width,
|
17 |
+
out_channels=target_channels * 2 if not use_si else target_channels, # for vlb loss
|
18 |
+
z_channels=z_channels,
|
19 |
+
num_res_blocks=depth,
|
20 |
+
cond_method=cond_method,
|
21 |
+
)
|
22 |
+
self.use_si = use_si
|
23 |
+
if not use_si:
|
24 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine", predict_xstart=predict_xstart)
|
25 |
+
self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine", predict_xstart=predict_xstart)
|
26 |
+
else:
|
27 |
+
self.transport = create_transport()
|
28 |
+
self.sampler = Sampler(self.transport)
|
29 |
+
|
30 |
+
def forward(self, target, z, mask=None):
|
31 |
+
model_kwargs = dict(c=z)
|
32 |
+
if not self.use_si:
|
33 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
34 |
+
loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
|
35 |
+
else:
|
36 |
+
loss_dict = self.transport.training_losses(self.net, target, model_kwargs)
|
37 |
+
loss = loss_dict["loss"]
|
38 |
+
if mask is not None:
|
39 |
+
loss = (loss * mask).sum() / mask.sum()
|
40 |
+
return loss.mean()
|
41 |
+
|
42 |
+
def sample(self, z, temperature=1.0, cfg=1.0):
|
43 |
+
# diffusion loss sampling
|
44 |
+
device = z.device
|
45 |
+
if not cfg == 1.0:
|
46 |
+
noise = torch.randn(z.shape[0] // 2, self.in_channels, device=device)
|
47 |
+
noise = torch.cat([noise, noise], dim=0)
|
48 |
+
model_kwargs = dict(c=z, cfg_scale=cfg)
|
49 |
+
sample_fn = self.net.forward_with_cfg
|
50 |
+
else:
|
51 |
+
noise = torch.randn(z.shape[0], self.in_channels, device=device)
|
52 |
+
model_kwargs = dict(c=z)
|
53 |
+
sample_fn = self.net.forward
|
54 |
+
|
55 |
+
if not self.use_si:
|
56 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
57 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
|
58 |
+
temperature=temperature, device=device
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma", temperature=temperature)
|
62 |
+
sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1]
|
63 |
+
if cfg != 1.0:
|
64 |
+
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
|
65 |
+
return sampled_token_latent
|
66 |
+
|
67 |
+
|
68 |
+
def modulate(x, shift, scale):
|
69 |
+
return x * (1 + scale) + shift
|
70 |
+
|
71 |
+
|
72 |
+
class TimestepEmbedder(nn.Module):
|
73 |
+
"""
|
74 |
+
Embeds scalar timesteps into vector representations.
|
75 |
+
"""
|
76 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
77 |
+
super().__init__()
|
78 |
+
self.mlp = nn.Sequential(
|
79 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
80 |
+
nn.SiLU(),
|
81 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
82 |
+
)
|
83 |
+
self.frequency_embedding_size = frequency_embedding_size
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def timestep_embedding(t, dim, max_period=10000):
|
87 |
+
"""
|
88 |
+
Create sinusoidal timestep embeddings.
|
89 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
90 |
+
These may be fractional.
|
91 |
+
:param dim: the dimension of the output.
|
92 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
93 |
+
:return: an (N, D) Tensor of positional embeddings.
|
94 |
+
"""
|
95 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
96 |
+
half = dim // 2
|
97 |
+
freqs = torch.exp(
|
98 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
99 |
+
).to(device=t.device)
|
100 |
+
args = t[:, None].float() * freqs[None]
|
101 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
102 |
+
if dim % 2:
|
103 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
104 |
+
return embedding
|
105 |
+
|
106 |
+
def forward(self, t):
|
107 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
108 |
+
t_emb = self.mlp(t_freq)
|
109 |
+
return t_emb
|
110 |
+
|
111 |
+
|
112 |
+
class ResBlock(nn.Module):
|
113 |
+
"""
|
114 |
+
A residual block with AdaLN for timestep and optional concatenation for condition.
|
115 |
+
"""
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
channels,
|
119 |
+
cond_method="adaln",
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.channels = channels
|
123 |
+
self.cond_method = cond_method
|
124 |
+
|
125 |
+
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
126 |
+
self.adaLN_modulation = nn.Sequential(
|
127 |
+
nn.SiLU(),
|
128 |
+
nn.Linear(channels, 3 * channels, bias=True)
|
129 |
+
)
|
130 |
+
|
131 |
+
# Input dimension depends on conditioning method
|
132 |
+
mlp_in_dim = channels * 2 if cond_method == "concat" else channels
|
133 |
+
self.mlp = nn.Sequential(
|
134 |
+
nn.Linear(mlp_in_dim, channels, bias=True),
|
135 |
+
nn.SiLU(),
|
136 |
+
nn.Linear(channels, channels, bias=True),
|
137 |
+
)
|
138 |
+
|
139 |
+
def forward(self, x, t, c=None):
|
140 |
+
# Apply timestep embedding via AdaLN
|
141 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(3, dim=-1)
|
142 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
143 |
+
|
144 |
+
# Concatenate condition if using concat method
|
145 |
+
if self.cond_method == "concat" and c is not None:
|
146 |
+
h = torch.cat([h, c], dim=-1)
|
147 |
+
|
148 |
+
h = self.mlp(h)
|
149 |
+
x = x + gate_mlp * h
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class FinalLayer(nn.Module):
|
154 |
+
"""
|
155 |
+
Final layer with AdaLN for timestep and optional concatenation for condition.
|
156 |
+
"""
|
157 |
+
def __init__(self, model_channels, out_channels, cond_method="adaln"):
|
158 |
+
super().__init__()
|
159 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
160 |
+
self.cond_method = cond_method
|
161 |
+
|
162 |
+
self.adaLN_modulation = nn.Sequential(
|
163 |
+
nn.SiLU(),
|
164 |
+
nn.Linear(model_channels, 2 * model_channels, bias=True)
|
165 |
+
)
|
166 |
+
|
167 |
+
# Output dimension depends on conditioning method
|
168 |
+
linear_in_dim = model_channels * 2 if cond_method == "concat" else model_channels
|
169 |
+
self.linear = nn.Linear(linear_in_dim, out_channels, bias=True)
|
170 |
+
|
171 |
+
def forward(self, x, t, c=None):
|
172 |
+
# Apply timestep embedding via AdaLN
|
173 |
+
shift, scale = self.adaLN_modulation(t).chunk(2, dim=-1)
|
174 |
+
x = modulate(self.norm_final(x), shift, scale)
|
175 |
+
|
176 |
+
# Concatenate condition if using concat method
|
177 |
+
if self.cond_method == "concat" and c is not None:
|
178 |
+
x = torch.cat([x, c], dim=-1)
|
179 |
+
|
180 |
+
return self.linear(x)
|
181 |
+
|
182 |
+
|
183 |
+
class SimpleMLPAdaLN(nn.Module):
|
184 |
+
"""
|
185 |
+
MLP for Diffusion Loss with AdaLN for timestep and optional concatenation for condition.
|
186 |
+
"""
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
in_channels,
|
190 |
+
model_channels,
|
191 |
+
out_channels,
|
192 |
+
z_channels,
|
193 |
+
num_res_blocks,
|
194 |
+
cond_method="adaln"
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
self.in_channels = in_channels
|
198 |
+
self.model_channels = model_channels
|
199 |
+
self.out_channels = out_channels
|
200 |
+
self.cond_method = cond_method
|
201 |
+
|
202 |
+
self.time_embed = TimestepEmbedder(model_channels)
|
203 |
+
self.cond_embed = nn.Linear(z_channels, model_channels)
|
204 |
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
205 |
+
|
206 |
+
# Create residual blocks
|
207 |
+
res_blocks = [ResBlock(model_channels, cond_method) for _ in range(num_res_blocks)]
|
208 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
209 |
+
|
210 |
+
self.final_layer = FinalLayer(model_channels, out_channels, cond_method=cond_method)
|
211 |
+
self.initialize_weights()
|
212 |
+
|
213 |
+
def initialize_weights(self):
|
214 |
+
# Basic initialization for all linear layers
|
215 |
+
def _basic_init(module):
|
216 |
+
if isinstance(module, nn.Linear):
|
217 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
218 |
+
if module.bias is not None:
|
219 |
+
nn.init.constant_(module.bias, 0)
|
220 |
+
self.apply(_basic_init)
|
221 |
+
|
222 |
+
# Initialize timestep embedding MLP
|
223 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
224 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
225 |
+
|
226 |
+
# Zero-out adaLN modulation layers (always used for timestep)
|
227 |
+
for i, block in enumerate(self.res_blocks):
|
228 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
229 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
230 |
+
|
231 |
+
# Zero-out output layers
|
232 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
233 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
234 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
235 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
236 |
+
|
237 |
+
def forward(self, x, t, c):
|
238 |
+
"""
|
239 |
+
Apply the model to an input batch.
|
240 |
+
:param x: an [N x C] Tensor of inputs.
|
241 |
+
:param t: a 1-D batch of timesteps.
|
242 |
+
:param c: conditioning from AR transformer.
|
243 |
+
:return: an [N x C] Tensor of outputs.
|
244 |
+
"""
|
245 |
+
x = self.input_proj(x)
|
246 |
+
t_emb = self.time_embed(t)
|
247 |
+
c_emb = self.cond_embed(c)
|
248 |
+
|
249 |
+
# Prepare conditioning based on method
|
250 |
+
if self.cond_method == "adaln":
|
251 |
+
t_combined, c_for_concat = t_emb + c_emb, None
|
252 |
+
else: # concat
|
253 |
+
t_combined, c_for_concat = t_emb, c_emb
|
254 |
+
|
255 |
+
for block in self.res_blocks:
|
256 |
+
x = block(x, t_combined, c_for_concat)
|
257 |
+
return self.final_layer(x, t_combined, c_for_concat)
|
258 |
+
|
259 |
+
def forward_with_cfg(self, x, t, c, cfg_scale):
|
260 |
+
half = x[: len(x) // 2]
|
261 |
+
combined = torch.cat([half, half], dim=0)
|
262 |
+
model_out = self.forward(combined, t, c)
|
263 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
264 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
265 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
266 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
267 |
+
return torch.cat([eps, rest], dim=1)
|
semanticist/stage2/generate.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, temperature: float = 1.0):
|
7 |
+
tokens = model(None, cond_idx, input_pos, cfg=cfg_scale, temperature=temperature)
|
8 |
+
return tokens.unsqueeze(1)
|
9 |
+
|
10 |
+
|
11 |
+
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, temperature: float = 1.0):
|
12 |
+
assert input_pos.shape[-1] == 1
|
13 |
+
if cfg_scale > 1.0:
|
14 |
+
x = torch.cat([x, x])
|
15 |
+
tokens = model(x, cond_idx=None, input_pos=input_pos, cfg=cfg_scale, temperature=temperature)
|
16 |
+
return tokens
|
17 |
+
|
18 |
+
|
19 |
+
def decode_n_tokens(
|
20 |
+
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
|
21 |
+
cfg_scale: float, cfg_schedule = "constant", temperature: float = 1.0):
|
22 |
+
new_tokens = []
|
23 |
+
for i in range(num_new_tokens):
|
24 |
+
cfg_iter = get_cfg(cfg_scale, i + 1, num_new_tokens + 1, cfg_schedule)
|
25 |
+
next_token = decode_one_token(model, cur_token, input_pos, cfg_iter, temperature=temperature).unsqueeze(1)
|
26 |
+
input_pos += 1
|
27 |
+
new_tokens.append(next_token.clone())
|
28 |
+
cur_token = next_token
|
29 |
+
|
30 |
+
return new_tokens
|
31 |
+
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_schedule = "constant", temperature: float = 1.0):
|
35 |
+
if cfg_scale > 1.0:
|
36 |
+
cond_null = torch.ones_like(cond) * model.num_classes
|
37 |
+
cond_combined = torch.cat([cond, cond_null])
|
38 |
+
else:
|
39 |
+
cond_combined = cond
|
40 |
+
T = model.cls_token_num
|
41 |
+
|
42 |
+
T_new = T + max_new_tokens
|
43 |
+
max_seq_length = T_new
|
44 |
+
max_batch_size = cond.shape[0]
|
45 |
+
|
46 |
+
device = cond.device
|
47 |
+
dtype = model.z_proj.weight.dtype
|
48 |
+
if torch.is_autocast_enabled():
|
49 |
+
dtype = torch.get_autocast_dtype(device_type=device.type)
|
50 |
+
with torch.device(device):
|
51 |
+
max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
|
52 |
+
model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=dtype)
|
53 |
+
|
54 |
+
if emb_masks is not None:
|
55 |
+
assert emb_masks.shape[0] == max_batch_size
|
56 |
+
assert emb_masks.shape[-1] == T
|
57 |
+
if cfg_scale > 1.0:
|
58 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
|
59 |
+
else:
|
60 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
|
61 |
+
|
62 |
+
eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
|
63 |
+
model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
|
64 |
+
|
65 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
66 |
+
seq = torch.empty((max_batch_size, T_new, model.slot_dim), dtype=dtype, device=device)
|
67 |
+
|
68 |
+
input_pos = torch.arange(0, T, device=device)
|
69 |
+
cfg_iter = get_cfg(cfg_scale, 0, max_new_tokens, cfg_schedule)
|
70 |
+
next_token = prefill(model, cond_combined, input_pos, cfg_iter, temperature=temperature)
|
71 |
+
seq[:, T:T+1] = next_token
|
72 |
+
|
73 |
+
if max_new_tokens > 1:
|
74 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
75 |
+
generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, cfg_scale, cfg_schedule=cfg_schedule, temperature=temperature)
|
76 |
+
seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
|
77 |
+
|
78 |
+
model.reset_caches()
|
79 |
+
return seq[:, T:]
|
80 |
+
|
81 |
+
|
82 |
+
def get_cfg(cfg, cur_step, total_step, cfg_schedule="constant"):
|
83 |
+
if cfg_schedule == "linear":
|
84 |
+
return 1 + (cfg - 1) * (cur_step + 1) / total_step
|
85 |
+
elif cfg_schedule == "constant":
|
86 |
+
return cfg
|
87 |
+
else:
|
88 |
+
raise NotImplementedError
|
semanticist/stage2/gpt.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
|
5 |
+
# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
|
6 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
7 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
8 |
+
from typing import Optional, List, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from semanticist.stage1.vision_transformer import DropPath
|
15 |
+
from semanticist.stage2.diffloss import DiffLoss
|
16 |
+
|
17 |
+
def find_multiple(n: int, k: int):
|
18 |
+
if n % k == 0:
|
19 |
+
return n
|
20 |
+
return n + k - (n % k)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
#################################################################################
|
25 |
+
# Embedding Layers for Class Labels #
|
26 |
+
#################################################################################
|
27 |
+
class LabelEmbedder(nn.Module):
|
28 |
+
"""
|
29 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
30 |
+
"""
|
31 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
32 |
+
super().__init__()
|
33 |
+
use_cfg_embedding = dropout_prob > 0
|
34 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
35 |
+
self.num_classes = num_classes
|
36 |
+
self.dropout_prob = dropout_prob
|
37 |
+
|
38 |
+
def token_drop(self, labels, force_drop_ids=None):
|
39 |
+
"""
|
40 |
+
Drops labels to enable classifier-free guidance.
|
41 |
+
"""
|
42 |
+
if force_drop_ids is None:
|
43 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
44 |
+
else:
|
45 |
+
drop_ids = force_drop_ids == 1
|
46 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
47 |
+
return labels
|
48 |
+
|
49 |
+
def forward(self, labels, train, force_drop_ids=None):
|
50 |
+
use_dropout = self.dropout_prob > 0
|
51 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
52 |
+
labels = self.token_drop(labels, force_drop_ids)
|
53 |
+
embeddings = self.embedding_table(labels).unsqueeze(1)
|
54 |
+
return embeddings
|
55 |
+
|
56 |
+
|
57 |
+
class MLP(nn.Module):
|
58 |
+
def __init__(self, in_features, hidden_features, out_features):
|
59 |
+
super().__init__()
|
60 |
+
out_features = out_features or in_features
|
61 |
+
hidden_features = hidden_features or in_features
|
62 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
63 |
+
self.act = nn.GELU(approximate='tanh')
|
64 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.fc1(x)
|
68 |
+
x = self.act(x)
|
69 |
+
x = self.fc2(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
#################################################################################
|
74 |
+
# GPT Model #
|
75 |
+
#################################################################################
|
76 |
+
class RMSNorm(torch.nn.Module):
|
77 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
78 |
+
super().__init__()
|
79 |
+
self.eps = eps
|
80 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
81 |
+
|
82 |
+
def _norm(self, x):
|
83 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
output = self._norm(x.float()).type_as(x)
|
87 |
+
return output * self.weight
|
88 |
+
|
89 |
+
|
90 |
+
class FeedForward(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
dim: int,
|
94 |
+
multiple_of: int = 256,
|
95 |
+
ffn_dropout_p: float = 0.0,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
hidden_dim = 4 * dim
|
99 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
100 |
+
hidden_dim = find_multiple(hidden_dim, multiple_of)
|
101 |
+
|
102 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
103 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
104 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
105 |
+
self.ffn_dropout = nn.Dropout(ffn_dropout_p)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
109 |
+
|
110 |
+
|
111 |
+
class KVCache(nn.Module):
|
112 |
+
def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
|
113 |
+
super().__init__()
|
114 |
+
cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
|
115 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
116 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
117 |
+
|
118 |
+
def update(self, input_pos, k_val, v_val):
|
119 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
120 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
121 |
+
k_out = self.k_cache
|
122 |
+
v_out = self.v_cache
|
123 |
+
k_out[:, :, input_pos] = k_val
|
124 |
+
v_out[:, :, input_pos] = v_val
|
125 |
+
|
126 |
+
return k_out, v_out
|
127 |
+
|
128 |
+
|
129 |
+
class Attention(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
dim: int,
|
133 |
+
n_head: int,
|
134 |
+
attn_dropout_p: float = 0.0,
|
135 |
+
resid_dropout_p: float = 0.1,
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
assert dim % n_head == 0
|
139 |
+
self.dim = dim
|
140 |
+
self.head_dim = dim // n_head
|
141 |
+
self.n_head = n_head
|
142 |
+
|
143 |
+
# key, query, value projections for all heads, but in a batch
|
144 |
+
self.wqkv = nn.Linear(dim, dim * 3, bias=False)
|
145 |
+
self.wo = nn.Linear(dim, dim, bias=False)
|
146 |
+
self.kv_cache = None
|
147 |
+
|
148 |
+
# regularization
|
149 |
+
self.attn_dropout_p = attn_dropout_p
|
150 |
+
self.resid_dropout = nn.Dropout(resid_dropout_p)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self, x: torch.Tensor,
|
154 |
+
input_pos: Optional[torch.Tensor] = None,
|
155 |
+
mask: Optional[torch.Tensor] = None
|
156 |
+
):
|
157 |
+
bsz, seqlen, _ = x.shape
|
158 |
+
xq, xk, xv = self.wqkv(x).split([self.dim, self.dim, self.dim], dim=-1)
|
159 |
+
|
160 |
+
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
|
161 |
+
xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
|
162 |
+
xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
|
163 |
+
|
164 |
+
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
|
165 |
+
|
166 |
+
if self.kv_cache is not None:
|
167 |
+
keys, values = self.kv_cache.update(input_pos, xk, xv)
|
168 |
+
else:
|
169 |
+
keys, values = xk, xv
|
170 |
+
|
171 |
+
output = F.scaled_dot_product_attention(
|
172 |
+
xq, keys, values,
|
173 |
+
attn_mask=mask,
|
174 |
+
is_causal=True if mask is None else False, # is_causal=False is for KV cache
|
175 |
+
dropout_p=self.attn_dropout_p if self.training else 0)
|
176 |
+
|
177 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
178 |
+
|
179 |
+
output = self.resid_dropout(self.wo(output))
|
180 |
+
return output
|
181 |
+
|
182 |
+
|
183 |
+
class TransformerBlock(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
dim: int,
|
187 |
+
n_head: int,
|
188 |
+
multiple_of: int = 256,
|
189 |
+
norm_eps: float = 1e-5,
|
190 |
+
attn_dropout_p: float = 0.0,
|
191 |
+
ffn_dropout_p: float = 0.1,
|
192 |
+
resid_dropout_p: float = 0.1,
|
193 |
+
drop_path: float = 0.0,
|
194 |
+
):
|
195 |
+
super().__init__()
|
196 |
+
self.attention = Attention(
|
197 |
+
dim=dim,
|
198 |
+
n_head=n_head,
|
199 |
+
attn_dropout_p=attn_dropout_p,
|
200 |
+
resid_dropout_p=resid_dropout_p,
|
201 |
+
)
|
202 |
+
self.feed_forward = FeedForward(
|
203 |
+
dim=dim,
|
204 |
+
multiple_of=multiple_of,
|
205 |
+
ffn_dropout_p=ffn_dropout_p,
|
206 |
+
)
|
207 |
+
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
208 |
+
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
209 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
210 |
+
|
211 |
+
def forward(self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
|
212 |
+
h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
|
213 |
+
out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
|
214 |
+
return out
|
215 |
+
|
216 |
+
|
217 |
+
class Transformer(nn.Module):
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
dim: int = 4096,
|
221 |
+
n_layer: int = 32,
|
222 |
+
n_head: int = 32,
|
223 |
+
attn_dropout_p: float = 0.0,
|
224 |
+
resid_dropout_p: float = 0.1,
|
225 |
+
ffn_dropout_p: float = 0.1,
|
226 |
+
drop_path_rate: float = 0.0,
|
227 |
+
num_classes: Union[int, List[int]] = 1000,
|
228 |
+
class_dropout_prob: float = 0.1,
|
229 |
+
|
230 |
+
cls_token_num: int = 1,
|
231 |
+
num_slots: int = 16,
|
232 |
+
slot_dim: int = 256,
|
233 |
+
|
234 |
+
diffloss_d: int = 3,
|
235 |
+
diffloss_w: int = 1024,
|
236 |
+
num_sampling_steps: str = '100',
|
237 |
+
diffusion_batch_mul: int = 4,
|
238 |
+
predict_xstart: bool = False,
|
239 |
+
use_si: bool = False,
|
240 |
+
cond_method: str = "adaln",
|
241 |
+
**kwargs,
|
242 |
+
):
|
243 |
+
super().__init__()
|
244 |
+
|
245 |
+
# Store configuration
|
246 |
+
self.dim = dim
|
247 |
+
self.n_layer = n_layer
|
248 |
+
self.n_head = n_head
|
249 |
+
self.num_slots = num_slots
|
250 |
+
self.slot_dim = slot_dim
|
251 |
+
self.num_classes = num_classes
|
252 |
+
self.cls_token_num = cls_token_num
|
253 |
+
|
254 |
+
# Initialize embeddings
|
255 |
+
self.cls_embedding = LabelEmbedder(num_classes, dim, class_dropout_prob)
|
256 |
+
self.z_proj = nn.Linear(slot_dim, dim, bias=True)
|
257 |
+
self.z_proj_ln = RMSNorm(dim)
|
258 |
+
self.pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots + cls_token_num, dim))
|
259 |
+
|
260 |
+
# transformer blocks
|
261 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)]
|
262 |
+
self.layers = torch.nn.ModuleList()
|
263 |
+
for layer_id in range(n_layer):
|
264 |
+
self.layers.append(TransformerBlock(
|
265 |
+
dim=dim,
|
266 |
+
n_head=n_head,
|
267 |
+
ffn_dropout_p=ffn_dropout_p,
|
268 |
+
attn_dropout_p=attn_dropout_p,
|
269 |
+
resid_dropout_p=resid_dropout_p,
|
270 |
+
drop_path=dpr[layer_id],
|
271 |
+
))
|
272 |
+
|
273 |
+
# output layer
|
274 |
+
self.norm = RMSNorm(dim)
|
275 |
+
|
276 |
+
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots, dim))
|
277 |
+
|
278 |
+
# KVCache
|
279 |
+
self.max_batch_size = -1
|
280 |
+
self.max_seq_length = -1
|
281 |
+
|
282 |
+
self.initialize_weights()
|
283 |
+
|
284 |
+
# Diffusion Loss
|
285 |
+
self.diffloss = DiffLoss(
|
286 |
+
target_channels=slot_dim,
|
287 |
+
z_channels=dim,
|
288 |
+
width=diffloss_w,
|
289 |
+
depth=diffloss_d,
|
290 |
+
num_sampling_steps=num_sampling_steps,
|
291 |
+
predict_xstart=predict_xstart,
|
292 |
+
use_si=use_si,
|
293 |
+
cond_method=cond_method,
|
294 |
+
)
|
295 |
+
self.diffusion_batch_mul = diffusion_batch_mul
|
296 |
+
|
297 |
+
def initialize_weights(self):
|
298 |
+
nn.init.normal_(self.pos_embed_learned, std=0.02)
|
299 |
+
nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02)
|
300 |
+
# Initialize nn.Linear and nn.Embedding
|
301 |
+
self.apply(self._init_weights)
|
302 |
+
|
303 |
+
def _init_weights(self, module):
|
304 |
+
if isinstance(module, nn.Linear):
|
305 |
+
module.weight.data.normal_(std=0.02)
|
306 |
+
if module.bias is not None:
|
307 |
+
module.bias.data.zero_()
|
308 |
+
elif isinstance(module, nn.Embedding):
|
309 |
+
module.weight.data.normal_(std=0.02)
|
310 |
+
|
311 |
+
def setup_caches(self, max_batch_size, max_seq_length, dtype):
|
312 |
+
# if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
313 |
+
# return
|
314 |
+
head_dim = self.dim // self.n_head
|
315 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
316 |
+
self.max_seq_length = max_seq_length
|
317 |
+
self.max_batch_size = max_batch_size
|
318 |
+
for b in self.layers:
|
319 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.n_head, head_dim, dtype)
|
320 |
+
|
321 |
+
causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
|
322 |
+
self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
|
323 |
+
|
324 |
+
def reset_caches(self):
|
325 |
+
self.max_seq_length = -1
|
326 |
+
self.max_batch_size = -1
|
327 |
+
for b in self.layers:
|
328 |
+
b.attention.kv_cache = None
|
329 |
+
|
330 |
+
def forward_loss(self, z, target):
|
331 |
+
bsz, seq_len, _ = target.shape
|
332 |
+
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
|
333 |
+
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
|
334 |
+
loss = self.diffloss(z=z, target=target)
|
335 |
+
return loss
|
336 |
+
|
337 |
+
def forward_cfg(self, h, cfg):
|
338 |
+
if cfg > 1.0:
|
339 |
+
h_cond, h_uncond = h.chunk(2, dim=0)
|
340 |
+
h = h_uncond + cfg * (h_cond - h_uncond)
|
341 |
+
return h
|
342 |
+
|
343 |
+
def forward(
|
344 |
+
self,
|
345 |
+
slots: torch.Tensor,
|
346 |
+
cond_idx: torch.Tensor,
|
347 |
+
input_pos: Optional[torch.Tensor] = None,
|
348 |
+
mask: Optional[torch.Tensor] = None,
|
349 |
+
cfg: float = 1.0,
|
350 |
+
temperature: float = 1.0
|
351 |
+
):
|
352 |
+
if slots is not None and cond_idx is not None: # training or naive inference
|
353 |
+
cond_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
354 |
+
cond_embeddings = cond_embeddings.expand(-1, self.cls_token_num, -1)
|
355 |
+
token_embeddings = self.z_proj(slots)
|
356 |
+
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
|
357 |
+
else:
|
358 |
+
if cond_idx is not None: # prefill in inference
|
359 |
+
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
360 |
+
token_embeddings = token_embeddings.expand(-1, self.cls_token_num, -1)
|
361 |
+
else: # decode_n_tokens(kv cache) in inference
|
362 |
+
token_embeddings = self.z_proj(slots)
|
363 |
+
|
364 |
+
bs = token_embeddings.shape[0]
|
365 |
+
mask = self.causal_mask[:bs, None, input_pos]
|
366 |
+
|
367 |
+
h = token_embeddings
|
368 |
+
if self.training:
|
369 |
+
h = h + self.pos_embed_learned
|
370 |
+
else:
|
371 |
+
h = h + self.pos_embed_learned[:, input_pos].view(1, -1, self.dim)
|
372 |
+
|
373 |
+
h = self.z_proj_ln(h) # not sure if this is needed
|
374 |
+
|
375 |
+
# transformer blocks
|
376 |
+
for layer in self.layers:
|
377 |
+
h = layer(h, input_pos, mask)
|
378 |
+
|
379 |
+
h = self.norm(h)
|
380 |
+
|
381 |
+
if self.training:
|
382 |
+
h = h[:, self.cls_token_num - 1 : -1].contiguous()
|
383 |
+
h = h + self.diffusion_pos_embed_learned
|
384 |
+
loss = self.forward_loss(h, slots.detach())
|
385 |
+
return loss
|
386 |
+
else:
|
387 |
+
h = h[:, -1]
|
388 |
+
h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1]
|
389 |
+
next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=cfg)
|
390 |
+
return next_tokens
|
391 |
+
|
392 |
+
|
393 |
+
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
394 |
+
return list(self.layers)
|
395 |
+
|
396 |
+
|
397 |
+
|
398 |
+
#################################################################################
|
399 |
+
# GPT Configs #
|
400 |
+
#################################################################################
|
401 |
+
### text-conditional
|
402 |
+
def GPT_7B(**kwargs):
|
403 |
+
return Transformer(n_layer=32, n_head=32, dim=4096, **kwargs) # 6.6B
|
404 |
+
|
405 |
+
def GPT_3B(**kwargs):
|
406 |
+
return Transformer(n_layer=24, n_head=32, dim=3200, **kwargs) # 3.1B
|
407 |
+
|
408 |
+
def GPT_1B(**kwargs):
|
409 |
+
return Transformer(n_layer=22, n_head=32, dim=2048, **kwargs) # 1.2B
|
410 |
+
|
411 |
+
### class-conditional
|
412 |
+
def GPT_XXXL(**kwargs):
|
413 |
+
return Transformer(n_layer=48, n_head=40, dim=2560, **kwargs) # 3.9B
|
414 |
+
|
415 |
+
def GPT_XXL(**kwargs):
|
416 |
+
return Transformer(n_layer=48, n_head=24, dim=1536, **kwargs) # 1.4B
|
417 |
+
|
418 |
+
def GPT_XL(**kwargs):
|
419 |
+
return Transformer(n_layer=36, n_head=20, dim=1280, **kwargs) # 775M
|
420 |
+
|
421 |
+
def GPT_L(**kwargs):
|
422 |
+
return Transformer(n_layer=24, n_head=16, dim=1024, **kwargs) # 343M
|
423 |
+
|
424 |
+
def GPT_B(**kwargs):
|
425 |
+
return Transformer(n_layer=12, n_head=12, dim=768, **kwargs) # 111M
|
426 |
+
|
427 |
+
|
428 |
+
GPT_models = {
|
429 |
+
'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
|
430 |
+
'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
|
431 |
+
}
|
semanticist/utils/datasets.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import numpy as np
|
4 |
+
import os.path as osp
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision
|
7 |
+
import torchvision.transforms as TF
|
8 |
+
|
9 |
+
def pair(t):
|
10 |
+
return t if isinstance(t, tuple) else (t, t)
|
11 |
+
|
12 |
+
def center_crop_arr(pil_image, image_size):
|
13 |
+
"""
|
14 |
+
Center cropping implementation from ADM.
|
15 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
16 |
+
"""
|
17 |
+
while min(*pil_image.size) >= 2 * image_size:
|
18 |
+
pil_image = pil_image.resize(
|
19 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
20 |
+
)
|
21 |
+
|
22 |
+
scale = image_size / min(*pil_image.size)
|
23 |
+
pil_image = pil_image.resize(
|
24 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
25 |
+
)
|
26 |
+
|
27 |
+
arr = np.array(pil_image)
|
28 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
29 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
30 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
31 |
+
|
32 |
+
def vae_transforms(split, aug='randcrop', img_size=256):
|
33 |
+
t = []
|
34 |
+
if split == 'train':
|
35 |
+
if aug == 'randcrop':
|
36 |
+
t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
|
37 |
+
t.append(TF.RandomCrop(img_size))
|
38 |
+
elif aug == 'centercrop':
|
39 |
+
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Invalid augmentation: {aug}")
|
42 |
+
t.append(TF.RandomHorizontalFlip(p=0.5))
|
43 |
+
else:
|
44 |
+
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
|
45 |
+
|
46 |
+
t.append(TF.ToTensor())
|
47 |
+
|
48 |
+
return TF.Compose(t)
|
49 |
+
|
50 |
+
|
51 |
+
def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]):
|
52 |
+
t = []
|
53 |
+
if 'centercrop' in aug:
|
54 |
+
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
|
55 |
+
t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))])))
|
56 |
+
elif 'tencrop' in aug:
|
57 |
+
crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges]
|
58 |
+
t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes]))
|
59 |
+
t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple]))
|
60 |
+
t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops])))
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Invalid augmentation: {aug}")
|
63 |
+
|
64 |
+
return TF.Compose(t)
|
65 |
+
|
66 |
+
class ImageNet(torchvision.datasets.ImageFolder):
|
67 |
+
def __init__(self, root, split='train', aug='randcrop', img_size=256):
|
68 |
+
super().__init__(osp.join(root, split))
|
69 |
+
if not 'cache' in aug:
|
70 |
+
self.transform = vae_transforms(split, aug=aug, img_size=img_size)
|
71 |
+
else:
|
72 |
+
self.transform = cached_transforms(aug=aug, img_size=img_size)
|
semanticist/utils/device_utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def configure_compute_backend():
|
4 |
+
"""Configure PyTorch compute backend settings for CUDA."""
|
5 |
+
if torch.cuda.is_available():
|
6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
7 |
+
torch.backends.cudnn.allow_tf32 = True
|
8 |
+
torch.backends.cudnn.benchmark = True
|
9 |
+
torch.backends.cudnn.deterministic = False
|
10 |
+
else:
|
11 |
+
raise ValueError("No CUDA available")
|
12 |
+
|
13 |
+
def get_device():
|
14 |
+
"""Get the device to use for training."""
|
15 |
+
if torch.cuda.is_available():
|
16 |
+
return torch.device("cuda")
|
17 |
+
else:
|
18 |
+
raise ValueError("No CUDA available")
|
semanticist/utils/logger.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict, deque
|
2 |
+
import datetime
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from semanticist.engine.trainer_utils import is_dist_avail_and_initialized, is_main_process
|
7 |
+
from semanticist.utils.device_utils import get_device
|
8 |
+
|
9 |
+
def synchronize_processes():
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
torch.cuda.synchronize()
|
12 |
+
else: # do nothing
|
13 |
+
pass
|
14 |
+
|
15 |
+
def empty_cache():
|
16 |
+
if torch.cuda.is_available():
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
else: # do nothing
|
19 |
+
pass
|
20 |
+
|
21 |
+
class SmoothedValue(object):
|
22 |
+
"""Track a series of values and provide access to smoothed values over a
|
23 |
+
window or the global series average.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, window_size=20, fmt=None):
|
27 |
+
if fmt is None:
|
28 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
29 |
+
self.deque = deque(maxlen=window_size)
|
30 |
+
self.total = 0.0
|
31 |
+
self.count = 0
|
32 |
+
self.fmt = fmt
|
33 |
+
|
34 |
+
def update(self, value, n=1):
|
35 |
+
self.deque.append(value)
|
36 |
+
self.count += n
|
37 |
+
self.total += value * n
|
38 |
+
|
39 |
+
def synchronize_between_processes(self):
|
40 |
+
"""
|
41 |
+
Warning: does not synchronize the deque!
|
42 |
+
"""
|
43 |
+
if not is_dist_avail_and_initialized():
|
44 |
+
return
|
45 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float32, device=get_device())
|
46 |
+
dist.barrier()
|
47 |
+
dist.all_reduce(t)
|
48 |
+
t = t.tolist()
|
49 |
+
self.count = int(t[0])
|
50 |
+
self.total = t[1]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def median(self):
|
54 |
+
d = torch.tensor(list(self.deque))
|
55 |
+
return d.median().item()
|
56 |
+
|
57 |
+
@property
|
58 |
+
def avg(self):
|
59 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
60 |
+
return d.mean().item()
|
61 |
+
|
62 |
+
@property
|
63 |
+
def global_avg(self):
|
64 |
+
return self.total / self.count
|
65 |
+
|
66 |
+
@property
|
67 |
+
def max(self):
|
68 |
+
return max(self.deque)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def value(self):
|
72 |
+
return self.deque[-1]
|
73 |
+
|
74 |
+
def __str__(self):
|
75 |
+
return self.fmt.format(
|
76 |
+
median=self.median,
|
77 |
+
avg=self.avg,
|
78 |
+
global_avg=self.global_avg,
|
79 |
+
max=self.max,
|
80 |
+
value=self.value)
|
81 |
+
|
82 |
+
|
83 |
+
class MetricLogger(object):
|
84 |
+
def __init__(self, delimiter="\t"):
|
85 |
+
self.meters = defaultdict(SmoothedValue)
|
86 |
+
self.delimiter = delimiter
|
87 |
+
|
88 |
+
def update(self, **kwargs):
|
89 |
+
for k, v in kwargs.items():
|
90 |
+
if v is None:
|
91 |
+
continue
|
92 |
+
if isinstance(v, torch.Tensor):
|
93 |
+
v = v.item()
|
94 |
+
assert isinstance(v, (float, int))
|
95 |
+
self.meters[k].update(v)
|
96 |
+
|
97 |
+
def __getattr__(self, attr):
|
98 |
+
if attr in self.meters:
|
99 |
+
return self.meters[attr]
|
100 |
+
if attr in self.__dict__:
|
101 |
+
return self.__dict__[attr]
|
102 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
103 |
+
type(self).__name__, attr))
|
104 |
+
|
105 |
+
def __str__(self):
|
106 |
+
loss_str = []
|
107 |
+
for name, meter in self.meters.items():
|
108 |
+
loss_str.append(
|
109 |
+
"{}: {}".format(name, str(meter))
|
110 |
+
)
|
111 |
+
return self.delimiter.join(loss_str)
|
112 |
+
|
113 |
+
def synchronize_between_processes(self):
|
114 |
+
for meter in self.meters.values():
|
115 |
+
meter.synchronize_between_processes()
|
116 |
+
|
117 |
+
def add_meter(self, name, meter):
|
118 |
+
self.meters[name] = meter
|
119 |
+
|
120 |
+
def log_every(self, iterable, print_freq, header=None):
|
121 |
+
i = 0
|
122 |
+
if not header:
|
123 |
+
header = ''
|
124 |
+
start_time = time.time()
|
125 |
+
end = time.time()
|
126 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
127 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
128 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
129 |
+
log_msg = [
|
130 |
+
header,
|
131 |
+
'[{0' + space_fmt + '}/{1}]',
|
132 |
+
'eta: {eta}',
|
133 |
+
'{meters}',
|
134 |
+
'time: {time}',
|
135 |
+
'data: {data}'
|
136 |
+
]
|
137 |
+
if torch.cuda.is_available():
|
138 |
+
log_msg.append('mem: {memory:.0f}')
|
139 |
+
log_msg.append("util: {util:.1f}%")
|
140 |
+
log_msg = self.delimiter.join(log_msg)
|
141 |
+
MB = 1024.0 * 1024.0
|
142 |
+
for obj in iterable:
|
143 |
+
data_time.update(time.time() - end)
|
144 |
+
yield obj
|
145 |
+
iter_time.update(time.time() - end)
|
146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
149 |
+
if torch.cuda.is_available():
|
150 |
+
if is_main_process():
|
151 |
+
memory = torch.cuda.max_memory_allocated()
|
152 |
+
util = torch.cuda.utilization()
|
153 |
+
print(log_msg.format(
|
154 |
+
i, len(iterable), eta=eta_string,
|
155 |
+
meters=str(self),
|
156 |
+
time=str(iter_time), data=str(data_time),
|
157 |
+
memory=memory / MB, util=util))
|
158 |
+
else:
|
159 |
+
if is_main_process():
|
160 |
+
print(log_msg.format(
|
161 |
+
i, len(iterable), eta=eta_string,
|
162 |
+
meters=str(self),
|
163 |
+
time=str(iter_time), data=str(data_time)))
|
164 |
+
i += 1
|
165 |
+
end = time.time()
|
166 |
+
total_time = time.time() - start_time
|
167 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
168 |
+
if is_main_process():
|
169 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
170 |
+
header, total_time_str, total_time / len(iterable)))
|
semanticist/utils/lr_scheduler.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
2 |
+
from timm.scheduler.step_lr import StepLRScheduler
|
3 |
+
|
4 |
+
def build_scheduler(optimizer, n_epoch, n_iter_per_epoch, lr_min=0, warmup_steps=0, warmup_lr_init=0, decay_steps=None, cosine_lr=True):
|
5 |
+
if decay_steps is None:
|
6 |
+
decay_steps = n_epoch * n_iter_per_epoch
|
7 |
+
|
8 |
+
if cosine_lr:
|
9 |
+
scheduler = CosineLRScheduler(optimizer, t_initial=decay_steps, lr_min=lr_min, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
|
10 |
+
cycle_limit=1, t_in_epochs=False, warmup_prefix=True)
|
11 |
+
else:
|
12 |
+
scheduler = StepLRScheduler(optimizer, decay_t=decay_steps, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
|
13 |
+
t_in_epochs=False, warmup_prefix=True)
|
14 |
+
|
15 |
+
return scheduler
|
semanticist/utils/transform.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import torchvision.transforms as T
|
3 |
+
|
4 |
+
def pair(t):
|
5 |
+
return t if isinstance(t, tuple) else (t, t)
|
6 |
+
|
7 |
+
def stage1_transform(img_size=256, is_train=True, scale=0.8):
|
8 |
+
|
9 |
+
resize = pair(int(img_size/scale))
|
10 |
+
t = []
|
11 |
+
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
|
12 |
+
if is_train:
|
13 |
+
t.append(T.RandomCrop(img_size))
|
14 |
+
t.append(T.RandomHorizontalFlip(p=0.5))
|
15 |
+
else:
|
16 |
+
t.append(T.CenterCrop(img_size))
|
17 |
+
|
18 |
+
t.append(T.ToTensor())
|
19 |
+
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
|
20 |
+
|
21 |
+
return T.Compose(t)
|
22 |
+
|
23 |
+
def stage2_transform(img_size=256, is_train=True, scale=0.8):
|
24 |
+
resize = pair(int(img_size/scale))
|
25 |
+
t = []
|
26 |
+
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
|
27 |
+
if is_train:
|
28 |
+
t.append(T.RandomCrop(img_size))
|
29 |
+
else:
|
30 |
+
t.append(T.CenterCrop(img_size))
|
31 |
+
|
32 |
+
t.append(T.ToTensor())
|
33 |
+
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
|
34 |
+
|
35 |
+
return T.Compose(t)
|