JiantaoLin commited on
Commit
c0dbb78
·
1 Parent(s): 2b11b58
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +16 -16
pipeline/kiss3d_wrapper.py CHANGED
@@ -67,7 +67,7 @@ def init_wrapper_from_config(config_path):
67
  flux_dtype = config_['flux'].get('dtype', 'bf16')
68
  flux_controlnet_pth = config_['flux'].get('controlnet', None)
69
  # flux_lora_pth = config_['flux'].get('lora', None)
70
- flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
71
  flux_redux_pth = config_['flux'].get('redux', None)
72
  # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype_[flux_dtype]).to(flux_device)
73
  if flux_base_model_pth.endswith('safetensors'):
@@ -102,23 +102,23 @@ def init_wrapper_from_config(config_path):
102
  # logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
103
 
104
  # init multiview model
105
- # logger.info('==> Loading multiview diffusion model ...')
106
- # multiview_device = config_['multiview'].get('device', 'cpu')
107
- # multiview_pipeline = DiffusionPipeline.from_pretrained(
108
- # config_['multiview']['base_model'],
109
- # custom_pipeline=config_['multiview']['custom_pipeline'],
110
- # torch_dtype=torch.float16,
111
- # )
112
- # multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
113
- # multiview_pipeline.scheduler.config, timestep_spacing='trailing'
114
- # )
115
 
116
- # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
117
- # if unet_ckpt_path is not None:
118
- # state_dict = torch.load(unet_ckpt_path, map_location='cpu')
119
- # multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
120
 
121
- # multiview_pipeline.to(multiview_device)
122
  # logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
123
  multiview_pipeline = None
124
 
 
67
  flux_dtype = config_['flux'].get('dtype', 'bf16')
68
  flux_controlnet_pth = config_['flux'].get('controlnet', None)
69
  # flux_lora_pth = config_['flux'].get('lora', None)
70
+ flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal.safetensors", repo_type="model", token=access_token)
71
  flux_redux_pth = config_['flux'].get('redux', None)
72
  # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype_[flux_dtype]).to(flux_device)
73
  if flux_base_model_pth.endswith('safetensors'):
 
102
  # logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
103
 
104
  # init multiview model
105
+ logger.info('==> Loading multiview diffusion model ...')
106
+ multiview_device = config_['multiview'].get('device', 'cpu')
107
+ multiview_pipeline = DiffusionPipeline.from_pretrained(
108
+ config_['multiview']['base_model'],
109
+ custom_pipeline=config_['multiview']['custom_pipeline'],
110
+ torch_dtype=torch.float16,
111
+ )
112
+ multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
113
+ multiview_pipeline.scheduler.config, timestep_spacing='trailing'
114
+ )
115
 
116
+ unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
117
+ if unet_ckpt_path is not None:
118
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
119
+ multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
120
 
121
+ multiview_pipeline.to(multiview_device)
122
  # logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
123
  multiview_pipeline = None
124