import torch
from torch import nn
from nsr.triplane import Triplane_fg_bg_plane
# import timm
from vit.vit_triplane import Triplane, ViTTriplaneDecomposed
import argparse
import inspect
import dnnlib
from guided_diffusion import dist_util

from pdb import set_trace as st

import vit.vision_transformer as vits
from guided_diffusion import logger
from .confnet import ConfNet

from ldm.modules.diffusionmodules.model import Encoder, MVEncoder, MVEncoderGS, MVEncoderGSDynamicInp
from ldm.modules.diffusionmodules.mv_unet import MVUNet, LGM_MVEncoder

# from ldm.modules.diffusionmodules.openaimodel import MultiViewUNetModel_Encoder

# * create pre-trained encoder & triplane / other nsr decoder


class AE(torch.nn.Module):

    def __init__(self,
                 encoder,
                 decoder,
                 img_size,
                 encoder_cls_token,
                 decoder_cls_token,
                 preprocess,
                 use_clip,
                 dino_version='v1',
                 clip_dtype=None,
                 no_dim_up_mlp=False,
                 dim_up_mlp_as_func=False,
                 uvit_skip_encoder=False,
                 confnet=None) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.img_size = img_size
        self.encoder_cls_token = encoder_cls_token
        self.decoder_cls_token = decoder_cls_token
        self.use_clip = use_clip
        self.dino_version = dino_version
        self.confnet = confnet

        if self.dino_version == 'v2':
            self.encoder.mask_token = None
            self.decoder.vit_decoder.mask_token = None

        if 'sd' not in self.dino_version:

            self.uvit_skip_encoder = uvit_skip_encoder
            if uvit_skip_encoder:
                logger.log(
                    f'enables uvit: length of vit_encoder.blocks: {len(self.encoder.blocks)}'
                )
                for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]:
                    blk.skip_linear = nn.Linear(2 * self.encoder.embed_dim,
                                                self.encoder.embed_dim)

                    # trunc_normal_(blk.skip_linear.weight, std=.02)
                    nn.init.constant_(blk.skip_linear.weight, 0)
                    if isinstance(
                            blk.skip_linear,
                            nn.Linear) and blk.skip_linear.bias is not None:
                        nn.init.constant_(blk.skip_linear.bias, 0)
            else:
                logger.log(f'disable uvit')
        else:
            if 'dit' not in self.dino_version:  # dino vit, not dit
                self.decoder.vit_decoder.cls_token = None
                self.decoder.vit_decoder.patch_embed.proj = nn.Identity()
                self.decoder.triplane_decoder.planes = None
                self.decoder.vit_decoder.mask_token = None

            if self.use_clip:
                self.clip_dtype = clip_dtype  # torch.float16

            else:

                if not no_dim_up_mlp and self.encoder.embed_dim != self.decoder.vit_decoder.embed_dim:
                    self.dim_up_mlp = nn.Linear(
                        self.encoder.embed_dim,
                        self.decoder.vit_decoder.embed_dim)
                    logger.log(
                        f"dim_up_mlp: {self.encoder.embed_dim} -> {self.decoder.vit_decoder.embed_dim}, as_func: {self.dim_up_mlp_as_func}"
                    )
                else:
                    logger.log('ignore dim_up_mlp: ', no_dim_up_mlp)

        self.preprocess = preprocess

        self.dim_up_mlp = None  # CLIP/B-16
        self.dim_up_mlp_as_func = dim_up_mlp_as_func

        # * remove certain components to make sure no unused parameters during DDP
        # self.decoder.vit_decoder.cls_token = nn.Identity()
        torch.cuda.empty_cache()
        # self.decoder.vit_decoder.patch_embed.proj.bias = nn.Identity()
        # self.decoder.vit_decoder.patch_embed.proj.weight = nn.Identity()
        # self.decoder.vit_decoder.patch_embed.proj.bias = nn.Identity()

    def encode(self, *args, **kwargs):
        if not self.use_clip:
            if self.dino_version == 'v1':
                latent = self.encode_dinov1(*args, **kwargs)
            elif self.dino_version == 'v2':
                if self.uvit_skip_encoder:
                    latent = self.encode_dinov2_uvit(*args, **kwargs)
                else:
                    latent = self.encode_dinov2(*args, **kwargs)
            else:
                latent = self.encoder(*args)

        else:
            latent = self.encode_clip(*args, **kwargs)

        return latent

    def encode_dinov1(self, x):
        # return self.encoder(img)
        x = self.encoder.prepare_tokens(x)
        for blk in self.encoder.blocks:
            x = blk(x)
        x = self.encoder.norm(x)
        if not self.encoder_cls_token:
            return x[:, 1:]

        return x

    def encode_dinov2(self, x):
        # return self.encoder(img)
        x = self.encoder.prepare_tokens_with_masks(x, masks=None)
        for blk in self.encoder.blocks:
            x = blk(x)
        x_norm = self.encoder.norm(x)

        if not self.encoder_cls_token:
            return x_norm[:, 1:]
        # else:
        # return x_norm[:, :1]

        # return {
        #     "x_norm_clstoken": x_norm[:, 0],
        #     "x_norm_patchtokens": x_norm[:, 1:],
        # }

        return x_norm

    def encode_dinov2_uvit(self, x):
        # return self.encoder(img)
        x = self.encoder.prepare_tokens_with_masks(x, masks=None)

        # for blk in self.encoder.blocks:
        #     x = blk(x)

        skips = [x]

        # in blks
        for blk in self.encoder.blocks[0:len(self.encoder.blocks) // 2 - 1]:
            x = blk(x)  # B 3 N C
            skips.append(x)

        # mid blks
        for blk in self.encoder.blocks[len(self.encoder.blocks) // 2 -
                                       1:len(self.encoder.blocks) // 2]:
            x = blk(x)  # B 3 N C

        # out blks
        for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]:
            x = x + blk.skip_linear(torch.cat(
                [x, skips.pop()], dim=-1))  # long skip connections in uvit
            x = blk(x)  # B 3 N C

        x_norm = self.encoder.norm(x)

        if not self.decoder_cls_token:
            return x_norm[:, 1:]

        return x_norm

    def encode_clip(self, x):
        # * replace with CLIP encoding pipeline
        # return self.encoder(img)
        # x = x.dtype(self.clip_dtype)
        x = self.encoder.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1],
                      -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([
            self.encoder.class_embedding.to(x.dtype) + torch.zeros(
                x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
        ],
                      dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.encoder.positional_embedding.to(x.dtype)
        x = self.encoder.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.encoder.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.encoder.ln_post(x[:, 1:, :])  # * return the spatial tokens

        return x

        # x = self.ln_post(x[:, 0, :]) # * return the spatial tokens

        # if self.proj is not None:
        #     x = x @ self.proj

        # return x

    def decode_wo_triplane(self, latent, c=None, img_size=None):
        if img_size is None:
            img_size = self.img_size

        if self.dim_up_mlp is not None:
            if not self.dim_up_mlp_as_func:
                latent = self.dim_up_mlp(latent)
                # return self.decoder.vit_decode(latent, img_size)
            else:
                return self.decoder.vit_decode(
                    latent, img_size,
                    dim_up_mlp=self.dim_up_mlp)  # used in vae-ldm

        return self.decoder.vit_decode(latent, img_size, c=c)

    def decode(self, latent, c, img_size=None, return_raw_only=False):
        # if img_size is None:
        #     img_size = self.img_size

        # if self.dim_up_mlp is not None:
        #     latent = self.dim_up_mlp(latent)

        latent = self.decode_wo_triplane(latent, img_size=img_size, c=c)
        # return self.decoder.triplane_decode(latent, c, return_raw_only=return_raw_only)
        return self.decoder.triplane_decode(latent, c)

    def decode_after_vae_no_render(
        self,
        ret_dict,
        img_size=None,
    ):

        if img_size is None:
            img_size = self.img_size

        assert self.dim_up_mlp is None
        # if not self.dim_up_mlp_as_func:
        #     latent = self.dim_up_mlp(latent)
        # return self.decoder.vit_decode(latent, img_size)

        latent = self.decoder.vit_decode_backbone(ret_dict, img_size)
        ret_dict = self.decoder.vit_decode_postprocess(latent, ret_dict)
        return ret_dict

    def decode_after_vae(
            self,
            #  latent,
            ret_dict,  # vae_dict
            c,
            img_size=None,
            return_raw_only=False):
        ret_dict = self.decode_after_vae_no_render(ret_dict, img_size)
        return self.decoder.triplane_decode(ret_dict, c)

    def decode_confmap(self, img):
        assert self.confnet is not None
        # https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L152
        # conf_sigma_l1 = self.confnet(img)  # Bx2xHxW
        return self.confnet(img)  # Bx1xHxW

    def encode_decode(self, img, c, return_raw_only=False):
        latent = self.encode(img)
        pred = self.decode(latent, c, return_raw_only=return_raw_only)
        if self.confnet is not None:
            pred.update({
                'conf_sigma': self.decode_confmap(img)  # 224x224
            })

        return pred

    def forward(self,
                img=None,
                c=None,
                latent=None,
                behaviour='enc_dec',
                coordinates=None,
                directions=None,
                return_raw_only=False,
                *args,
                **kwargs):
        """wrap all operations inside forward() for DDP use.
        """

        if behaviour == 'enc_dec':
            pred = self.encode_decode(img, c, return_raw_only=return_raw_only)
            return pred

        elif behaviour == 'enc':
            latent = self.encode(img)
            return latent

        elif behaviour == 'dec':
            assert latent is not None
            pred: dict = self.decode(latent,
                                     c,
                                     self.img_size,
                                     return_raw_only=return_raw_only)
            return pred

        elif behaviour == 'dec_wo_triplane':
            assert latent is not None
            pred: dict = self.decode_wo_triplane(latent, self.img_size)
            return pred

        elif behaviour == 'enc_dec_wo_triplane':
            latent = self.encode(img)
            pred: dict = self.decode_wo_triplane(latent, img_size=self.img_size, c=c)
            return pred

        elif behaviour == 'encoder_vae':
            latent = self.encode(img)
            ret_dict = self.decoder.vae_reparameterization(latent, True)
            return ret_dict

        elif behaviour == 'decode_after_vae_no_render':
            pred: dict = self.decode_after_vae_no_render(latent, self.img_size)
            return pred

        elif behaviour == 'decode_after_vae':
            pred: dict = self.decode_after_vae(latent, c, self.img_size)
            return pred

        # elif behaviour == 'gaussian_dec':
        #     assert latent is not None
        #     pred: dict = self.decoder.triplane_decode(
        #         latent, c, return_raw_only=return_raw_only, **kwargs)
        #     # pred: dict = self.decoder.triplane_decode(latent, c)

        elif behaviour == 'triplane_dec':
            assert latent is not None
            pred: dict = self.decoder.triplane_decode(
                latent, c, return_raw_only=return_raw_only, **kwargs)
            # pred: dict = self.decoder.triplane_decode(latent, c)

        elif behaviour == 'triplane_decode_grid':
            assert latent is not None
            pred: dict = self.decoder.triplane_decode_grid(
                latent, **kwargs)
            # pred: dict = self.decoder.triplane_decode(latent, c)

        elif behaviour == 'vit_postprocess_triplane_dec':
            assert latent is not None
            latent = self.decoder.vit_decode_postprocess(
                latent)  # translate spatial token from vit-decoder into 2D
            pred: dict = self.decoder.triplane_decode(
                latent, c)  # render with triplane

        elif behaviour == 'triplane_renderer':
            assert latent is not None
            pred: dict = self.decoder.triplane_renderer(
                latent, coordinates, directions)

        # elif behaviour == 'triplane_SR':
        #     assert latent is not None
        #     pred: dict = self.decoder.triplane_renderer(
        #         latent, coordinates, directions)

        elif behaviour == 'get_rendering_kwargs':
            pred = self.decoder.triplane_decoder.rendering_kwargs

        return pred


class AE_CLIPEncoder(AE):

    def __init__(self, encoder, decoder, img_size, cls_token) -> None:
        super().__init__(encoder, decoder, img_size, cls_token)


class AE_with_Diffusion(torch.nn.Module):

    def __init__(self, auto_encoder, denoise_model) -> None:
        super().__init__()
        self.auto_encoder = auto_encoder
        self.denoise_model = denoise_model  # simply for easy MPTrainer manipulation

    def forward(self,
                img,
                c,
                behaviour='enc_dec',
                latent=None,
                *args,
                **kwargs):
        # wrap auto_encoder and denoising model inside a single forward function to use DDP (only forward supported) and MPTrainer (single model) easier
        if behaviour == 'enc_dec':
            pred = self.auto_encoder(img, c)
            return pred
        elif behaviour == 'enc':
            latent = self.auto_encoder.encode(img)
            if self.auto_encoder.dim_up_mlp is not None:
                latent = self.auto_encoder.dim_up_mlp(latent)
            return latent
        elif behaviour == 'dec':
            assert latent is not None
            pred: dict = self.auto_encoder.decode(latent, c, self.img_size)
            return pred
        elif behaviour == 'denoise':
            assert latent is not None
            pred: dict = self.denoise_model(*args, **kwargs)
            return pred


def eg3d_options_default():

    opts = dnnlib.EasyDict(
        dict(
            cbase=32768,
            cmax=512,
            map_depth=2,
            g_class_name='nsr.triplane.TriPlaneGenerator',  # TODO
            g_num_fp16_res=0,
        ))

    return opts


def rendering_options_defaults(opts):

    rendering_options = {
        # 'image_resolution': c.training_set_kwargs.resolution,
        'image_resolution': 256,
        'disparity_space_sampling': False,
        'clamp_mode': 'softplus',
        'c_gen_conditioning_zero':
        True,  # if true, fill generator pose conditioning label with dummy zero vector
        # 'gpc_reg_prob': opts.gpc_reg_prob if opts.gen_pose_cond else None,
        'c_scale':
        opts.c_scale,  # mutliplier for generator pose conditioning label
        'superresolution_noise_mode': 'none',
        'density_reg': opts.density_reg,  # strength of density regularization
        'density_reg_p_dist': opts.
        density_reg_p_dist,  # distance at which to sample perturbed points for density regularization
        'reg_type': opts.
        reg_type,  # for experimenting with variations on density regularization
        'decoder_lr_mul': 1,
        # opts.decoder_lr_mul,  # learning rate multiplier for decoder
        'decoder_activation': 'sigmoid',
        'sr_antialias': True,
        'return_triplane_features': False,  # for DDF supervision
        'return_sampling_details_flag': False,

        # * shape default sr

        # 'superresolution_module': 'nsr.superresolution.SuperresolutionHybrid4X',
        # 'superresolution_module':
        # 'utils.torch_utils.components.PixelUnshuffleUpsample',
        'superresolution_module': 'utils.torch_utils.components.NearestConvSR',
    }

    if opts.cfg == 'ffhq':
        rendering_options.update({
            'superresolution_module':
            'nsr.superresolution.SuperresolutionHybrid8XDC',
            'focal': 2985.29 / 700,
            'depth_resolution':
            48 - 0,  # number of uniform samples to take per ray.
            'depth_resolution_importance':
            48 - 0,  # number of importance samples to take per ray.
            'bg_depth_resolution':
            16,  # 4/14 in stylenerf, https://github.com/facebookresearch/StyleNeRF/blob/7f5610a058f27fcc360c6b972181983d7df794cb/conf/model/stylenerf_ffhq.yaml#L48
            'ray_start':
            2.25,  # near point along each ray to start taking samples.
            'ray_end':
            3.3,  # far point along each ray to stop taking samples. 
            'box_warp':
            1,  # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5].
            'avg_camera_radius':
            2.7,  # used only in the visualizer to specify camera orbit radius.
            'avg_camera_pivot': [
                0, 0, 0.2
            ],  # used only in the visualizer to control center of camera rotation.
            'superresolution_noise_mode': 'random',
        })
    elif opts.cfg == 'afhq':
        rendering_options.update({
            'superresolution_module':
            'nsr.superresolution.SuperresolutionHybrid8X',
            'superresolution_noise_mode': 'random',
            'focal': 4.2647,
            'depth_resolution': 48,
            'depth_resolution_importance': 48,
            'ray_start': 2.25,
            'ray_end': 3.3,
            'box_warp': 1,
            'avg_camera_radius': 2.7,
            'avg_camera_pivot': [0, 0, -0.06],
        })
    elif opts.cfg == 'shapenet':  # TODO, lies in a sphere
        rendering_options.update({
            'depth_resolution': 64,
            'depth_resolution_importance': 64,
            # * radius 1.2 setting, newly rendered images
            'ray_start': 0.2,
            'ray_end': 2.2,
            # 'ray_start': opts.ray_start,
            # 'ray_end': opts.ray_end,
            'box_warp': 2,  # TODO, how to set this value?
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'eg3d_shapenet_aug_resolution':
        rendering_options.update({
            'depth_resolution': 80,
            'depth_resolution_importance': 80,
            'ray_start': 0.1,
            'ray_end': 1.9,  # 2.6/1.7*1.2
            'box_warp': 1.1,
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair':
        rendering_options.update({
            'depth_resolution': 96,
            'depth_resolution_importance': 96,
            'ray_start': 0.1,
            'ray_end': 1.9,  # 2.6/1.7*1.2
            'box_warp': 1.1,
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128':
        rendering_options.update({
            'depth_resolution': 128,
            'depth_resolution_importance': 128,
            'ray_start': 0.1,
            'ray_end': 1.9,  # 2.6/1.7*1.2
            'box_warp': 1.1,
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_64':
        rendering_options.update({
            'depth_resolution': 64,
            'depth_resolution_importance': 64,
            'ray_start': 0.1,
            'ray_end': 1.9,  # 2.6/1.7*1.2
            'box_warp': 1.1,
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'srn_shapenet_aug_resolution_chair_128':
        rendering_options.update({
            'depth_resolution': 128,
            'depth_resolution_importance': 128,
            'ray_start': 1.25,
            'ray_end': 2.75,
            'box_warp': 1.5,
            'white_back': True,
            'avg_camera_radius': 2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128_residualSR':
        rendering_options.update({
            'depth_resolution':
            128,
            'depth_resolution_importance':
            128,
            'ray_start':
            0.1,
            'ray_end':
            1.9,  # 2.6/1.7*1.2
            'box_warp':
            1.1,
            'white_back':
            True,
            'avg_camera_radius':
            1.2,
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR_Residual',
        })

    elif opts.cfg == 'shapenet_tuneray':  # TODO, lies in a sphere
        rendering_options.update({
            'depth_resolution': 64,
            'depth_resolution_importance': 64,
            # * radius 1.2 setting, newly rendered images
            'ray_start': opts.ray_start,
            'ray_end': opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'shapenet_tuneray_aug_resolution':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution': 80,
            'depth_resolution_importance': 80,
            # * radius 1.2 setting, newly rendered images
            'ray_start': opts.ray_start,
            'ray_end': opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution': 128,
            'depth_resolution_importance': 128,
            # * radius 1.2 setting, newly rendered images
            'ray_start': opts.ray_start,
            'ray_end': opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution': 96,
            'depth_resolution_importance': 96,
            # * radius 1.2 setting, newly rendered images
            'ray_start': opts.ray_start,
            'ray_end': opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })
    # ! default version
    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestSR':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            96,
            'depth_resolution_importance':
            96,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            opts.ray_start,
            'ray_end':
            opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back':
            True,
            'avg_camera_radius':
            1.2,
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR',
        })

    # ! 64+64, since ssdnerf adopts this setting
    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            64,
            'depth_resolution_importance':
            64,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            opts.ray_start,
            'ray_end':
            opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back':
            True,
            'avg_camera_radius':
            1.2,
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR',
        })

    # ! 64+64+patch, since ssdnerf adopts this setting
    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR_patch':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            64,
            'depth_resolution_importance':
            64,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            opts.ray_start,
            'ray_end':
            opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back':
            True,
            'avg_camera_radius':
            1.2,
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR',
            # patch configs
            'PatchRaySampler':
            True,
            # 'patch_rendering_resolution': 32,
            # 'patch_rendering_resolution': 48,
            'patch_rendering_resolution':
            opts.patch_rendering_resolution,
        })

    elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_nearestSR':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            64,
            'depth_resolution_importance':
            64,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            opts.ray_start,
            # 'auto',
            'ray_end':
            opts.ray_end,
            # 'auto',
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            # 2,
            'white_back':
            True,
            'avg_camera_radius':
            1.946,  # ?
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR',
            # patch configs
            # 'PatchRaySampler': False,
            # 'patch_rendering_resolution': 32,
            # 'patch_rendering_resolution': 48,
            # 'patch_rendering_resolution': opts.patch_rendering_resolution,
        })

    elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_auto':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            64,
            'depth_resolution_importance':
            64,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            'auto',
            'ray_end':
            'auto',
            'box_warp':
            0.9,
            'white_back':
            True,
            'radius_range': [1.5,2],
            # 'z_near': 1.5-0.45, # radius in [1.5, 2], https://github.com/modelscope/richdreamer/issues/12#issuecomment-1897734616
            # 'z_far': 2.0+0.45,
            'sampler_bbox_min':
            -0.45,
            'sampler_bbox_max':
            0.45,
            # 'avg_camera_pivot': [0, 0, 0], # not used
            'filter_out_of_bbox':
            True,
            # 'superresolution_module':
            # 'utils.torch_utils.components.NearestConvSR',
            # patch configs
            'PatchRaySampler':
            True,
            # 'patch_rendering_resolution': 32,
            # 'patch_rendering_resolution': 48,
            'patch_rendering_resolution':
            opts.patch_rendering_resolution,
        })
        rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min']
        rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max']


    elif opts.cfg == 'objverse_tuneray_aug_resolution_128_128_auto':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            128,
            'depth_resolution_importance':
            128,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            'auto',
            'ray_end':
            'auto',
            'box_warp':
            0.9,
            'white_back':
            True,
            'radius_range': [1.5,2],
            # 'z_near': 1.5-0.45, # radius in [1.5, 2], https://github.com/modelscope/richdreamer/issues/12#issuecomment-1897734616
            # 'z_far': 2.0+0.45,
            'sampler_bbox_min':
            -0.45,
            'sampler_bbox_max':
            0.45,
            # 'avg_camera_pivot': [0, 0, 0], # not used
            'filter_out_of_bbox':
            True,
            # 'superresolution_module':
            # 'utils.torch_utils.components.NearestConvSR',
            # patch configs
            'PatchRaySampler':
            True,
            # 'patch_rendering_resolution': 32,
            # 'patch_rendering_resolution': 48,
            'patch_rendering_resolution':
            opts.patch_rendering_resolution,
        })
        rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min']
        rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max']

    elif opts.cfg == 'objverse_tuneray_aug_resolution_96_96_auto':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            96,
            'depth_resolution_importance':
            96,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            'auto',
            'ray_end':
            'auto',
            'box_warp':
            0.9,
            'white_back':
            True,
            'radius_range': [1.5,2],
            'sampler_bbox_min':
            -0.45,
            'sampler_bbox_max':
            0.45,
            'filter_out_of_bbox':
            True,
            'PatchRaySampler':
            True,
            'patch_rendering_resolution':
            opts.patch_rendering_resolution,
        })
        rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min']
        rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max']



    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestResidualSR':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            96,
            'depth_resolution_importance':
            96,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            opts.ray_start,
            'ray_end':
            opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back':
            True,
            'avg_camera_radius':
            1.2,
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR_Residual',
        })

    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestResidualSR':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution':
            64,
            'depth_resolution_importance':
            64,
            # * radius 1.2 setting, newly rendered images
            'ray_start':
            opts.ray_start,
            'ray_end':
            opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back':
            True,
            'avg_camera_radius':
            1.2,
            'avg_camera_pivot': [0, 0, 0],
            'superresolution_module':
            'utils.torch_utils.components.NearestConvSR_Residual',
        })

    elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_104':  # to differentiate hwc
        rendering_options.update({
            'depth_resolution': 104,
            'depth_resolution_importance': 104,
            # * radius 1.2 setting, newly rendered images
            'ray_start': opts.ray_start,
            'ray_end': opts.ray_end,
            'box_warp':
            opts.ray_end - opts.ray_start,  # TODO, how to set this value?
            'white_back': True,
            'avg_camera_radius': 1.2,
            'avg_camera_pivot': [0, 0, 0],
        })

    rendering_options.update({'return_sampling_details_flag': True})
    rendering_options.update({'return_sampling_details_flag': True})

    return rendering_options


def model_encoder_defaults():

    return dict(
        use_clip=False,
        arch_encoder="vits",
        arch_decoder="vits",
        load_pretrain_encoder=False,
        encoder_lr=1e-5,
        encoder_weight_decay=
        0.001,  # https://github.com/google-research/vision_transformer
        no_dim_up_mlp=False,
        dim_up_mlp_as_func=False,
        decoder_load_pretrained=True,
        uvit_skip_encoder=False,
        # vae ldm
        vae_p=1,
        ldm_z_channels=4,
        ldm_embed_dim=4,
        use_conf_map=False,
        # sd E, lite version by default
        sd_E_ch=64,
        z_channels=3*4,
        sd_E_num_res_blocks=1,
        num_frames=4,
        # vit_decoder
        arch_dit_decoder='DiT2-B/2',
        return_all_dit_layers=False,
        # sd D
        # sd_D_ch=32,
        # sd_D_res_blocks=1,
        # sd_D_res_blocks=1,
        lrm_decoder=False,
        plane_n=3,
        gs_rendering=False,
    )


def triplane_decoder_defaults():
    opts = dict(
        triplane_fg_bg=False,
        cfg='shapenet',
        density_reg=0.25,
        density_reg_p_dist=0.004,
        reg_type='l1',
        triplane_decoder_lr=0.0025,  # follow eg3d G lr
        super_resolution_lr=0.0025,
        # triplane_decoder_wd=0.1,
        c_scale=1,
        nsr_lr=0.02,
        triplane_size=224,
        decoder_in_chans=32,
        triplane_in_chans=-1,
        decoder_output_dim=3,
        out_chans=96,
        c_dim=25,  # Conditioning label (C) dimensionality.
        # ray_start=0.2,
        # ray_end=2.2,
        ray_start=0.6,  # shapenet default
        ray_end=1.8,
        rendering_kwargs={},
        sr_training=False,
        bcg_synthesis=False,  # from panohead
        bcg_synthesis_kwargs={},  # G_kwargs.copy()
        #
        image_size=128,  # raw 3D rendering output resolution.
        patch_rendering_resolution=45,
    )

    # else:
    #     assert False, "Need to specify config"

    # opts = dict(opts)
    # opts.pop('cfg')

    return opts


def vit_decoder_defaults():
    res = dict(
        vit_decoder_lr=1e-5,  # follow eg3d G lr
        vit_decoder_wd=0.001,
    )
    return res


def nsr_decoder_defaults():
    res = {
        'decomposed': False,
    }  # TODO, add defaults for all nsr
    res.update(triplane_decoder_defaults())  # triplane by default now
    res.update(vit_decoder_defaults())  # type: ignore
    return res


def loss_defaults():
    opt = dict(
        color_criterion='mse',
        l2_lambda=1.0,
        lpips_lambda=0.,
        lpips_delay_iter=0,
        sr_delay_iter=0,
        # kl_anneal=0,
        kl_anneal=False,
        latent_lambda=0.,
        latent_criterion='mse',
        kl_lambda=0.0,
        # kl_anneal=False,
        ssim_lambda=0.,
        l1_lambda=0.,
        id_lambda=0.0,
        depth_lambda=0.0,  # TODO
        alpha_lambda=0.0,  # TODO
        fg_mse=False,
        bg_lamdba=0.0,
        density_reg=0.0,  # tvloss in eg3d
        density_reg_p_dist=0.004,  # 'density regularization strength.'
        density_reg_every=4,  # lazy density reg

        # 3D supervision, ffhq/afhq eg3d warm up
        shape_uniform_lambda=0.005,
        shape_importance_lambda=0.01,
        shape_depth_lambda=0.,

        # gan loss
        rec_cvD_lambda=0.01,
        nvs_cvD_lambda=0.025,
        patchgan_disc_factor=0.01,
        patchgan_disc_g_weight=0.2, # 
        r1_gamma=1.0,  # ffhq default value for eg3d
        sds_lamdba=1.0,
        nvs_D_lr_mul=1,  # compared with 1e-4
        cano_D_lr_mul=1,  # compared with 1e-4

        # lsgm loss
        ce_balanced_kl=1.,
        p_eps_lambda=1,
        # symmetric loss
        symmetry_loss=False,
        depth_smoothness_lambda=0.0,
        ce_lambda=1.0,
        negative_entropy_lambda=1.0,
        grad_clip=False,
        online_mask=False,  # in unsup3d
    )
    return opt


def dataset_defaults():
    res = dict(
        use_lmdb=False,
        use_wds=False,
        use_lmdb_compressed=True,
        compile=False,
        interval=1,
        objv_dataset=False,
        decode_encode_img_only=False,
        load_wds_diff=False,
        load_wds_latent=False,
        eval_load_wds_instance=True,
        shards_lst="",
        eval_shards_lst="",
        mv_input=False,
        duplicate_sample=True,
        orthog_duplicate=False,
        split_chunk_input=False, # split=8 per chunk
        load_real=False,
        four_view_for_latent=False,
        single_view_for_i23d=False,
        shuffle_across_cls=False,
        load_extra_36_view=False,
        mv_latent_dir='',
        append_depth=False,
        plucker_embedding=False,
        gs_cam_format=False,
        split_chunk_size=8,
    )
    return res


def encoder_and_nsr_defaults():
    """
    Defaults for image training.
    """
    # ViT configs
    res = dict(
        dino_version='v1',
        encoder_in_channels=3,
        img_size=[224],
        patch_size=16,  # ViT-S/16
        in_chans=384,
        num_classes=0,
        embed_dim=384,  # Check ViT encoder dim
        depth=6,
        num_heads=16,
        mlp_ratio=4.,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.1,
        attn_drop_rate=0.,
        drop_path_rate=0.,
        norm_layer='nn.LayerNorm',
        # img_resolution=128,  # Output resolution.
        cls_token=False,
        # image_size=128,  # rendered output resolution.
        # img_channels=3,  # Number of output color channels.
        encoder_cls_token=False,
        decoder_cls_token=False,
        sr_kwargs={},
        sr_ratio=2,
        # sd configs
    )
    # Triplane configs
    res.update(model_encoder_defaults())
    res.update(nsr_decoder_defaults())
    res.update(
        ae_classname='vit.vit_triplane.ViTTriplaneDecomposed')  # if add SR
    return res


def create_3DAE_model(
        arch_encoder,
        arch_decoder,
        dino_version='v1',
        img_size=[224],
        patch_size=16,
        in_chans=384,
        num_classes=0,
        embed_dim=1024,  # Check ViT encoder dim
        depth=6,
        num_heads=16,
        mlp_ratio=4.,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.1,
        attn_drop_rate=0.,
        drop_path_rate=0.,
        # norm_layer=nn.LayerNorm,
        norm_layer='nn.LayerNorm',
        out_chans=96,
        decoder_in_chans=32,
        triplane_in_chans=-1,
        decoder_output_dim=32,
        encoder_cls_token=False,
        decoder_cls_token=False,
        c_dim=25,  # Conditioning label (C) dimensionality.
        image_size=128,  # Output resolution.
        img_channels=3,  # Number of output color channels.
        rendering_kwargs={},
        load_pretrain_encoder=False,
        decomposed=True,
        triplane_size=224,
        ae_classname='ViTTriplaneDecomposed',
        use_clip=False,
        sr_kwargs={},
        sr_ratio=2,
        no_dim_up_mlp=False,
        dim_up_mlp_as_func=False,
        decoder_load_pretrained=True,
        uvit_skip_encoder=False,
        bcg_synthesis_kwargs={},
        # decoder params
        vae_p=1,
        ldm_z_channels=4,
        ldm_embed_dim=4,
        use_conf_map=False,
        triplane_fg_bg=False,
        encoder_in_channels=3,
        sd_E_ch=64,
        z_channels=3*4,
        sd_E_num_res_blocks=1,
        num_frames=6,
        arch_dit_decoder='DiT2-B/2',
        lrm_decoder=False,
        gs_rendering=False,
        return_all_dit_layers=False,
        *args,
        **kwargs):

    # TODO, check pre-trained ViT encoder cfgs

    preprocess = None
    clip_dtype = None
    if load_pretrain_encoder:
        if not use_clip:
            if dino_version == 'v1':
                encoder = torch.hub.load(
                    'facebookresearch/dino:main',
                    'dino_{}{}'.format(arch_encoder, patch_size))
                logger.log(
                    f'loaded pre-trained dino v1 ViT-S{patch_size} encoder ckpt'
                )
            elif dino_version == 'v2':
                encoder = torch.hub.load(
                    'facebookresearch/dinov2',
                    'dinov2_{}{}'.format(arch_encoder, patch_size))
                logger.log(
                    f'loaded pre-trained dino v2 {arch_encoder}{patch_size} encoder ckpt'
                )
            elif 'sd' in dino_version:  # just for compat

                if 'mv' in dino_version:
                    if 'lgm' in dino_version:
                        encoder_cls = MVUNet(
                            input_size=256,
                            up_channels=(1024, 1024, 512, 256,
                                         128),  # one more decoder
                            up_attention=(True, True, True, False, False),
                            splat_size=128,
                            output_size=
                            512,  # render & supervise Gaussians at a higher resolution.
                            batch_size=8,
                            num_views=8,
                            gradient_accumulation_steps=1,
                            # mixed_precision='bf16',
                        )
                    elif 'gs' in dino_version:
                        encoder_cls = MVEncoder
                    else:
                        encoder_cls = MVEncoder

                else:
                    encoder_cls = Encoder

                encoder = encoder_cls(  # mono input
                    double_z=True,
                    resolution=256,
                    in_channels=encoder_in_channels,
                    # ch=128,
                    ch=64,  # ! fit in the memory
                    # ch_mult=[1,2,4,4],
                    # num_res_blocks=2,
                    ch_mult=[1, 2, 4, 4],
                    num_res_blocks=1,
                    dropout=0.0,
                    attn_resolutions=[],
                    out_ch=3,  # unused
                    z_channels=4 * 3,
                )  # stable diffusion encoder
            else:
                raise NotImplementedError()

        else:
            import clip
            model, preprocess = clip.load("ViT-B/16", device=dist_util.dev())
            model.float()  # convert weight to float32
            clip_dtype = model.dtype
            encoder = getattr(
                model, 'visual')  # only use the CLIP visual encoder here
            encoder.requires_grad_(False)
            logger.log(
                f'loaded pre-trained CLIP ViT-B{patch_size} encoder, fixed.')

    elif 'sd' in dino_version:
        attn_kwargs = {}
        if 'mv' in dino_version:
            if 'lgm' in dino_version:
                encoder = LGM_MVEncoder(
                    in_channels=9,
                    # input_size=256,
                    up_channels=(1024, 1024, 512, 256,
                                 128),  # one more decoder
                    up_attention=(True, True, True, False, False),
                )

            else:
                if 'dynaInp' in dino_version:
                    encoder_cls = MVEncoderGSDynamicInp
                else:
                    encoder_cls = MVEncoder
                attn_kwargs = {
                    'n_heads': 8,
                    'd_head': 64,
                }

        else:
            encoder_cls = Encoder

        if 'lgm' not in dino_version: # TODO, for compat now
            # st()
            encoder = encoder_cls(
                double_z=True,
                resolution=256,
                in_channels=encoder_in_channels,
                # ch=128,
                # ch=64, # ! fit in the memory
                ch=sd_E_ch,
                # ch_mult=[1,2,4,4],
                # num_res_blocks=2,
                ch_mult=[1, 2, 4, 4],
                # num_res_blocks=1,
                num_res_blocks=sd_E_num_res_blocks,
                num_frames=num_frames,
                dropout=0.0,
                attn_resolutions=[],
                out_ch=3,  # unused
                z_channels=z_channels, # 4 * 3
                attn_kwargs=attn_kwargs,
            )  # stable diffusion encoder

    else:
        encoder = vits.__dict__[arch_encoder](
            patch_size=patch_size,
            drop_path_rate=drop_path_rate,  # stochastic depth
            img_size=img_size)

    # assert decomposed
    # if decomposed:
    if triplane_in_chans == -1:
        triplane_in_chans = decoder_in_chans

    # if triplane_fg_bg:
    #     triplane_renderer_cls = Triplane_fg_bg_plane
    # else:
    triplane_renderer_cls = Triplane

    # triplane_decoder = Triplane(
    triplane_decoder = triplane_renderer_cls(
        c_dim,  # Conditioning label (C) dimensionality.
        image_size,  # Output resolution.
        img_channels,  # Number of output color channels.
        rendering_kwargs=rendering_kwargs,
        out_chans=out_chans,
        # create_triplane=True,  # compatability, remove later
        triplane_size=triplane_size,
        decoder_in_chans=triplane_in_chans,
        decoder_output_dim=decoder_output_dim,
        sr_kwargs=sr_kwargs,
        bcg_synthesis_kwargs=bcg_synthesis_kwargs,
        lrm_decoder=lrm_decoder)

    if load_pretrain_encoder:

        if dino_version == 'v1':
            vit_decoder = torch.hub.load(
                'facebookresearch/dino:main',
                'dino_{}{}'.format(arch_decoder, patch_size))
            logger.log(
                'loaded pre-trained decoder',
                "facebookresearch/dino:main', 'dino_{}{}".format(
                    arch_decoder, patch_size))
        else:

            vit_decoder = torch.hub.load(
                'facebookresearch/dinov2',
                # 'dinov2_{}{}'.format(arch_decoder, patch_size))
                'dinov2_{}{}'.format(arch_decoder, patch_size),
                pretrained=decoder_load_pretrained)
            logger.log(
                'loaded pre-trained decoder',
                "facebookresearch/dinov2', 'dinov2_{}{}".format(
                    arch_decoder,
                    patch_size), 'pretrianed=', decoder_load_pretrained)

    elif 'dit' in dino_version:
        from dit.dit_decoder import DiT2_models

        vit_decoder = DiT2_models[arch_dit_decoder](
            input_size=16,
            num_classes=0,
            learn_sigma=False,
            in_channels=embed_dim,
            mixed_prediction=False,
            context_dim=None,  # add CLIP text embedding
            roll_out=True, plane_n=4 if 
            'gs' in dino_version else 3,
            return_all_layers=return_all_dit_layers,
            )

    else:  # has bug on global token, to fix
        vit_decoder = vits.__dict__[arch_decoder](
            patch_size=patch_size,
            drop_path_rate=drop_path_rate,  # stochastic depth
            img_size=img_size)

    # decoder = ViTTriplaneDecomposed(vit_decoder, triplane_decoder)
    # if True:
    decoder_kwargs = dict(
        class_name=ae_classname,
        vit_decoder=vit_decoder,
        triplane_decoder=triplane_decoder,
        # encoder_cls_token=encoder_cls_token,
        cls_token=decoder_cls_token,
        sr_ratio=sr_ratio,
        vae_p=vae_p,
        ldm_z_channels=ldm_z_channels,
        ldm_embed_dim=ldm_embed_dim,
    )
    decoder = dnnlib.util.construct_class_by_name(**decoder_kwargs)


    # if return_encoder_decoder:
    #     return encoder, decoder, img_size[0], cls_token
    # else:

    if use_conf_map:
        confnet = ConfNet(cin=3, cout=1, nf=64, zdim=128)
    else:
        confnet = None

    auto_encoder = AE(
        encoder,
        decoder,
        img_size[0],
        encoder_cls_token,
        decoder_cls_token,
        preprocess,
        use_clip,
        dino_version,
        clip_dtype,
        no_dim_up_mlp=no_dim_up_mlp,
        dim_up_mlp_as_func=dim_up_mlp_as_func,
        uvit_skip_encoder=uvit_skip_encoder,
        confnet=confnet,
    )

    logger.log(auto_encoder)
    torch.cuda.empty_cache()

    return auto_encoder


# def create_3DAE_Diffusion_model(
#         arch_encoder,
#         arch_decoder,
#         img_size=[224],
#         patch_size=16,
#         in_chans=384,
#         num_classes=0,
#         embed_dim=1024,  # Check ViT encoder dim
#         depth=6,
#         num_heads=16,
#         mlp_ratio=4.,
#         qkv_bias=False,
#         qk_scale=None,
#         drop_rate=0.1,
#         attn_drop_rate=0.,
#         drop_path_rate=0.,
#         # norm_layer=nn.LayerNorm,
#         norm_layer='nn.LayerNorm',
#         out_chans=96,
#         decoder_in_chans=32,
#         decoder_output_dim=32,
#         cls_token=False,
#         c_dim=25,  # Conditioning label (C) dimensionality.
#         img_resolution=128,  # Output resolution.
#         img_channels=3,  # Number of output color channels.
#         rendering_kwargs={},
#         load_pretrain_encoder=False,
#         decomposed=True,
#         triplane_size=224,
#         ae_classname='ViTTriplaneDecomposed',
#         # return_encoder_decoder=False,
#         *args,
#         **kwargs
#         ):

#     # TODO, check pre-trained ViT encoder cfgs

#     encoder, decoder, img_size, cls_token = create_3DAE_model(
#         arch_encoder,
#         arch_decoder,
#         img_size,
#         patch_size,
#         in_chans,
#         num_classes,
#         embed_dim,  # Check ViT encoder dim
#         depth,
#         num_heads,
#         mlp_ratio,
#         qkv_bias,
#         qk_scale,
#         drop_rate,
#         attn_drop_rate,
#         drop_path_rate,
#         # norm_layer=nn.LayerNorm,
#         norm_layer,
#         out_chans=96,
#         decoder_in_chans=32,
#         decoder_output_dim=32,
#         cls_token=False,
#         c_dim=25,  # Conditioning label (C) dimensionality.
#         img_resolution=128,  # Output resolution.
#         img_channels=3,  # Number of output color channels.
#         rendering_kwargs={},
#         load_pretrain_encoder=False,
#         decomposed=True,
#         triplane_size=224,
#         ae_classname='ViTTriplaneDecomposed',
#         return_encoder_decoder=False,
#         *args,
#         **kwargs
#     ) # type: ignore


def create_Triplane(
        c_dim=25,  # Conditioning label (C) dimensionality.
        img_resolution=128,  # Output resolution.
        img_channels=3,  # Number of output color channels.
        rendering_kwargs={},
        decoder_output_dim=32,
        *args,
        **kwargs):

    decoder = Triplane(
        c_dim,  # Conditioning label (C) dimensionality.
        img_resolution,  # Output resolution.
        img_channels,  # Number of output color channels.
        # TODO, replace with c
        rendering_kwargs=rendering_kwargs,
        create_triplane=True,
        decoder_output_dim=decoder_output_dim)
    return decoder


def DiT_defaults():
    return {
        'dit_model': "DiT-B/16",
        'vae': "ema"
        # dit_model="DiT-XL/2",
        # dit_patch_size=8,
    }