Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| # import imageio.v3 | |
| import functools | |
| import json | |
| import os | |
| from pathlib import Path | |
| from pdb import set_trace as st | |
| from einops import rearrange | |
| import webdataset as wds | |
| import traceback | |
| import blobfile as bf | |
| import imageio | |
| import numpy as np | |
| # from sympy import O | |
| import torch as th | |
| import torch.distributed as dist | |
| import torchvision | |
| from PIL import Image | |
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
| from torch.optim import AdamW | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from guided_diffusion import dist_util, logger | |
| from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
| from guided_diffusion.nn import update_ema | |
| from guided_diffusion.resample import LossAwareSampler, UniformSampler | |
| from guided_diffusion.train_util import (calc_average_loss, | |
| find_ema_checkpoint, | |
| find_resume_checkpoint, | |
| get_blob_logdir, log_rec3d_loss_dict, | |
| parse_resume_step_from_filename) | |
| from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
| from .train_util import TrainLoop3DRec | |
| class TrainLoop3DRecNV(TrainLoop3DRec): | |
| # supervise the training of novel view | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| load_submodule_name='', | |
| ignore_resume_opt=False, | |
| model_name='rec', | |
| use_amp=False, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| load_submodule_name=load_submodule_name, | |
| ignore_resume_opt=ignore_resume_opt, | |
| model_name=model_name, | |
| use_amp=use_amp, | |
| **kwargs) | |
| self.rec_cano = True | |
| def forward_backward(self, batch, *args, **kwargs): | |
| # return super().forward_backward(batch, *args, **kwargs) | |
| self.mp_trainer_rec.zero_grad() | |
| batch_size = batch['img_to_encoder'].shape[0] | |
| for i in range(0, batch_size, self.microbatch): | |
| # st() | |
| micro = { | |
| k: v[i:i + self.microbatch].to(dist_util.dev()) | |
| for k, v in batch.items() | |
| } | |
| # ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first. | |
| # wrap forward within amp | |
| with th.autocast(device_type='cuda', | |
| dtype=th.float16, | |
| enabled=self.mp_trainer_rec.use_amp): | |
| target_nvs = {} | |
| target_cano = {} | |
| latent = self.rec_model(img=micro['img_to_encoder'], | |
| behaviour='enc_dec_wo_triplane') | |
| pred = self.rec_model( | |
| latent=latent, | |
| c=micro['nv_c'], # predict novel view here | |
| behaviour='triplane_dec') | |
| for k, v in micro.items(): | |
| if k[:2] == 'nv': | |
| orig_key = k.replace('nv_', '') | |
| target_nvs[orig_key] = v | |
| target_cano[orig_key] = micro[orig_key] | |
| with self.rec_model.no_sync(): # type: ignore | |
| loss, loss_dict, fg_mask = self.loss_class( | |
| pred, | |
| target_nvs, | |
| step=self.step + self.resume_step, | |
| test_mode=False, | |
| return_fg_mask=True, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None) | |
| log_rec3d_loss_dict(loss_dict) | |
| if self.rec_cano: | |
| pred_cano = self.rec_model(latent=latent, | |
| c=micro['c'], | |
| behaviour='triplane_dec') | |
| with self.rec_model.no_sync(): # type: ignore | |
| fg_mask = target_cano['depth_mask'].unsqueeze( | |
| 1).repeat_interleave(3, 1).float() | |
| loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( | |
| pred_cano['image_raw'], | |
| target_cano['img'], | |
| fg_mask, | |
| step=self.step + self.resume_step, | |
| test_mode=False, | |
| ) | |
| loss = loss + loss_cano | |
| # remove redundant log | |
| log_rec3d_loss_dict({ | |
| f'cano_{k}': v | |
| for k, v in loss_cano_dict.items() | |
| # if "loss" in k | |
| }) | |
| self.mp_trainer_rec.backward(loss) | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| if self.rec_cano: | |
| self.log_img(micro, pred, pred_cano) | |
| else: | |
| self.log_img(micro, pred, None) | |
| def log_img(self, micro, pred, pred_cano): | |
| # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
| def norm_depth(pred_depth): # to [-1,1] | |
| # pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| return -(pred_depth * 2 - 1) | |
| pred_img = pred['image_raw'] | |
| gt_img = micro['img'] | |
| # infer novel view also | |
| # if self.loss_class.opt.symmetry_loss: | |
| # pred_nv_img = nvs_pred | |
| # else: | |
| # ! replace with novel view prediction | |
| # ! log another novel-view prediction | |
| # pred_nv_img = self.rec_model( | |
| # img=micro['img_to_encoder'], | |
| # c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
| # if 'depth' in micro: | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = norm_depth(gt_depth) | |
| # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
| # gt_depth.min()) | |
| # if True: | |
| fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
| input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
| if 'image_depth' in pred: | |
| pred_depth = norm_depth(pred['image_depth']) | |
| pred_nv_depth = norm_depth(pred_cano['image_depth']) | |
| else: | |
| pred_depth = th.zeros_like(gt_depth) | |
| pred_nv_depth = th.zeros_like(gt_depth) | |
| if 'image_sr' in pred: | |
| if pred['image_sr'].shape[-1] == 512: | |
| pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
| dim=-1) | |
| gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
| dim=-1) | |
| pred_depth = self.pool_512(pred_depth) | |
| gt_depth = self.pool_512(gt_depth) | |
| elif pred['image_sr'].shape[-1] == 256: | |
| pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
| dim=-1) | |
| gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
| dim=-1) | |
| pred_depth = self.pool_256(pred_depth) | |
| gt_depth = self.pool_256(gt_depth) | |
| else: | |
| pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
| dim=-1) | |
| gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
| dim=-1) | |
| gt_depth = self.pool_128(gt_depth) | |
| pred_depth = self.pool_128(pred_depth) | |
| else: | |
| gt_img = self.pool_64(gt_img) | |
| gt_depth = self.pool_64(gt_depth) | |
| pred_vis = th.cat([ | |
| pred_img, | |
| pred_depth.repeat_interleave(3, dim=1), | |
| fg_mask.repeat_interleave(3, dim=1), | |
| ], | |
| dim=-1) # B, 3, H, W | |
| pred_vis_nv = th.cat([ | |
| pred_cano['image_raw'], | |
| pred_nv_depth.repeat_interleave(3, dim=1), | |
| input_fg_mask.repeat_interleave(3, dim=1), | |
| ], | |
| dim=-1) # B, 3, H, W | |
| pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
| gt_vis = th.cat([ | |
| gt_img, | |
| gt_depth.repeat_interleave(3, dim=1), | |
| th.zeros_like(gt_img) | |
| ], | |
| dim=-1) # TODO, fail to load depth. range [0, 1] | |
| if 'conf_sigma' in pred: | |
| gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
| # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
| vis = th.cat([gt_vis, pred_vis], dim=-2) | |
| # .permute( | |
| # 0, 2, 3, 1).cpu() | |
| vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
| 64) # HWC | |
| torchvision.utils.save_image( | |
| vis_tensor, | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
| value_range=(-1, 1), | |
| normalize=True) | |
| # vis = vis.numpy() * 127.5 + 127.5 | |
| # vis = vis.clip(0, 255).astype(np.uint8) | |
| # Image.fromarray(vis).save( | |
| # f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| logger.log('log vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| # self.writer.add_image(f'images', | |
| # vis, | |
| # self.step + self.resume_step, | |
| # dataformats='HWC') | |
| # return pred | |
| class TrainLoop3DRecNVPatch(TrainLoop3DRecNV): | |
| # add patch rendering | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| load_submodule_name='', | |
| ignore_resume_opt=False, | |
| model_name='rec', | |
| use_amp=False, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| load_submodule_name=load_submodule_name, | |
| ignore_resume_opt=ignore_resume_opt, | |
| model_name=model_name, | |
| use_amp=use_amp, | |
| **kwargs) | |
| # the rendrer | |
| self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore | |
| # self.rec_cano = False | |
| self.rec_cano = True | |
| def forward_backward(self, batch, *args, **kwargs): | |
| # add patch sampling | |
| self.mp_trainer_rec.zero_grad() | |
| batch_size = batch['img_to_encoder'].shape[0] | |
| for i in range(0, batch_size, self.microbatch): | |
| micro = { | |
| k: v[i:i + self.microbatch].to(dist_util.dev()) | |
| for k, v in batch.items() | |
| } | |
| # ! sample rendering patch | |
| target = { | |
| **self.eg3d_model( | |
| c=micro['nv_c'], # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=micro['nv_bbox']), # rays o / dir | |
| } | |
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
| 'patch_rendering_resolution'] # type: ignore | |
| cropped_target = { | |
| k: | |
| th.empty_like(v) | |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] | |
| if k not in [ | |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
| 'nv_img_sr', 'c' | |
| ] else v | |
| for k, v in micro.items() | |
| } | |
| # crop according to uv sampling | |
| for j in range(micro['img'].shape[0]): | |
| top, left, height, width = target['ray_bboxes'][ | |
| j] # list of tuple | |
| # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
| # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
| # cropped_target[key][ | |
| # j:j + 1] = torchvision.transforms.functional.crop( | |
| # micro[key][j:j + 1], top, left, height, width) | |
| cropped_target[f'{key}'][ # ! no nv_ here | |
| j:j + 1] = torchvision.transforms.functional.crop( | |
| micro[f'nv_{key}'][j:j + 1], top, left, height, | |
| width) | |
| # target.update(cropped_target) | |
| # wrap forward within amp | |
| with th.autocast(device_type='cuda', | |
| dtype=th.float16, | |
| enabled=self.mp_trainer_rec.use_amp): | |
| # target_nvs = {} | |
| # target_cano = {} | |
| latent = self.rec_model(img=micro['img_to_encoder'], | |
| behaviour='enc_dec_wo_triplane') | |
| pred_nv = self.rec_model( | |
| latent=latent, | |
| c=micro['nv_c'], # predict novel view here | |
| behaviour='triplane_dec', | |
| ray_origins=target['ray_origins'], | |
| ray_directions=target['ray_directions'], | |
| ) | |
| # ! directly retrieve from target | |
| # for k, v in target.items(): | |
| # if k[:2] == 'nv': | |
| # orig_key = k.replace('nv_', '') | |
| # target_nvs[orig_key] = v | |
| # target_cano[orig_key] = target[orig_key] | |
| with self.rec_model.no_sync(): # type: ignore | |
| loss, loss_dict, _ = self.loss_class(pred_nv, | |
| cropped_target, | |
| step=self.step + | |
| self.resume_step, | |
| test_mode=False, | |
| return_fg_mask=True, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None) | |
| log_rec3d_loss_dict(loss_dict) | |
| if self.rec_cano: | |
| cano_target = { | |
| **self.eg3d_model( | |
| c=micro['c'], # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=micro['bbox']), # rays o / dir | |
| } | |
| cano_cropped_target = { | |
| k: th.empty_like(v) | |
| for k, v in cropped_target.items() | |
| } | |
| for j in range(micro['img'].shape[0]): | |
| top, left, height, width = cano_target['ray_bboxes'][ | |
| j] # list of tuple | |
| # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| for key in ('img', 'depth_mask', | |
| 'depth'): # type: ignore | |
| # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
| cano_cropped_target[key][ | |
| j:j + | |
| 1] = torchvision.transforms.functional.crop( | |
| micro[key][j:j + 1], top, left, height, | |
| width) | |
| # cano_target.update(cano_cropped_target) | |
| pred_cano = self.rec_model( | |
| latent=latent, | |
| c=micro['c'], | |
| behaviour='triplane_dec', | |
| ray_origins=cano_target['ray_origins'], | |
| ray_directions=cano_target['ray_directions'], | |
| ) | |
| with self.rec_model.no_sync(): # type: ignore | |
| fg_mask = cano_cropped_target['depth_mask'].unsqueeze( | |
| 1).repeat_interleave(3, 1).float() | |
| loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( | |
| pred_cano['image_raw'], | |
| cano_cropped_target['img'], | |
| fg_mask, | |
| step=self.step + self.resume_step, | |
| test_mode=False, | |
| ) | |
| loss = loss + loss_cano | |
| # remove redundant log | |
| log_rec3d_loss_dict({ | |
| f'cano_{k}': v | |
| for k, v in loss_cano_dict.items() | |
| # if "loss" in k | |
| }) | |
| self.mp_trainer_rec.backward(loss) | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| self.log_patch_img(cropped_target, pred_nv, pred_cano) | |
| def log_patch_img(self, micro, pred, pred_cano): | |
| # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
| def norm_depth(pred_depth): # to [-1,1] | |
| # pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| return -(pred_depth * 2 - 1) | |
| pred_img = pred['image_raw'] | |
| gt_img = micro['img'] | |
| # infer novel view also | |
| # if self.loss_class.opt.symmetry_loss: | |
| # pred_nv_img = nvs_pred | |
| # else: | |
| # ! replace with novel view prediction | |
| # ! log another novel-view prediction | |
| # pred_nv_img = self.rec_model( | |
| # img=micro['img_to_encoder'], | |
| # c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
| # if 'depth' in micro: | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = norm_depth(gt_depth) | |
| # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
| # gt_depth.min()) | |
| # if True: | |
| fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
| input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
| if 'image_depth' in pred: | |
| pred_depth = norm_depth(pred['image_depth']) | |
| pred_cano_depth = norm_depth(pred_cano['image_depth']) | |
| else: | |
| pred_depth = th.zeros_like(gt_depth) | |
| pred_cano_depth = th.zeros_like(gt_depth) | |
| # if 'image_sr' in pred: | |
| # if pred['image_sr'].shape[-1] == 512: | |
| # pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
| # dim=-1) | |
| # gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
| # dim=-1) | |
| # pred_depth = self.pool_512(pred_depth) | |
| # gt_depth = self.pool_512(gt_depth) | |
| # elif pred['image_sr'].shape[-1] == 256: | |
| # pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
| # dim=-1) | |
| # gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
| # dim=-1) | |
| # pred_depth = self.pool_256(pred_depth) | |
| # gt_depth = self.pool_256(gt_depth) | |
| # else: | |
| # pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
| # dim=-1) | |
| # gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
| # dim=-1) | |
| # gt_depth = self.pool_128(gt_depth) | |
| # pred_depth = self.pool_128(pred_depth) | |
| # else: | |
| # gt_img = self.pool_64(gt_img) | |
| # gt_depth = self.pool_64(gt_depth) | |
| pred_vis = th.cat([ | |
| pred_img, | |
| pred_depth.repeat_interleave(3, dim=1), | |
| fg_mask.repeat_interleave(3, dim=1), | |
| ], | |
| dim=-1) # B, 3, H, W | |
| pred_vis_nv = th.cat([ | |
| pred_cano['image_raw'], | |
| pred_cano_depth.repeat_interleave(3, dim=1), | |
| input_fg_mask.repeat_interleave(3, dim=1), | |
| ], | |
| dim=-1) # B, 3, H, W | |
| pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
| gt_vis = th.cat([ | |
| gt_img, | |
| gt_depth.repeat_interleave(3, dim=1), | |
| th.zeros_like(gt_img) | |
| ], | |
| dim=-1) # TODO, fail to load depth. range [0, 1] | |
| # if 'conf_sigma' in pred: | |
| # gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
| # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
| # st() | |
| vis = th.cat([gt_vis, pred_vis], dim=-2) | |
| # .permute( | |
| # 0, 2, 3, 1).cpu() | |
| vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
| 64) # HWC | |
| torchvision.utils.save_image( | |
| vis_tensor, | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
| value_range=(-1, 1), | |
| normalize=True) | |
| logger.log('log vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| # self.writer.add_image(f'images', | |
| # vis, | |
| # self.step + self.resume_step, | |
| # dataformats='HWC') | |
| class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch): | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| load_submodule_name='', | |
| ignore_resume_opt=False, | |
| model_name='rec', | |
| use_amp=False, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| load_submodule_name=load_submodule_name, | |
| ignore_resume_opt=ignore_resume_opt, | |
| model_name=model_name, | |
| use_amp=use_amp, | |
| **kwargs) | |
| def forward_backward(self, batch, *args, **kwargs): | |
| # add patch sampling | |
| self.mp_trainer_rec.zero_grad() | |
| batch_size = batch['img_to_encoder'].shape[0] | |
| batch.pop('caption') # not required | |
| batch.pop('ins') # not required | |
| # batch.pop('nv_caption') # not required | |
| for i in range(0, batch_size, self.microbatch): | |
| micro = { | |
| k: | |
| v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
| v, th.Tensor) else v[i:i + self.microbatch] | |
| for k, v in batch.items() | |
| } | |
| # ! sample rendering patch | |
| target = { | |
| **self.eg3d_model( | |
| c=micro['nv_c'], # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=micro['nv_bbox']), # rays o / dir | |
| } | |
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
| 'patch_rendering_resolution'] # type: ignore | |
| cropped_target = { | |
| k: | |
| th.empty_like(v) | |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] | |
| if k not in [ | |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
| 'nv_img_sr', 'c', 'caption', 'nv_caption' | |
| ] else v | |
| for k, v in micro.items() | |
| } | |
| # crop according to uv sampling | |
| for j in range(micro['img'].shape[0]): | |
| top, left, height, width = target['ray_bboxes'][ | |
| j] # list of tuple | |
| # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
| # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
| # cropped_target[key][ | |
| # j:j + 1] = torchvision.transforms.functional.crop( | |
| # micro[key][j:j + 1], top, left, height, width) | |
| cropped_target[f'{key}'][ # ! no nv_ here | |
| j:j + 1] = torchvision.transforms.functional.crop( | |
| micro[f'nv_{key}'][j:j + 1], top, left, height, | |
| width) | |
| # ! cano view loss | |
| cano_target = { | |
| **self.eg3d_model( | |
| c=micro['c'], # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=micro['bbox']), # rays o / dir | |
| } | |
| # cano_cropped_target = { | |
| # k: th.empty_like(v) | |
| # for k, v in cropped_target.items() | |
| # } | |
| # for j in range(micro['img'].shape[0]): | |
| # top, left, height, width = cano_target['ray_bboxes'][ | |
| # j] # list of tuple | |
| # # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| # for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
| # # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
| # cano_cropped_target[key][ | |
| # j:j + 1] = torchvision.transforms.functional.crop( | |
| # micro[key][j:j + 1], top, left, height, width) | |
| # ! vit no amp | |
| latent = self.rec_model(img=micro['img_to_encoder'], | |
| behaviour='enc_dec_wo_triplane') | |
| # wrap forward within amp | |
| with th.autocast(device_type='cuda', | |
| dtype=th.float16, | |
| enabled=self.mp_trainer_rec.use_amp): | |
| # c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here | |
| # c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here | |
| instance_mv_num = batch_size // 4 # 4 pairs by default | |
| # instance_mv_num = 4 | |
| # ! roll views for multi-view supervision | |
| c = th.cat([ | |
| micro['nv_c'].roll(instance_mv_num * i, dims=0) | |
| for i in range(1, 4) | |
| ] | |
| # + [micro['c']] | |
| ) # predict novel view here | |
| ray_origins = th.cat( | |
| [ | |
| target['ray_origins'].roll(instance_mv_num * i, dims=0) | |
| for i in range(1, 4) | |
| ] | |
| # + [cano_target['ray_origins'] ] | |
| , | |
| 0) | |
| ray_directions = th.cat([ | |
| target['ray_directions'].roll(instance_mv_num * i, dims=0) | |
| for i in range(1, 4) | |
| ] | |
| # + [cano_target['ray_directions'] ] | |
| ) | |
| pred_nv_cano = self.rec_model( | |
| # latent=latent.expand(2,), | |
| latent={ | |
| 'latent_after_vit': # ! triplane for rendering | |
| # latent['latent_after_vit'].repeat(2, 1, 1, 1) | |
| latent['latent_after_vit'].repeat(3, 1, 1, 1) | |
| }, | |
| c=c, | |
| behaviour='triplane_dec', | |
| # ray_origins=target['ray_origins'], | |
| # ray_directions=target['ray_directions'], | |
| ray_origins=ray_origins, | |
| ray_directions=ray_directions, | |
| ) | |
| pred_nv_cano.update( | |
| latent | |
| ) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
| # gt = { | |
| # k: th.cat([v, cano_cropped_target[k]], 0) | |
| # for k, v in cropped_target.items() | |
| # } | |
| gt = { | |
| k: | |
| th.cat( | |
| [ | |
| v.roll(instance_mv_num * i, dims=0) | |
| for i in range(1, 4) | |
| ] | |
| # + [cano_cropped_target[k] ] | |
| , | |
| 0) | |
| for k, v in cropped_target.items() | |
| } # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True) | |
| with self.rec_model.no_sync(): # type: ignore | |
| loss, loss_dict, _ = self.loss_class( | |
| pred_nv_cano, | |
| gt, # prepare merged data | |
| step=self.step + self.resume_step, | |
| test_mode=False, | |
| return_fg_mask=True, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None) | |
| log_rec3d_loss_dict(loss_dict) | |
| self.mp_trainer_rec.backward(loss) | |
| # for name, p in self.rec_model.named_parameters(): | |
| # if p.grad is None: | |
| # logger.log(f"found rec unused param: {name}") | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| micro_bs = micro['img_to_encoder'].shape[0] | |
| self.log_patch_img( # record one cano view and one novel view | |
| cropped_target, | |
| { | |
| k: pred_nv_cano[k][-micro_bs:] | |
| for k in ['image_raw', 'image_depth', 'image_mask'] | |
| }, | |
| { | |
| k: pred_nv_cano[k][:micro_bs] | |
| for k in ['image_raw', 'image_depth', 'image_mask'] | |
| }, | |
| ) | |
| def eval_loop(self): | |
| return super().eval_loop() | |
| # def eval_loop(self, c_list:list): | |
| def eval_novelview_loop_old(self, camera=None): | |
| # novel view synthesis given evaluation camera trajectory | |
| all_loss_dict = [] | |
| novel_view_micro = {} | |
| # ! randomly inference an instance | |
| export_mesh = True | |
| if export_mesh: | |
| Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True, | |
| exist_ok=True) | |
| # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
| batch = {} | |
| # if camera is not None: | |
| # # batch['c'] = camera.to(batch['c'].device()) | |
| # batch['c'] = camera.clone() | |
| # else: | |
| # batch = | |
| for eval_idx, render_reference in enumerate(tqdm(self.eval_data)): | |
| if eval_idx > 500: | |
| break | |
| video_out = imageio.get_writer( | |
| f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4', | |
| mode='I', | |
| fps=25, | |
| codec='libx264') | |
| with open( | |
| f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt', | |
| 'w') as f: | |
| f.write(render_reference['caption']) | |
| for key in ['ins', 'bbox', 'caption']: | |
| if key in render_reference: | |
| render_reference.pop(key) | |
| real_flag = False | |
| mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics. | |
| if render_reference['c'].shape[:2] == (1, 40): | |
| real_flag = True | |
| # real img monocular reconstruction | |
| # compat lst for enumerate | |
| render_reference = [{ | |
| k: v[0][idx:idx + 1] | |
| for k, v in render_reference.items() | |
| } for idx in range(40)] | |
| elif render_reference['c'].shape[0] == 8: | |
| mv_flag = True | |
| render_reference = { | |
| k: v[:4] | |
| for k, v in render_reference.items() | |
| } | |
| # save gt | |
| torchvision.utils.save_image( | |
| render_reference[0:4]['img'], | |
| logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx), | |
| padding=0, | |
| normalize=True, | |
| value_range=(-1, 1), | |
| ) | |
| # torchvision.utils.save_image(render_reference[4:8]['img'], | |
| # logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx), | |
| # padding=0, | |
| # normalize=True, | |
| # value_range=(-1,1), | |
| # ) | |
| else: | |
| # compat lst for enumerate | |
| st() | |
| render_reference = [{ | |
| k: v[idx:idx + 1] | |
| for k, v in render_reference.items() | |
| } for idx in range(40)] | |
| # ! single-view version | |
| render_reference[0]['img_to_encoder'] = render_reference[14][ | |
| 'img_to_encoder'] # encode side view | |
| render_reference[0]['img'] = render_reference[14][ | |
| 'img'] # encode side view | |
| # save gt | |
| torchvision.utils.save_image( | |
| render_reference[0]['img'], | |
| logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx), | |
| padding=0, | |
| normalize=True, | |
| value_range=(-1, 1)) | |
| # ! TODO, merge with render_video_given_triplane later | |
| for i, batch in enumerate(render_reference): | |
| # for i in range(0, 8, self.microbatch): | |
| # c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
| micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
| st() | |
| if i == 0: | |
| if mv_flag: | |
| novel_view_micro = None | |
| else: | |
| novel_view_micro = { | |
| k: | |
| v[0:1].to(dist_util.dev()).repeat_interleave( | |
| # v[14:15].to(dist_util.dev()).repeat_interleave( | |
| micro['img'].shape[0], | |
| 0) if isinstance(v, th.Tensor) else v[0:1] | |
| for k, v in batch.items() | |
| } | |
| else: | |
| if i == 1: | |
| # ! output mesh | |
| if export_mesh: | |
| # ! get planes first | |
| # self.latent_name = 'latent_normalized' # normalized triplane latent | |
| # ddpm_latent = { | |
| # self.latent_name: planes, | |
| # } | |
| # ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render')) | |
| # mesh_size = 512 | |
| # mesh_size = 256 | |
| mesh_size = 384 | |
| # mesh_size = 320 | |
| # mesh_thres = 3 # TODO, requires tuning | |
| # mesh_thres = 5 # TODO, requires tuning | |
| mesh_thres = 10 # TODO, requires tuning | |
| import mcubes | |
| import trimesh | |
| dump_path = f'{logger.get_dir()}/mesh/' | |
| os.makedirs(dump_path, exist_ok=True) | |
| grid_out = self.rec_model( | |
| latent=pred, | |
| grid_size=mesh_size, | |
| behaviour='triplane_decode_grid', | |
| ) | |
| vtx, faces = mcubes.marching_cubes( | |
| grid_out['sigma'].squeeze(0).squeeze( | |
| -1).cpu().numpy(), mesh_thres) | |
| vtx = vtx / (mesh_size - 1) * 2 - 1 | |
| # vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0) | |
| # vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) | |
| # vtx_colors = (vtx_colors * 255).astype(np.uint8) | |
| # mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) | |
| mesh = trimesh.Trimesh( | |
| vertices=vtx, | |
| faces=faces, | |
| ) | |
| mesh_dump_path = os.path.join( | |
| dump_path, f'{eval_idx}.ply') | |
| mesh.export(mesh_dump_path, 'ply') | |
| print(f"Mesh dumped to {dump_path}") | |
| del grid_out, mesh | |
| th.cuda.empty_cache() | |
| # return | |
| # st() | |
| # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: | |
| novel_view_micro = { | |
| k: | |
| v[0:1].to(dist_util.dev()).repeat_interleave( | |
| micro['img'].shape[0], 0) | |
| for k, v in novel_view_micro.items() | |
| } | |
| pred = self.rec_model(img=novel_view_micro['img_to_encoder'], | |
| c=micro['c']) # pred: (B, 3, 64, 64) | |
| # target = { | |
| # 'img': micro['img'], | |
| # 'depth': micro['depth'], | |
| # 'depth_mask': micro['depth_mask'] | |
| # } | |
| # targe | |
| # if not export_mesh: | |
| if not real_flag: | |
| _, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
| all_loss_dict.append(loss_dict) | |
| # ! move to other places, add tensorboard | |
| # pred_vis = th.cat([ | |
| # pred['image_raw'], | |
| # -pred['image_depth'].repeat_interleave(3, dim=1) | |
| # ], | |
| # dim=-1) | |
| # normalize depth | |
| # if True: | |
| pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / ( | |
| pred_depth.max() - pred_depth.min()) | |
| if 'image_sr' in pred: | |
| if pred['image_sr'].shape[-1] == 512: | |
| pred_vis = th.cat([ | |
| micro['img_sr'], | |
| self.pool_512(pred['image_raw']), pred['image_sr'], | |
| self.pool_512(pred_depth).repeat_interleave(3, | |
| dim=1) | |
| ], | |
| dim=-1) | |
| elif pred['image_sr'].shape[-1] == 256: | |
| pred_vis = th.cat([ | |
| micro['img_sr'], | |
| self.pool_256(pred['image_raw']), pred['image_sr'], | |
| self.pool_256(pred_depth).repeat_interleave(3, | |
| dim=1) | |
| ], | |
| dim=-1) | |
| else: | |
| pred_vis = th.cat([ | |
| micro['img_sr'], | |
| self.pool_128(pred['image_raw']), | |
| self.pool_128(pred['image_sr']), | |
| self.pool_128(pred_depth).repeat_interleave(3, | |
| dim=1) | |
| ], | |
| dim=-1) | |
| else: | |
| # pred_vis = th.cat([ | |
| # self.pool_64(micro['img']), pred['image_raw'], | |
| # pred_depth.repeat_interleave(3, dim=1) | |
| # ], | |
| # dim=-1) # B, 3, H, W | |
| pooled_depth = self.pool_128(pred_depth).repeat_interleave( | |
| 3, dim=1) | |
| pred_vis = th.cat( | |
| [ | |
| # self.pool_128(micro['img']), | |
| self.pool_128(novel_view_micro['img'] | |
| ), # use the input here | |
| self.pool_128(pred['image_raw']), | |
| pooled_depth, | |
| ], | |
| dim=-1) # B, 3, H, W | |
| vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
| vis = vis * 127.5 + 127.5 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| if export_mesh: | |
| # save image | |
| torchvision.utils.save_image( | |
| pred['image_raw'], | |
| logger.get_dir() + | |
| '/FID_Cals/{}_{}.png'.format(eval_idx, i), | |
| padding=0, | |
| normalize=True, | |
| value_range=(-1, 1)) | |
| torchvision.utils.save_image( | |
| pooled_depth, | |
| logger.get_dir() + | |
| '/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i), | |
| padding=0, | |
| normalize=True, | |
| value_range=(0, 1)) | |
| # st() | |
| for j in range(vis.shape[0]): | |
| video_out.append_data(vis[j]) | |
| video_out.close() | |
| # if not export_mesh: | |
| if not real_flag or mv_flag: | |
| val_scores_for_logging = calc_average_loss(all_loss_dict) | |
| with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), | |
| 'a') as f: | |
| json.dump({'step': self.step, **val_scores_for_logging}, f) | |
| # * log to tensorboard | |
| for k, v in val_scores_for_logging.items(): | |
| self.writer.add_scalar(f'Eval/NovelView/{k}', v, | |
| self.step + self.resume_step) | |
| del video_out | |
| # del pred_vis | |
| # del pred | |
| th.cuda.empty_cache() | |
| # def eval_loop(self, c_list:list): | |
| def eval_novelview_loop(self, camera=None, save_latent=False): | |
| # novel view synthesis given evaluation camera trajectory | |
| if save_latent: # for diffusion learning | |
| latent_dir = Path(f'{logger.get_dir()}/latent_dir') | |
| latent_dir.mkdir(exist_ok=True, parents=True) | |
| # wds_path = os.path.join(logger.get_dir(), 'latent_dir', | |
| # f'wds-%06d.tar') | |
| # sink = wds.ShardWriter(wds_path, start_shard=0) | |
| # eval_batch_size = 20 | |
| # eval_batch_size = 1 | |
| eval_batch_size = 40 # ! for i23d | |
| for eval_idx, micro in enumerate(tqdm(self.eval_data)): | |
| latent = self.rec_model( | |
| img=micro['img_to_encoder'], | |
| behaviour='encoder_vae') # pred: (B, 3, 64, 64) | |
| # torchvision.utils.save_image(micro['img'], 'inp.jpg') | |
| if save_latent: | |
| # np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy()) | |
| latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' | |
| Path(latent_save_dir).mkdir(parents=True, exist_ok=True) | |
| np.save(f'{latent_save_dir}/latent.npy', | |
| latent[self.latent_name][0].cpu().numpy()) | |
| assert all([ | |
| micro['ins'][0] == micro['ins'][i] | |
| for i in range(micro['c'].shape[0]) | |
| ]) # ! assert same instance | |
| if eval_idx < 50: | |
| # if False: | |
| self.render_video_given_triplane( | |
| latent[self.latent_name], # B 12 32 32 | |
| self.rec_model, # compatible with join_model | |
| name_prefix=f'{self.step + self.resume_step}_{eval_idx}', | |
| save_img=False, | |
| render_reference={'c': camera}, | |
| save_mesh=True) | |
| class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward): | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| load_submodule_name='', | |
| ignore_resume_opt=False, | |
| model_name='rec', | |
| use_amp=False, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| load_submodule_name=load_submodule_name, | |
| ignore_resume_opt=ignore_resume_opt, | |
| model_name=model_name, | |
| use_amp=use_amp, | |
| **kwargs) | |
| def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): | |
| # add patch sampling | |
| self.mp_trainer_rec.zero_grad() | |
| batch_size = batch['img_to_encoder'].shape[0] | |
| batch.pop('caption') # not required | |
| batch.pop('nv_caption') # not required | |
| batch.pop('ins') # not required | |
| batch.pop('nv_ins') # not required | |
| if '__key__' in batch.keys(): | |
| batch.pop('__key__') | |
| for i in range(0, batch_size, self.microbatch): | |
| micro = { | |
| k: | |
| v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
| v, th.Tensor) else v[i:i + self.microbatch] | |
| for k, v in batch.items() | |
| } | |
| # ! sample rendering patch | |
| # nv_c = th.cat([micro['nv_c'], micro['c']]) | |
| nv_c = th.cat([micro['nv_c'], micro['c']]) | |
| # nv_c = micro['nv_c'] | |
| target = { | |
| **self.eg3d_model( | |
| c=nv_c, # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir | |
| } | |
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
| 'patch_rendering_resolution'] # type: ignore | |
| cropped_target = { | |
| k: | |
| th.empty_like(v).repeat_interleave(2, 0) | |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] | |
| if k not in [ | |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
| 'nv_img_sr', 'c', 'caption', 'nv_caption' | |
| ] else v | |
| for k, v in micro.items() | |
| } | |
| # crop according to uv sampling | |
| for j in range(2 * self.microbatch): | |
| top, left, height, width = target['ray_bboxes'][ | |
| j] # list of tuple | |
| # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
| if j < self.microbatch: | |
| cropped_target[f'{key}'][ # ! no nv_ here | |
| j:j + 1] = torchvision.transforms.functional.crop( | |
| micro[f'nv_{key}'][j:j + 1], top, left, height, | |
| width) | |
| else: | |
| cropped_target[f'{key}'][ # ! no nv_ here | |
| j:j + 1] = torchvision.transforms.functional.crop( | |
| micro[f'{key}'][j - self.microbatch:j - | |
| self.microbatch + 1], top, | |
| left, height, width) | |
| # ! vit no amp | |
| latent = self.rec_model(img=micro['img_to_encoder'], | |
| behaviour='enc_dec_wo_triplane') | |
| # wrap forward within amp | |
| with th.autocast(device_type='cuda', | |
| dtype=th.float16, | |
| enabled=self.mp_trainer_rec.use_amp): | |
| # c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here | |
| # c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here | |
| # instance_mv_num = batch_size // 4 # 4 pairs by default | |
| # instance_mv_num = 4 | |
| # ! roll views for multi-view supervision | |
| # c = micro['nv_c'] | |
| ray_origins = target['ray_origins'] | |
| ray_directions = target['ray_directions'] | |
| pred_nv_cano = self.rec_model( | |
| # latent=latent.expand(2,), | |
| latent={ | |
| 'latent_after_vit': # ! triplane for rendering | |
| latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4 | |
| # latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4 | |
| }, | |
| c=nv_c, | |
| behaviour='triplane_dec', | |
| ray_origins=ray_origins, | |
| ray_directions=ray_directions, | |
| ) | |
| pred_nv_cano.update( | |
| latent | |
| ) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
| gt = cropped_target | |
| with self.rec_model.no_sync(): # type: ignore | |
| loss, loss_dict, _ = self.loss_class( | |
| pred_nv_cano, | |
| gt, # prepare merged data | |
| step=self.step + self.resume_step, | |
| test_mode=False, | |
| return_fg_mask=True, | |
| behaviour=behaviour, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None) | |
| log_rec3d_loss_dict(loss_dict) | |
| self.mp_trainer_rec.backward(loss) | |
| # for name, p in self.rec_model.named_parameters(): | |
| # if p.grad is None: | |
| # logger.log(f"found rec unused param: {name}") | |
| # torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True) | |
| # torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0: | |
| try: | |
| torchvision.utils.save_image( | |
| th.cat( | |
| [cropped_target['img'], pred_nv_cano['image_raw'] | |
| ], ), | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
| normalize=True) | |
| logger.log( | |
| 'log vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| except Exception as e: | |
| logger.log(e) | |
| # micro_bs = micro['img_to_encoder'].shape[0] | |
| # self.log_patch_img( # record one cano view and one novel view | |
| # cropped_target, | |
| # { | |
| # k: pred_nv_cano[k][0:1] | |
| # for k in ['image_raw', 'image_depth', 'image_mask'] | |
| # }, | |
| # { | |
| # k: pred_nv_cano[k][1:2] | |
| # for k in ['image_raw', 'image_depth', 'image_mask'] | |
| # }, | |
| # ) | |
| # def save(self): | |
| # return super().save() | |
| class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss( | |
| TrainLoop3DRecNVPatchSingleForwardMV): | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| load_submodule_name='', | |
| ignore_resume_opt=False, | |
| model_name='rec', | |
| use_amp=False, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| load_submodule_name=load_submodule_name, | |
| ignore_resume_opt=ignore_resume_opt, | |
| model_name=model_name, | |
| use_amp=use_amp, | |
| **kwargs) | |
| # create discriminator | |
| disc_params = self.loss_class.get_trainable_parameters() | |
| self.mp_trainer_disc = MixedPrecisionTrainer( | |
| model=self.loss_class.discriminator, | |
| use_fp16=self.use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| model_name='disc', | |
| use_amp=use_amp, | |
| model_params=disc_params) | |
| # st() # check self.lr | |
| self.opt_disc = AdamW( | |
| self.mp_trainer_disc.master_params, | |
| lr=self.lr, # follow sd code base | |
| betas=(0, 0.999), | |
| eps=1e-8) | |
| # TODO, is loss cls already in the DDP? | |
| if self.use_ddp: | |
| self.ddp_disc = DDP( | |
| self.loss_class.discriminator, | |
| device_ids=[dist_util.dev()], | |
| output_device=dist_util.dev(), | |
| broadcast_buffers=False, | |
| bucket_cap_mb=128, | |
| find_unused_parameters=False, | |
| ) | |
| else: | |
| self.ddp_disc = self.loss_class.discriminator | |
| # def run_st | |
| # def run_step(self, batch, *args): | |
| # self.forward_backward(batch) | |
| # took_step = self.mp_trainer_rec.optimize(self.opt) | |
| # if took_step: | |
| # self._update_ema() | |
| # self._anneal_lr() | |
| # self.log_step() | |
| def save(self, mp_trainer=None, model_name='rec'): | |
| if mp_trainer is None: | |
| mp_trainer = self.mp_trainer_rec | |
| def save_checkpoint(rate, params): | |
| state_dict = mp_trainer.master_params_to_state_dict(params) | |
| if dist_util.get_rank() == 0: | |
| logger.log(f"saving model {model_name} {rate}...") | |
| if not rate: | |
| filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" | |
| else: | |
| filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" | |
| with bf.BlobFile(bf.join(get_blob_logdir(), filename), | |
| "wb") as f: | |
| th.save(state_dict, f) | |
| save_checkpoint(0, mp_trainer.master_params) | |
| dist.barrier() | |
| def run_step(self, batch, step='g_step'): | |
| # self.forward_backward(batch) | |
| if step == 'g_step': | |
| self.forward_backward(batch, behaviour='g_step') | |
| took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) | |
| if took_step_g_rec: | |
| self._update_ema() # g_ema | |
| elif step == 'd_step': | |
| self.forward_backward(batch, behaviour='d_step') | |
| _ = self.mp_trainer_disc.optimize(self.opt_disc) | |
| self._anneal_lr() | |
| self.log_step() | |
| def run_loop(self, batch=None): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| batch = next(self.data) | |
| self.run_step(batch, 'g_step') | |
| batch = next(self.data) | |
| self.run_step(batch, 'd_step') | |
| if self.step % 1000 == 0: | |
| dist_util.synchronize() | |
| if self.step % 10000 == 0: | |
| th.cuda.empty_cache() # avoid memory leak | |
| if self.step % self.log_interval == 0 and dist_util.get_rank( | |
| ) == 0: | |
| out = logger.dumpkvs() | |
| # * log to tensorboard | |
| for k, v in out.items(): | |
| self.writer.add_scalar(f'Loss/{k}', v, | |
| self.step + self.resume_step) | |
| if self.step % self.eval_interval == 0 and self.step != 0: | |
| if dist_util.get_rank() == 0: | |
| try: | |
| self.eval_loop() | |
| except Exception as e: | |
| logger.log(e) | |
| dist_util.synchronize() | |
| # if self.step % self.save_interval == 0 and self.step != 0: | |
| if self.step % self.save_interval == 0: | |
| self.save() | |
| self.save(self.mp_trainer_disc, | |
| self.mp_trainer_disc.model_name) | |
| dist_util.synchronize() | |
| # Run for a finite amount of time in integration tests. | |
| if os.environ.get("DIFFUSION_TRAINING_TEST", | |
| "") and self.step > 0: | |
| return | |
| self.step += 1 | |
| if self.step > self.iterations: | |
| logger.log('reached maximum iterations, exiting') | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - | |
| 1) % self.save_interval != 0 and self.step != 1: | |
| self.save() | |
| exit() | |
| # Save the last checkpoint if it wasn't already saved. | |
| # if (self.step - 1) % self.save_interval != 0 and self.step != 1: | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() # save rec | |
| self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) | |