from model.unet import ScaleAt |
from model.latentnet import * |
from diffusion.resample import UniformSampler |
from diffusion.diffusion import space_timesteps |
from typing import Tuple |
from torch.utils.data import DataLoader |
from config_base import BaseConfig |
from diffusion import * |
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule |
from model import * |
from choices import * |
from multiprocessing import get_context |
import os |
from dataset_util import * |
from torch.utils.data.distributed import DistributedSampler |
from dataset import LatentDataLoader |
@dataclass |
class PretrainConfig(BaseConfig): |
name: str |
path: str |
@dataclass |
class TrainConfig(BaseConfig): |
seed: int = 0 |
train_mode: TrainMode = TrainMode.diffusion |
train_cond0_prob: float = 0 |
train_pred_xstart_detach: bool = True |
train_interpolate_prob: float = 0 |
train_interpolate_img: bool = False |
manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all |
manipulate_cls: str = None |
manipulate_shots: int = None |
manipulate_loss: ManipulateLossType = ManipulateLossType.bce |
manipulate_znormalize: bool = False |
manipulate_seed: int = 0 |
accum_batches: int = 1 |
autoenc_mid_attn: bool = True |
batch_size: int = 16 |
batch_size_eval: int = None |
beatgans_gen_type: GenerativeType = GenerativeType.ddim |
beatgans_loss_type: LossType = LossType.mse |
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps |
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large |
beatgans_rescale_timesteps: bool = False |
latent_infer_path: str = None |
latent_znormalize: bool = False |
latent_gen_type: GenerativeType = GenerativeType.ddim |
latent_loss_type: LossType = LossType.mse |
latent_model_mean_type: ModelMeanType = ModelMeanType.eps |
latent_model_var_type: ModelVarType = ModelVarType.fixed_large |
latent_rescale_timesteps: bool = False |
latent_T_eval: int = 1_000 |
latent_clip_sample: bool = False |
latent_beta_scheduler: str = 'linear' |
beta_scheduler: str = 'linear' |
data_name: str = '' |
data_val_name: str = None |
diffusion_type: str = None |
dropout: float = 0.1 |
ema_decay: float = 0.9999 |
eval_num_images: int = 5_000 |
eval_every_samples: int = 200_000 |
eval_ema_every_samples: int = 200_000 |
fid_use_torch: bool = True |
fp16: bool = False |
grad_clip: float = 1 |
img_size: int = 64 |
lr: float = 0.0001 |
optimizer: OptimizerType = OptimizerType.adam |
weight_decay: float = 0 |
model_conf: ModelConfig = None |
model_name: ModelName = None |
model_type: ModelType = None |
net_attn: Tuple[int] = None |
net_beatgans_attn_head: int = 1 |
net_beatgans_embed_channels: int = 512 |
net_resblock_updown: bool = True |
net_enc_use_time: bool = False |
net_enc_pool: str = 'adaptivenonzero' |
net_beatgans_gradient_checkpoint: bool = False |
net_beatgans_resnet_two_cond: bool = False |
net_beatgans_resnet_use_zero_module: bool = True |
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm |
net_beatgans_resnet_cond_channels: int = None |
net_ch_mult: Tuple[int] = None |
net_ch: int = 64 |
net_enc_attn: Tuple[int] = None |
net_enc_k: int = None |
net_enc_num_res_blocks: int = 2 |
net_enc_channel_mult: Tuple[int] = None |
net_enc_grad_checkpoint: bool = False |
net_autoenc_stochastic: bool = False |
net_latent_activation: Activation = Activation.silu |
net_latent_channel_mult: Tuple[int] = (1, 2, 4) |
net_latent_condition_bias: float = 0 |
net_latent_dropout: float = 0 |
net_latent_layers: int = None |
net_latent_net_last_act: Activation = Activation.none |
net_latent_net_type: LatentNetType = LatentNetType.none |
net_latent_num_hid_channels: int = 1024 |
net_latent_num_time_layers: int = 2 |
net_latent_skip_layers: Tuple[int] = None |
net_latent_time_emb_channels: int = 64 |
net_latent_use_norm: bool = False |
net_latent_time_last_act: bool = False |
net_num_res_blocks: int = 2 |
net_num_input_res_blocks: int = None |
net_enc_num_cls: int = None |
num_workers: int = 4 |
parallel: bool = False |
postfix: str = '' |
sample_size: int = 64 |
sample_every_samples: int = 20_000 |
save_every_samples: int = 100_000 |
style_ch: int = 512 |
T_eval: int = 1_000 |
T_sampler: str = 'uniform' |
T: int = 1_000 |
total_samples: int = 10_000_000 |
warmup: int = 0 |
pretrain: PretrainConfig = None |
continue_from: PretrainConfig = None |
eval_programs: Tuple[str] = None |
eval_path: str = None |
base_dir: str = 'checkpoints' |
use_cache_dataset: bool = False |
data_cache_dir: str = os.path.expanduser('~/cache') |
work_cache_dir: str = os.path.expanduser('~/mycache') |
name: str = '' |
def __post_init__(self): |
self.batch_size_eval = self.batch_size_eval or self.batch_size |
self.data_val_name = self.data_val_name or self.data_name |
def scale_up_gpus(self, num_gpus, num_nodes=1): |
self.eval_ema_every_samples *= num_gpus * num_nodes |
self.eval_every_samples *= num_gpus * num_nodes |
self.sample_every_samples *= num_gpus * num_nodes |
self.batch_size *= num_gpus * num_nodes |
self.batch_size_eval *= num_gpus * num_nodes |
return self |
@property |
def batch_size_effective(self): |
return self.batch_size * self.accum_batches |
@property |
def fid_cache(self): |
return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' |
@property |
def data_path(self): |
path = data_paths[self.data_name] |
if self.use_cache_dataset and path is not None: |
path = use_cached_dataset_path( |
path, f'{self.data_cache_dir}/{self.data_name}') |
return path |
@property |
def logdir(self): |
return f'{self.base_dir}/{self.name}' |
@property |
def generate_dir(self): |
return f'{self.work_cache_dir}/gen_images/{self.name}' |
def _make_diffusion_conf(self, T=None): |
if self.diffusion_type == 'beatgans': |
if self.beatgans_gen_type == GenerativeType.ddpm: |
section_counts = [T] |
elif self.beatgans_gen_type == GenerativeType.ddim: |
section_counts = f'ddim{T}' |
else: |
raise NotImplementedError() |
return SpacedDiffusionBeatGansConfig( |
gen_type=self.beatgans_gen_type, |
model_type=self.model_type, |
betas=get_named_beta_schedule(self.beta_scheduler, self.T), |
model_mean_type=self.beatgans_model_mean_type, |
model_var_type=self.beatgans_model_var_type, |
loss_type=self.beatgans_loss_type, |
rescale_timesteps=self.beatgans_rescale_timesteps, |
use_timesteps=space_timesteps(num_timesteps=self.T, |
section_counts=section_counts), |
fp16=self.fp16, |
) |
else: |
raise NotImplementedError() |
def _make_latent_diffusion_conf(self, T=None): |
if self.latent_gen_type == GenerativeType.ddpm: |
section_counts = [T] |
elif self.latent_gen_type == GenerativeType.ddim: |
section_counts = f'ddim{T}' |
else: |
raise NotImplementedError() |
return SpacedDiffusionBeatGansConfig( |
train_pred_xstart_detach=self.train_pred_xstart_detach, |
gen_type=self.latent_gen_type, |
model_type=ModelType.ddpm, |
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), |
model_mean_type=self.latent_model_mean_type, |
model_var_type=self.latent_model_var_type, |
loss_type=self.latent_loss_type, |
rescale_timesteps=self.latent_rescale_timesteps, |
use_timesteps=space_timesteps(num_timesteps=self.T, |
section_counts=section_counts), |
fp16=self.fp16, |
) |
@property |
def model_out_channels(self): |
return 3 |
def make_T_sampler(self): |
if self.T_sampler == 'uniform': |
return UniformSampler(self.T) |
else: |
raise NotImplementedError() |
def make_diffusion_conf(self): |
return self._make_diffusion_conf(self.T) |
def make_eval_diffusion_conf(self): |
return self._make_diffusion_conf(T=self.T_eval) |
def make_latent_diffusion_conf(self): |
return self._make_latent_diffusion_conf(T=self.T) |
def make_latent_eval_diffusion_conf(self): |
return self._make_latent_diffusion_conf(T=self.latent_T_eval) |
def make_dataset(self, path=None, **kwargs): |
return LatentDataLoader(self.window_size, |
self.frame_jpgs, |
self.lmd_feats_prefix, |
self.audio_prefix, |
self.raw_audio_prefix, |
self.motion_latents_prefix, |
self.pose_prefix, |
self.db_name, |
audio_hz=self.audio_hz) |
def make_loader(self, |
dataset, |
shuffle: bool, |
num_worker: bool = None, |
drop_last: bool = True, |
batch_size: int = None, |
parallel: bool = False): |
if parallel and distributed.is_initialized(): |
sampler = DistributedSampler(dataset, |
shuffle=shuffle, |
drop_last=True) |
else: |
sampler = None |
return DataLoader( |
dataset, |
batch_size=batch_size or self.batch_size, |
sampler=sampler, |
shuffle=False if sampler else shuffle, |
num_workers=num_worker or self.num_workers, |
pin_memory=True, |
drop_last=drop_last, |
multiprocessing_context=get_context('fork'), |
) |
def make_model_conf(self): |
if self.model_name == ModelName.beatgans_ddpm: |
self.model_type = ModelType.ddpm |
self.model_conf = BeatGANsUNetConfig( |
attention_resolutions=self.net_attn, |
channel_mult=self.net_ch_mult, |
conv_resample=True, |
dims=2, |
dropout=self.dropout, |
embed_channels=self.net_beatgans_embed_channels, |
image_size=self.img_size, |
in_channels=3, |
model_channels=self.net_ch, |
num_classes=None, |
num_head_channels=-1, |
num_heads_upsample=-1, |
num_heads=self.net_beatgans_attn_head, |
num_res_blocks=self.net_num_res_blocks, |
num_input_res_blocks=self.net_num_input_res_blocks, |
out_channels=self.model_out_channels, |
resblock_updown=self.net_resblock_updown, |
use_checkpoint=self.net_beatgans_gradient_checkpoint, |
use_new_attention_order=False, |
resnet_two_cond=self.net_beatgans_resnet_two_cond, |
resnet_use_zero_module=self. |
net_beatgans_resnet_use_zero_module, |
) |
elif self.model_name in [ |
ModelName.beatgans_autoenc, |
]: |
cls = BeatGANsAutoencConfig |
if self.model_name == ModelName.beatgans_autoenc: |
self.model_type = ModelType.autoencoder |
else: |
raise NotImplementedError() |
if self.net_latent_net_type == LatentNetType.none: |
latent_net_conf = None |
elif self.net_latent_net_type == LatentNetType.skip: |
latent_net_conf = MLPSkipNetConfig( |
num_channels=self.style_ch, |
skip_layers=self.net_latent_skip_layers, |
num_hid_channels=self.net_latent_num_hid_channels, |
num_layers=self.net_latent_layers, |
num_time_emb_channels=self.net_latent_time_emb_channels, |
activation=self.net_latent_activation, |
use_norm=self.net_latent_use_norm, |
condition_bias=self.net_latent_condition_bias, |
dropout=self.net_latent_dropout, |
last_act=self.net_latent_net_last_act, |
num_time_layers=self.net_latent_num_time_layers, |
time_last_act=self.net_latent_time_last_act, |
) |
else: |
raise NotImplementedError() |
self.model_conf = cls( |
attention_resolutions=self.net_attn, |
channel_mult=self.net_ch_mult, |
conv_resample=True, |
dims=2, |
dropout=self.dropout, |
embed_channels=self.net_beatgans_embed_channels, |
enc_out_channels=self.style_ch, |
enc_pool=self.net_enc_pool, |
enc_num_res_block=self.net_enc_num_res_blocks, |
enc_channel_mult=self.net_enc_channel_mult, |
enc_grad_checkpoint=self.net_enc_grad_checkpoint, |
enc_attn_resolutions=self.net_enc_attn, |
image_size=self.img_size, |
in_channels=3, |
model_channels=self.net_ch, |
num_classes=None, |
num_head_channels=-1, |
num_heads_upsample=-1, |
num_heads=self.net_beatgans_attn_head, |
num_res_blocks=self.net_num_res_blocks, |
num_input_res_blocks=self.net_num_input_res_blocks, |
out_channels=self.model_out_channels, |
resblock_updown=self.net_resblock_updown, |
use_checkpoint=self.net_beatgans_gradient_checkpoint, |
use_new_attention_order=False, |
resnet_two_cond=self.net_beatgans_resnet_two_cond, |
resnet_use_zero_module=self. |
net_beatgans_resnet_use_zero_module, |
latent_net_conf=latent_net_conf, |
resnet_cond_channels=self.net_beatgans_resnet_cond_channels, |
) |
else: |
raise NotImplementedError(self.model_name) |
return self.model_conf |