Update app.py
Browse files
app.py
CHANGED
@@ -82,22 +82,23 @@ from huggingface_hub import hf_hub_download
|
|
82 |
|
83 |
from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
|
84 |
|
85 |
-
|
|
|
86 |
resolution = 512
|
87 |
save_dir = "./outputs"
|
88 |
normal_transfer = NormalTransfer()
|
89 |
-
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(
|
90 |
-
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(
|
91 |
isomer_radius = 4.5
|
92 |
-
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(
|
93 |
-
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(
|
94 |
|
95 |
# model initialization and loading
|
96 |
# flux
|
97 |
flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
|
98 |
flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
|
99 |
flux_pipe.load_lora_weights(flux_lora_ckpt_path)
|
100 |
-
flux_pipe.to(device=
|
101 |
|
102 |
|
103 |
# lrm
|
@@ -109,11 +110,11 @@ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt",
|
|
109 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
110 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
111 |
model.load_state_dict(state_dict, strict=True)
|
112 |
-
model = model.to(
|
113 |
|
114 |
@spaces.GPU
|
115 |
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
116 |
-
images = image.unsqueeze(0).to(
|
117 |
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
118 |
# breakpoint()
|
119 |
with torch.no_grad():
|
@@ -225,7 +226,7 @@ def reconstruct_3d_model(images, prompt):
|
|
225 |
normal_multi_view = images[4:, :3, :, :]
|
226 |
multi_view_mask = get_background(normal_multi_view)
|
227 |
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
228 |
-
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(
|
229 |
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
|
230 |
# local normal to global normal
|
231 |
|
@@ -235,8 +236,8 @@ def reconstruct_3d_model(images, prompt):
|
|
235 |
global_normal = global_normal.permute(0,2,3,1)
|
236 |
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
237 |
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
|
238 |
-
vertices = torch.from_numpy(vertices).to(
|
239 |
-
faces = torch.from_numpy(faces).to(
|
240 |
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
241 |
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
242 |
|
@@ -283,7 +284,7 @@ def reconstruct_3d_model(images, prompt):
|
|
283 |
@spaces.GPU
|
284 |
def gradio_pipeline(prompt, seed):
|
285 |
global model
|
286 |
-
model.init_flexicubes_geometry(
|
287 |
model = model.eval()
|
288 |
# 生成多视图图像
|
289 |
rgb_normal_grid = generate_multi_view_images(prompt, seed)
|
|
|
82 |
|
83 |
from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
|
84 |
|
85 |
+
device_0 = "cuda:0"
|
86 |
+
device_1 = "cuda:1"
|
87 |
resolution = 512
|
88 |
save_dir = "./outputs"
|
89 |
normal_transfer = NormalTransfer()
|
90 |
+
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
|
91 |
+
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
|
92 |
isomer_radius = 4.5
|
93 |
+
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
|
94 |
+
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
|
95 |
|
96 |
# model initialization and loading
|
97 |
# flux
|
98 |
flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
|
99 |
flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
|
100 |
flux_pipe.load_lora_weights(flux_lora_ckpt_path)
|
101 |
+
flux_pipe.to(device=device_0, dtype=torch.bfloat16)
|
102 |
|
103 |
|
104 |
# lrm
|
|
|
110 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
111 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
112 |
model.load_state_dict(state_dict, strict=True)
|
113 |
+
model = model.to(device_1)
|
114 |
|
115 |
@spaces.GPU
|
116 |
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
117 |
+
images = image.unsqueeze(0).to(device_1)
|
118 |
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
119 |
# breakpoint()
|
120 |
with torch.no_grad():
|
|
|
226 |
normal_multi_view = images[4:, :3, :, :]
|
227 |
multi_view_mask = get_background(normal_multi_view)
|
228 |
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
229 |
+
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
|
230 |
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
|
231 |
# local normal to global normal
|
232 |
|
|
|
236 |
global_normal = global_normal.permute(0,2,3,1)
|
237 |
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
238 |
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
|
239 |
+
vertices = torch.from_numpy(vertices).to(device_1)
|
240 |
+
faces = torch.from_numpy(faces).to(device_1)
|
241 |
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
242 |
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
243 |
|
|
|
284 |
@spaces.GPU
|
285 |
def gradio_pipeline(prompt, seed):
|
286 |
global model
|
287 |
+
model.init_flexicubes_geometry(device_1, fovy=50.0)
|
288 |
model = model.eval()
|
289 |
# 生成多视图图像
|
290 |
rgb_normal_grid = generate_multi_view_images(prompt, seed)
|