Update app.py
Browse files
app.py
CHANGED
@@ -108,16 +108,16 @@ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(d
|
|
108 |
|
109 |
|
110 |
# lrm
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
@spaces.GPU
|
122 |
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
123 |
images = image.unsqueeze(0).to(device_1)
|
@@ -307,7 +307,7 @@ def reconstruct_3d_model(images, prompt):
|
|
307 |
def gradio_pipeline(prompt, seed):
|
308 |
# 生成多视图图像
|
309 |
# rgb_normal_grid = generate_multi_view_images(prompt, seed)
|
310 |
-
|
311 |
image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
|
312 |
|
313 |
# 3d reconstruction
|
|
|
108 |
|
109 |
|
110 |
# lrm
|
111 |
+
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
112 |
+
model_config = config.model_config
|
113 |
+
infer_config = config.infer_config
|
114 |
+
model = instantiate_from_config(model_config)
|
115 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
116 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
117 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
118 |
+
model.load_state_dict(state_dict, strict=True)
|
119 |
+
model = model.to(device_1)
|
120 |
+
torch.cuda.empty_cache()
|
121 |
@spaces.GPU
|
122 |
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
123 |
images = image.unsqueeze(0).to(device_1)
|
|
|
307 |
def gradio_pipeline(prompt, seed):
|
308 |
# 生成多视图图像
|
309 |
# rgb_normal_grid = generate_multi_view_images(prompt, seed)
|
310 |
+
rgb_normal_grid = np.load("rgb_normal_grid.npy")
|
311 |
image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
|
312 |
|
313 |
# 3d reconstruction
|