Diffsplat / src /train_gsvae.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
import warnings
warnings.filterwarnings("ignore") # ignore all warnings
import diffusers.utils.logging as diffusion_logging
diffusion_logging.set_verbosity_error() # ignore diffusers warnings
from typing import *
from torch.nn.parallel import DistributedDataParallel
from accelerate.optimizer import AcceleratedOptimizer
from accelerate.scheduler import AcceleratedScheduler
from accelerate.data_loader import DataLoaderShard
import os
import argparse
import types
import logging
import math
import gc
from tqdm import tqdm
import wandb
import torch
import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger as get_accelerate_logger
from accelerate import DataLoaderConfiguration, DeepSpeedPlugin
from lpips import LPIPS
from src.options import opt_dict
from src.data import GObjaverseParquetDataset, ParquetChunkDataSource, MultiEpochsChunkedDataLoader, yield_forever
from src.models import GSAutoencoderKL, GSRecon, get_optimizer, get_lr_scheduler
import src.utils.util as util
import src.utils.vis_util as vis_util
from extensions.diffusers_diffsplat import MyEMAModel
def main():
PROJECT_NAME = "GSVAE"
parser = argparse.ArgumentParser(
description="Train a AutoencoderKL for 3DGS properties"
)
parser.add_argument(
"--config_file",
type=str,
required=True,
help="Path to the config file"
)
parser.add_argument(
"--tag",
type=str,
required=True,
help="Tag that refers to the current experiment"
)
parser.add_argument(
"--output_dir",
type=str,
default="out",
help="Path to the output directory"
)
parser.add_argument(
"--hdfs_dir",
type=str,
default=None,
help="Path to the HDFS directory to save checkpoints"
)
parser.add_argument(
"--wandb_token_path",
type=str,
default="wandb/token",
help="Path to the WandB login token"
)
parser.add_argument(
"--resume_from_iter",
type=int,
default=None,
help="The iteration to load the checkpoint from"
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Seed for the PRNG"
)
parser.add_argument(
"--offline_wandb",
action="store_true",
help="Use offline WandB for experiment tracking"
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="The max iteration step for training"
)
parser.add_argument(
"--max_val_steps",
type=int,
default=5,
help="The max iteration step for validation"
)
parser.add_argument(
"--num_workers",
type=int,
default=32,
help="The number of processed spawned by the batch provider"
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for the data loader"
)
parser.add_argument(
"--use_ema",
action="store_true",
help="Use EMA model for training"
)
parser.add_argument(
"--scale_lr",
action="store_true",
help="Scale lr with total batch size (base batch size: 256)"
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.,
help="Max gradient norm for gradient clipping"
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass"
)
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help="Type of mixed precision training"
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help="Enable TF32 for faster training on Ampere GPUs"
)
parser.add_argument(
"--use_deepspeed",
action="store_true",
help="Use DeepSpeed for training"
)
parser.add_argument(
"--zero_stage",
type=int,
default=1,
choices=[1, 2, 3], # https://huggingface.co/docs/accelerate/usage_guides/deepspeed
help="ZeRO stage type for DeepSpeed"
)
parser.add_argument(
"--load_pretrained_gsrecon",
type=str,
default="gsrecon_gobj265k_cnp_even4",
help="Tag of a pretrained GSRecon in this project"
)
parser.add_argument(
"--load_pretrained_gsrecon_ckpt",
type=int,
default=-1,
help="Iteration of the pretrained GSRecon checkpoint"
)
parser.add_argument(
"--load_pretrained_model",
type=str,
default=None,
help="Tag of the model pretrained in this project"
)
parser.add_argument(
"--load_pretrained_model_ckpt",
type=int,
default=-1,
help="Iteration of the pretrained model checkpoint"
)
# Parse the arguments
args, extras = parser.parse_known_args()
# Parse the config file
configs = util.get_configs(args.config_file, extras) # change yaml configs by `extras`
# Parse the option dict
opt = opt_dict[configs["opt_type"]]
if "opt" in configs:
for k, v in configs["opt"].items():
setattr(opt, k, v)
opt.__post_init__()
# Create an experiment directory using the `tag`
exp_dir = os.path.join(args.output_dir, args.tag)
ckpt_dir = os.path.join(exp_dir, "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
if args.hdfs_dir is not None:
args.project_hdfs_dir = args.hdfs_dir
args.hdfs_dir = os.path.join(args.hdfs_dir, args.tag)
os.system(f"hdfs dfs -mkdir -p {args.hdfs_dir}")
# Initialize the logger
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y/%m/%d %H:%M:%S",
level=logging.INFO
)
logger = get_accelerate_logger(__name__, log_level="INFO")
file_handler = logging.FileHandler(os.path.join(exp_dir, "log.txt")) # output to file
file_handler.setFormatter(logging.Formatter(
fmt="%(asctime)s - %(message)s",
datefmt="%Y/%m/%d %H:%M:%S"
))
logger.logger.addHandler(file_handler)
logger.logger.propagate = True # propagate to the root logger (console)
# Set DeepSpeed config
if args.use_deepspeed:
deepspeed_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
zero_stage=int(args.zero_stage),
offload_optimizer_device="cpu", # hard-coded here, TODO: make it configurable
)
else:
deepspeed_plugin = None
# Initialize the accelerator
accelerator = Accelerator(
project_dir=exp_dir,
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
split_batches=False, # batch size per GPU
dataloader_config=DataLoaderConfiguration(non_blocking=args.pin_memory),
deepspeed_plugin=deepspeed_plugin,
)
logger.info(f"Accelerator state:\n{accelerator.state}\n")
# Set the random seed
if args.seed >= 0:
accelerate.utils.set_seed(args.seed)
logger.info(f"You have chosen to seed([{args.seed}]) the experiment [{args.tag}]\n")
# Enable TF32 for faster training on Ampere GPUs
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# Prepare dataset
if accelerator.is_local_main_process:
if not os.path.exists("/tmp/test_dataset"):
os.system(opt.dataset_setup_script)
accelerator.wait_for_everyone() # other processes wait for the main process
# Load the training and validation dataset
assert opt.file_dir_train is not None and opt.file_name_train is not None and \
opt.file_dir_test is not None and opt.file_name_test is not None
train_dataset = GObjaverseParquetDataset(
data_source=ParquetChunkDataSource(opt.file_dir_train, opt.file_name_train),
shuffle=True,
shuffle_buffer_size=-1, # `-1`: not shuffle actually
chunks_queue_max_size=1, # number of preloading chunks
# GObjaverse
opt=opt,
training=True,
)
val_dataset = GObjaverseParquetDataset(
data_source=ParquetChunkDataSource(opt.file_dir_test, opt.file_name_test),
shuffle=True, # shuffle for various visualization
shuffle_buffer_size=-1, # `-1`: not shuffle actually
chunks_queue_max_size=1, # number of preloading chunks
# GObjaverse
opt=opt,
training=False,
)
train_loader = MultiEpochsChunkedDataLoader(
train_dataset,
batch_size=configs["train"]["batch_size_per_gpu"],
num_workers=args.num_workers,
drop_last=True,
pin_memory=args.pin_memory,
)
val_loader = MultiEpochsChunkedDataLoader(
val_dataset,
batch_size=configs["val"]["batch_size_per_gpu"],
num_workers=args.num_workers,
drop_last=True,
pin_memory=args.pin_memory,
)
logger.info(f"Load [{len(train_dataset)}] training samples and [{len(val_dataset)}] validation samples\n")
# Compute the effective batch size and scale learning rate
total_batch_size = configs["train"]["batch_size_per_gpu"] * \
accelerator.num_processes * args.gradient_accumulation_steps
configs["train"]["total_batch_size"] = total_batch_size
if args.scale_lr:
configs["optimizer"]["lr"] *= (total_batch_size / 256)
configs["lr_scheduler"]["max_lr"] = configs["optimizer"]["lr"]
# LPIPS loss
if accelerator.is_main_process:
_ = LPIPS(net="vgg")
del _
accelerator.wait_for_everyone() # wait for pretrained backbone weights to be downloaded
lpips_loss = LPIPS(net="vgg").to(accelerator.device)
lpips_loss = lpips_loss.requires_grad_(False)
lpips_loss.eval()
# GSRecon
gsrecon = GSRecon(opt).to(accelerator.device)
gsrecon = gsrecon.requires_grad_(False)
gsrecon = gsrecon.eval()
# For DeepSpeed bug: model inputs could be `torch.nn.Module` (e.g., `lpips_loss`)
def is_floating_point(self):
return True
lpips_loss.is_floating_point = types.MethodType(is_floating_point, lpips_loss)
gsrecon.is_floating_point = types.MethodType(is_floating_point, gsrecon)
# Initialize the model, optimizer and lr scheduler
model = GSAutoencoderKL(opt)
if opt.use_tiny_decoder:
model.vae.requires_grad_(False)
else:
model.tiny_decoder.requires_grad_(False)
params_to_optimize = filter(lambda p: p.requires_grad, model.parameters())
optimizer = get_optimizer(params=params_to_optimize, **configs["optimizer"])
configs["lr_scheduler"]["total_steps"] = configs["train"]["epochs"] * math.ceil(
len(train_loader) // accelerator.num_processes / args.gradient_accumulation_steps) # only account updated steps
configs["lr_scheduler"]["total_steps"] *= accelerator.num_processes # for lr scheduler setting
if "num_warmup_steps" in configs["lr_scheduler"]:
configs["lr_scheduler"]["num_warmup_steps"] *= accelerator.num_processes # for lr scheduler setting
lr_scheduler = get_lr_scheduler(optimizer=optimizer, **configs["lr_scheduler"])
configs["lr_scheduler"]["total_steps"] //= accelerator.num_processes # reset for multi-gpu
if "num_warmup_steps" in configs["lr_scheduler"]:
configs["lr_scheduler"]["num_warmup_steps"] //= accelerator.num_processes # reset for multi-gpu
# Load pretrained reconstruction and vae models
if args.load_pretrained_model is not None:
logger.info(f"Load GSVAE checkpoint [{args.load_pretrained_model}] from iteration [{args.load_pretrained_model_ckpt:06d}]\n")
model = util.load_ckpt(
os.path.join(args.output_dir, args.load_pretrained_model, "checkpoints"),
args.load_pretrained_model_ckpt,
None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_model),
model, accelerator
)
logger.info(f"Load GSRecon checkpoint [{args.load_pretrained_gsrecon}] from iteration [{args.load_pretrained_gsrecon_ckpt:06d}]\n")
gsrecon = util.load_ckpt(
os.path.join(args.output_dir, args.load_pretrained_gsrecon, "checkpoints"),
args.load_pretrained_gsrecon_ckpt,
None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_gsrecon),
gsrecon, accelerator
)
# Initialize the EMA model to save moving average states
if args.use_ema:
logger.info("Use exponential moving average (EMA) for model parameters\n")
ema_states = MyEMAModel(
model.parameters(),
**configs["train"]["ema_kwargs"]
)
ema_states.to(accelerator.device)
# Prepare everything with `accelerator`
model, optimizer, lr_scheduler, train_loader, val_loader = accelerator.prepare(
model, optimizer, lr_scheduler, train_loader, val_loader
)
# Set classes explicitly for everything
model: DistributedDataParallel
optimizer: AcceleratedOptimizer
lr_scheduler: AcceleratedScheduler
train_loader: DataLoaderShard
val_loader: DataLoaderShard
# Cast input dataset to the appropriate dtype
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Training configs after distribution and accumulation setup
updated_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
total_updated_steps = configs["lr_scheduler"]["total_steps"]
if args.max_train_steps is None:
args.max_train_steps = total_updated_steps
assert configs["train"]["epochs"] * updated_steps_per_epoch == total_updated_steps
logger.info(f"Total batch size: [{total_batch_size}]")
logger.info(f"Learning rate: [{configs['optimizer']['lr']}]")
logger.info(f"Gradient Accumulation steps: [{args.gradient_accumulation_steps}]")
logger.info(f"Total epochs: [{configs['train']['epochs']}]")
logger.info(f"Total steps: [{total_updated_steps}]")
logger.info(f"Steps for updating per epoch: [{updated_steps_per_epoch}]")
logger.info(f"Steps for validation: [{len(val_loader)}]\n")
# (Optional) Load checkpoint
global_update_step = 0
if args.resume_from_iter is not None:
logger.info(f"Load checkpoint from iteration [{args.resume_from_iter}]\n")
# Download from HDFS
if not os.path.exists(os.path.join(ckpt_dir, f'{args.resume_from_iter:06d}')):
args.resume_from_iter = util.load_ckpt(
ckpt_dir,
args.resume_from_iter,
args.hdfs_dir,
None, # `None`: not load model ckpt here
accelerator, # manage the process states
)
# Load everything
accelerator.load_state(os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}")) # torch < 2.4.0 here for `weights_only=False`
if args.use_ema:
ema_states.load_state_dict(torch.load(
os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}", "ema_states.pth"),
map_location=accelerator.device
))
global_update_step = int(args.resume_from_iter)
# Save all experimental parameters and model architecture of this run to a file (args and configs)
if accelerator.is_main_process:
exp_params = util.save_experiment_params(args, configs, opt, exp_dir)
util.save_model_architecture(accelerator.unwrap_model(model), exp_dir)
# WandB logger
if accelerator.is_main_process:
if args.offline_wandb:
os.environ["WANDB_MODE"] = "offline"
with open(args.wandb_token_path, "r") as f:
os.environ["WANDB_API_KEY"] = f.read().strip()
wandb.init(
project=PROJECT_NAME, name=args.tag,
config=exp_params, dir=exp_dir,
resume=True
)
# Wandb artifact for logging experiment information
arti_exp_info = wandb.Artifact(args.tag, type="exp_info")
arti_exp_info.add_file(os.path.join(exp_dir, "params.yaml"))
arti_exp_info.add_file(os.path.join(exp_dir, "model.txt"))
arti_exp_info.add_file(os.path.join(exp_dir, "log.txt")) # only save the log before training
wandb.log_artifact(arti_exp_info)
# Start training
logger.logger.propagate = False # not propagate to the root logger (console)
progress_bar = tqdm(
range(total_updated_steps),
initial=global_update_step,
desc="Training",
ncols=125,
disable=not accelerator.is_main_process
)
for batch in yield_forever(train_loader):
if global_update_step == args.max_train_steps:
progress_bar.close()
logger.logger.propagate = True # propagate to the root logger (console)
if accelerator.is_main_process:
wandb.finish()
logger.info("Training finished!\n")
return
model.train()
with accelerator.accumulate(model):
outputs = model(batch, lpips_loss, gsrecon, step=global_update_step+1, dtype=weight_dtype,
use_tiny_decoder=opt.use_tiny_decoder) # `step` starts from 1
psnr = outputs["psnr"]
ssim = outputs["ssim"]
lpips = outputs["lpips"]
kl = outputs["kl"]
loss = outputs["loss"]
if "coord_mse" in outputs:
coord_mse = outputs["coord_mse"]
else:
coord_mse = None
if "normal_cosim" in outputs:
normal_cosim = outputs["normal_cosim"]
else:
normal_cosim = None
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# Gather the losses across all processes for logging (if we use distributed training)
psnr = accelerator.gather(psnr.detach()).mean()
ssim = accelerator.gather(ssim.detach()).mean()
lpips = accelerator.gather(lpips.detach()).mean()
kl = accelerator.gather(kl.detach()).mean()
loss = accelerator.gather(loss.detach()).mean()
if coord_mse is not None:
coord_mse = accelerator.gather(coord_mse.detach()).mean()
if normal_cosim is not None:
normal_cosim = accelerator.gather(normal_cosim.detach()).mean()
logs = {
"psnr": psnr.item(),
"ssim": ssim.item(),
"lpips": lpips.item(),
"loss": loss.item(),
"lr": lr_scheduler.get_last_lr()[0]
}
if args.use_ema:
ema_states.step(model.parameters())
logs.update({"ema": ema_states.cur_decay_value})
progress_bar.set_postfix(**logs)
progress_bar.update(1)
global_update_step += 1
logger.info(
f"[{global_update_step:06d} / {total_updated_steps:06d}] " +
f"psnr: {logs['psnr']:.4f}, ssim: {logs['ssim']:.4f}, lpips: {logs['lpips']:.4f}, " +
f"loss: {logs['loss']:.4f}, lr: {logs['lr']:.2e}" +
f", ema: {logs['ema']:.4f}" if args.use_ema else ""
)
# Log the training progress
if global_update_step % configs["train"]["log_freq"] == 0 or global_update_step == 1 \
or global_update_step % updated_steps_per_epoch == 0: # last step of an epoch
if accelerator.is_main_process:
wandb.log({
"training/psnr": psnr.item(),
"training/ssim": ssim.item(),
"training/lpips": lpips.item(),
"training/kl": kl.item(),
"training/loss": loss.item(),
"training/lr": lr_scheduler.get_last_lr()[0]
}, step=global_update_step)
if coord_mse is not None:
wandb.log({
"training/coord_mse": coord_mse.item()
}, step=global_update_step)
if normal_cosim is not None:
wandb.log({
"training/normal_cosim": normal_cosim.item()
}, step=global_update_step)
if args.use_ema:
wandb.log({
"training/ema": ema_states.cur_decay_value
}, step=global_update_step)
# Save checkpoint
if (global_update_step % configs["train"]["save_freq"] == 0 # 1. every `save_freq` steps
or global_update_step % (configs["train"]["save_freq_epoch"] * updated_steps_per_epoch) == 0 # 2. every `save_freq_epoch` epochs
or global_update_step == total_updated_steps): # 3. last step of an epoch
gc.collect()
if accelerator.distributed_type == accelerate.utils.DistributedType.DEEPSPEED:
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues
accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}"))
elif accelerator.is_main_process:
accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}"))
accelerator.wait_for_everyone() # ensure all processes have finished saving
if accelerator.is_main_process:
if args.use_ema:
torch.save(ema_states.state_dict(),
os.path.join(ckpt_dir, f"{global_update_step:06d}", "ema_states.pth"))
if args.hdfs_dir is not None:
util.save_ckpt(ckpt_dir, global_update_step, args.hdfs_dir)
gc.collect()
# Evaluate on the validation set
if (global_update_step == 1
or (global_update_step % configs["train"]["early_eval_freq"] == 0 and
global_update_step < configs["train"]["early_eval"]) # 1. more frequently at the beginning
or global_update_step % configs["train"]["eval_freq"] == 0 # 2. every `eval_freq` steps
or global_update_step % (configs["train"]["eval_freq_epoch"] * updated_steps_per_epoch) == 0 # 3. every `eval_freq_epoch` epochs
or global_update_step == total_updated_steps): # 4. last step of an epoch
torch.cuda.empty_cache()
gc.collect()
# Use EMA parameters for evaluation
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference
ema_states.store(model.parameters())
ema_states.copy_to(model.parameters())
with torch.no_grad():
with torch.autocast("cuda", torch.bfloat16):
model.eval()
all_val_matrics, val_steps = {}, 0
val_progress_bar = tqdm(
range(len(val_loader)) if args.max_val_steps is None \
else range(args.max_val_steps),
desc="Validation",
ncols=125,
disable=not accelerator.is_main_process
)
for val_batch in val_loader:
val_outputs = model(val_batch, lpips_loss, gsrecon, step=global_update_step, dtype=weight_dtype,
use_tiny_decoder=opt.use_tiny_decoder)
val_psnr = val_outputs["psnr"]
val_ssim = val_outputs["ssim"]
val_lpips = val_outputs["lpips"]
val_kl = val_outputs["kl"]
val_loss = val_outputs["loss"]
if "coord_mse" in val_outputs:
val_coord_mse = val_outputs["coord_mse"]
else:
val_coord_mse = None
if "normal_cosim" in val_outputs:
val_normal_cosim = val_outputs["normal_cosim"]
else:
val_normal_cosim = None
val_psnr = accelerator.gather_for_metrics(val_psnr).mean()
val_ssim = accelerator.gather_for_metrics(val_ssim).mean()
val_lpips = accelerator.gather_for_metrics(val_lpips).mean()
val_kl = accelerator.gather_for_metrics(val_kl).mean()
val_loss = accelerator.gather_for_metrics(val_loss).mean()
if val_coord_mse is not None:
val_coord_mse = accelerator.gather_for_metrics(val_coord_mse).mean()
if val_normal_cosim is not None:
val_normal_cosim = accelerator.gather_for_metrics(val_normal_cosim).mean()
val_logs = {
"psnr": val_psnr.item(),
"ssim": val_ssim.item(),
"lpips": val_lpips.item(),
"loss": val_loss.item()
}
val_progress_bar.set_postfix(**val_logs)
val_progress_bar.update(1)
val_steps += 1
all_val_matrics.setdefault("psnr", []).append(val_psnr)
all_val_matrics.setdefault("ssim", []).append(val_ssim)
all_val_matrics.setdefault("lpips", []).append(val_lpips)
all_val_matrics.setdefault("kl", []).append(val_kl)
all_val_matrics.setdefault("loss", []).append(val_loss)
if val_coord_mse is not None:
all_val_matrics.setdefault("coord_mse", []).append(val_coord_mse)
if val_normal_cosim is not None:
all_val_matrics.setdefault("normal_cosim", []).append(val_normal_cosim)
if args.max_val_steps is not None and val_steps == args.max_val_steps:
break
torch.cuda.empty_cache()
gc.collect()
val_progress_bar.close()
if args.use_ema:
# Switch back to the original model parameters
ema_states.restore(model.parameters())
for k, v in all_val_matrics.items():
all_val_matrics[k] = torch.tensor(v).mean()
logger.info(
f"Eval [{global_update_step:06d} / {total_updated_steps:06d}] " +
f"psnr: {all_val_matrics['psnr'].item():.4f}, " +
f"ssim: {all_val_matrics['ssim'].item():.4f}, " +
f"lpips: {all_val_matrics['lpips'].item():.4f}, " +
f"kl: {all_val_matrics['kl'].item():.4f}, " +
f"loss: {all_val_matrics['loss'].item():.4f}\n"
)
if accelerator.is_main_process:
wandb.log({
"validation/psnr": all_val_matrics["psnr"].item(),
"validation/ssim": all_val_matrics["ssim"].item(),
"validation/lpips": all_val_matrics["lpips"].item(),
"validation/kl": all_val_matrics["kl"].item(),
"validation/loss": all_val_matrics["loss"].item()
}, step=global_update_step)
if "coord_mse" in all_val_matrics:
wandb.log({
"validation/coord_mse": all_val_matrics["coord_mse"].item()
}, step=global_update_step)
if "normal_cosim" in all_val_matrics:
wandb.log({
"validation/normal_cosim": all_val_matrics["normal_cosim"].item()
}, step=global_update_step)
# Visualize rendering
wandb.log({
"images/training": vis_util.wandb_mvimage_log(outputs)
}, step=global_update_step)
wandb.log({
"images/validation": vis_util.wandb_mvimage_log(val_outputs)
}, step=global_update_step)
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
main()