LTT commited on
Commit
cf64ae4
·
verified ·
1 Parent(s): 299b7d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -108,16 +108,16 @@ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(d
108
 
109
 
110
  # lrm
111
- # config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
112
- # model_config = config.model_config
113
- # infer_config = config.infer_config
114
- # model = instantiate_from_config(model_config)
115
- # model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
116
- # state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
117
- # state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
118
- # model.load_state_dict(state_dict, strict=True)
119
- # model = model.to(device_1)
120
- # torch.cuda.empty_cache()
121
  @spaces.GPU
122
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
123
  images = image.unsqueeze(0).to(device_1)
@@ -307,7 +307,7 @@ def reconstruct_3d_model(images, prompt):
307
  def gradio_pipeline(prompt, seed):
308
  # 生成多视图图像
309
  # rgb_normal_grid = generate_multi_view_images(prompt, seed)
310
- # rgb_normal_grid = np.load("rgb_normal_grid.npy")
311
  image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
312
 
313
  # 3d reconstruction
 
108
 
109
 
110
  # lrm
111
+ config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
112
+ model_config = config.model_config
113
+ infer_config = config.infer_config
114
+ model = instantiate_from_config(model_config)
115
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
116
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
117
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
118
+ model.load_state_dict(state_dict, strict=True)
119
+ model = model.to(device_1)
120
+ torch.cuda.empty_cache()
121
  @spaces.GPU
122
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
123
  images = image.unsqueeze(0).to(device_1)
 
307
  def gradio_pipeline(prompt, seed):
308
  # 生成多视图图像
309
  # rgb_normal_grid = generate_multi_view_images(prompt, seed)
310
+ rgb_normal_grid = np.load("rgb_normal_grid.npy")
311
  image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
312
 
313
  # 3d reconstruction