Spaces:
Runtime error
Runtime error
import argparse | |
import datetime | |
import logging | |
import inspect | |
import math | |
import os | |
from typing import Dict, Optional, Tuple, List | |
from omegaconf import OmegaConf | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from dataclasses import dataclass | |
from packaging import version | |
import shutil | |
from collections import defaultdict | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
import torchvision.transforms.functional as TF | |
from torchvision.utils import make_grid, save_image | |
import transformers | |
import accelerate | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import ProjectConfiguration, set_seed | |
import diffusers | |
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel | |
from diffusers.optimization import get_scheduler | |
from diffusers.training_utils import EMAModel | |
from diffusers.utils import check_min_version, deprecate, is_wandb_available | |
from diffusers.utils.import_utils import is_xformers_available | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel | |
from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset | |
from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline | |
from einops import rearrange | |
from rembg import remove | |
import pdb | |
weight_dtype = torch.float16 | |
class TestConfig: | |
pretrained_model_name_or_path: str | |
pretrained_unet_path:str | |
revision: Optional[str] | |
validation_dataset: Dict | |
save_dir: str | |
seed: Optional[int] | |
validation_batch_size: int | |
dataloader_num_workers: int | |
local_rank: int | |
pipe_kwargs: Dict | |
pipe_validation_kwargs: Dict | |
unet_from_pretrained_kwargs: Dict | |
validation_guidance_scales: List[float] | |
validation_grid_nrow: int | |
camera_embedding_lr_mult: float | |
num_views: int | |
camera_embedding_type: str | |
pred_type: str # joint, or ablation | |
enable_xformers_memory_efficient_attention: bool | |
cond_on_normals: bool | |
cond_on_colors: bool | |
def log_validation(dataloader, pipeline, cfg: TestConfig, weight_dtype, name, save_dir): | |
pipeline.set_progress_bar_config(disable=True) | |
if cfg.seed is None: | |
generator = None | |
else: | |
generator = torch.Generator(device=pipeline.device).manual_seed(cfg.seed) | |
images_cond, images_pred = [], defaultdict(list) | |
for i, batch in tqdm(enumerate(dataloader)): | |
# (B, Nv, 3, H, W) | |
imgs_in = batch['imgs_in'] | |
alphas = batch['alphas'] | |
# (B, Nv, Nce) | |
camera_embeddings = batch['camera_embeddings'] | |
filename = batch['filename'] | |
bsz, num_views = imgs_in.shape[0], imgs_in.shape[1] | |
# (B*Nv, 3, H, W) | |
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") | |
alphas = rearrange(alphas, "B Nv C H W -> (B Nv) C H W") | |
# (B*Nv, Nce) | |
camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") | |
images_cond.append(imgs_in) | |
with torch.autocast("cuda"): | |
# B*Nv images | |
for guidance_scale in cfg.validation_guidance_scales: | |
out = pipeline( | |
imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs | |
).images | |
images_pred[f"{name}-sample_cfg{guidance_scale:.1f}"].append(out) | |
cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}") | |
# pdb.set_trace() | |
for i in range(bsz): | |
scene = os.path.basename(filename[i]) | |
print(scene) | |
scene_dir = os.path.join(cur_dir, scene) | |
outs_dir = os.path.join(scene_dir, "outs") | |
masked_outs_dir = os.path.join(scene_dir, "masked_outs") | |
os.makedirs(outs_dir, exist_ok=True) | |
os.makedirs(masked_outs_dir, exist_ok=True) | |
img_in = imgs_in[i*num_views] | |
alpha = alphas[i*num_views] | |
img_in = torch.cat([img_in, alpha], dim=0) | |
save_image(img_in, os.path.join(scene_dir, scene+".png")) | |
for j in range(num_views): | |
view = VIEWS[j] | |
idx = i*num_views + j | |
pred = out[idx] | |
# pdb.set_trace() | |
out_filename = f"{cfg.pred_type}_000_{view}.png" | |
pred = save_image(pred, os.path.join(outs_dir, out_filename)) | |
rm_pred = remove(pred) | |
save_image_numpy(rm_pred, os.path.join(scene_dir, out_filename)) | |
torch.cuda.empty_cache() | |
def save_image(tensor, fp): | |
ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
# pdb.set_trace() | |
im = Image.fromarray(ndarr) | |
im.save(fp) | |
return ndarr | |
def save_image_numpy(ndarr, fp): | |
im = Image.fromarray(ndarr) | |
im.save(fp) | |
def log_validation_joint(dataloader, pipeline, cfg: TestConfig, weight_dtype, name, save_dir): | |
pipeline.set_progress_bar_config(disable=True) | |
if cfg.seed is None: | |
generator = None | |
else: | |
generator = torch.Generator(device=pipeline.device).manual_seed(cfg.seed) | |
images_cond, normals_pred, images_pred = [], defaultdict(list), defaultdict(list) | |
for i, batch in tqdm(enumerate(dataloader)): | |
# repeat (2B, Nv, 3, H, W) | |
imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) | |
filename = batch['filename'] | |
# (2B, Nv, Nce) | |
camera_embeddings = torch.cat([batch['camera_embeddings']]*2, dim=0) | |
task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0) | |
camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1) | |
# (B*Nv, 3, H, W) | |
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") | |
# (B*Nv, Nce) | |
camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") | |
images_cond.append(imgs_in) | |
num_views = len(VIEWS) | |
with torch.autocast("cuda"): | |
# B*Nv images | |
for guidance_scale in cfg.validation_guidance_scales: | |
out = pipeline( | |
imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs | |
).images | |
bsz = out.shape[0] // 2 | |
normals_pred = out[:bsz] | |
images_pred = out[bsz:] | |
cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}") | |
for i in range(bsz//num_views): | |
scene = filename[i] | |
scene_dir = os.path.join(cur_dir, scene) | |
normal_dir = os.path.join(scene_dir, "normals") | |
masked_colors_dir = os.path.join(scene_dir, "masked_colors") | |
os.makedirs(normal_dir, exist_ok=True) | |
os.makedirs(masked_colors_dir, exist_ok=True) | |
for j in range(num_views): | |
view = VIEWS[j] | |
idx = i*num_views + j | |
normal = normals_pred[idx] | |
color = images_pred[idx] | |
normal_filename = f"normals_000_{view}.png" | |
rgb_filename = f"rgb_000_{view}.png" | |
normal = save_image(normal, os.path.join(normal_dir, normal_filename)) | |
color = save_image(color, os.path.join(scene_dir, rgb_filename)) | |
rm_normal = remove(normal) | |
rm_color = remove(color) | |
save_image_numpy(rm_normal, os.path.join(scene_dir, normal_filename)) | |
save_image_numpy(rm_color, os.path.join(masked_colors_dir, rgb_filename)) | |
torch.cuda.empty_cache() | |
def load_wonder3d_pipeline(cfg): | |
pipeline = MVDiffusionImagePipeline.from_pretrained( | |
cfg.pretrained_model_name_or_path, | |
torch_dtype=weight_dtype | |
) | |
# pipeline.to('cuda:0') | |
pipeline.unet.enable_xformers_memory_efficient_attention() | |
if torch.cuda.is_available(): | |
pipeline.to('cuda:0') | |
# sys.main_lock = threading.Lock() | |
return pipeline | |
def main( | |
cfg: TestConfig | |
): | |
# If passed along, set the training seed now. | |
if cfg.seed is not None: | |
set_seed(cfg.seed) | |
pipeline = load_wonder3d_pipeline(cfg) | |
if cfg.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
import xformers | |
xformers_version = version.parse(xformers.__version__) | |
if xformers_version == version.parse("0.0.16"): | |
print( | |
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
) | |
pipeline.unet.enable_xformers_memory_efficient_attention() | |
print("use xformers.") | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
# Get the dataset | |
validation_dataset = MVDiffusionDataset( | |
**cfg.validation_dataset | |
) | |
# DataLoaders creation: | |
validation_dataloader = torch.utils.data.DataLoader( | |
validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers | |
) | |
os.makedirs(cfg.save_dir, exist_ok=True) | |
if cfg.pred_type == 'joint': | |
log_validation_joint( | |
validation_dataloader, | |
pipeline, | |
cfg, | |
weight_dtype, | |
'validation', | |
cfg.save_dir | |
) | |
else: | |
log_validation( | |
validation_dataloader, | |
pipeline, | |
cfg, | |
weight_dtype, | |
'validation', | |
cfg.save_dir | |
) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', type=str, required=True) | |
args, extras = parser.parse_known_args() | |
from utils.misc import load_config | |
# parse YAML config to OmegaConf | |
cfg = load_config(args.config, cli_args=extras) | |
print(cfg) | |
schema = OmegaConf.structured(TestConfig) | |
# cfg = OmegaConf.load(args.config) | |
cfg = OmegaConf.merge(schema, cfg) | |
if cfg.num_views == 6: | |
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] | |
elif cfg.num_views == 4: | |
VIEWS = ['front', 'right', 'back', 'left'] | |
main(cfg) | |