LTT commited on
Commit
3b39c63
·
verified ·
1 Parent(s): 048c43e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -82,22 +82,23 @@ from huggingface_hub import hf_hub_download
82
 
83
  from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
84
 
85
- device = "cuda"
 
86
  resolution = 512
87
  save_dir = "./outputs"
88
  normal_transfer = NormalTransfer()
89
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
90
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
91
  isomer_radius = 4.5
92
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
93
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
94
 
95
  # model initialization and loading
96
  # flux
97
  flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
98
  flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
99
  flux_pipe.load_lora_weights(flux_lora_ckpt_path)
100
- flux_pipe.to(device=device, dtype=torch.bfloat16)
101
 
102
 
103
  # lrm
@@ -109,11 +110,11 @@ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt",
109
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
110
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
111
  model.load_state_dict(state_dict, strict=True)
112
- model = model.to(device)
113
 
114
  @spaces.GPU
115
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
116
- images = image.unsqueeze(0).to(device)
117
  images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
118
  # breakpoint()
119
  with torch.no_grad():
@@ -225,7 +226,7 @@ def reconstruct_3d_model(images, prompt):
225
  normal_multi_view = images[4:, :3, :, :]
226
  multi_view_mask = get_background(normal_multi_view)
227
  rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
228
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
229
  vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
230
  # local normal to global normal
231
 
@@ -235,8 +236,8 @@ def reconstruct_3d_model(images, prompt):
235
  global_normal = global_normal.permute(0,2,3,1)
236
  rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
237
  multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
238
- vertices = torch.from_numpy(vertices).to(device)
239
- faces = torch.from_numpy(faces).to(device)
240
  vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
241
  vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
242
 
@@ -283,7 +284,7 @@ def reconstruct_3d_model(images, prompt):
283
  @spaces.GPU
284
  def gradio_pipeline(prompt, seed):
285
  global model
286
- model.init_flexicubes_geometry(device, fovy=50.0)
287
  model = model.eval()
288
  # 生成多视图图像
289
  rgb_normal_grid = generate_multi_view_images(prompt, seed)
 
82
 
83
  from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
84
 
85
+ device_0 = "cuda:0"
86
+ device_1 = "cuda:1"
87
  resolution = 512
88
  save_dir = "./outputs"
89
  normal_transfer = NormalTransfer()
90
+ isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
91
+ isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
92
  isomer_radius = 4.5
93
+ isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
94
+ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
95
 
96
  # model initialization and loading
97
  # flux
98
  flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
99
  flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
100
  flux_pipe.load_lora_weights(flux_lora_ckpt_path)
101
+ flux_pipe.to(device=device_0, dtype=torch.bfloat16)
102
 
103
 
104
  # lrm
 
110
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
111
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
112
  model.load_state_dict(state_dict, strict=True)
113
+ model = model.to(device_1)
114
 
115
  @spaces.GPU
116
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
117
+ images = image.unsqueeze(0).to(device_1)
118
  images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
119
  # breakpoint()
120
  with torch.no_grad():
 
226
  normal_multi_view = images[4:, :3, :, :]
227
  multi_view_mask = get_background(normal_multi_view)
228
  rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
229
+ input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
230
  vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
231
  # local normal to global normal
232
 
 
236
  global_normal = global_normal.permute(0,2,3,1)
237
  rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
238
  multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
239
+ vertices = torch.from_numpy(vertices).to(device_1)
240
+ faces = torch.from_numpy(faces).to(device_1)
241
  vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
242
  vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
243
 
 
284
  @spaces.GPU
285
  def gradio_pipeline(prompt, seed):
286
  global model
287
+ model.init_flexicubes_geometry(device_1, fovy=50.0)
288
  model = model.eval()
289
  # 生成多视图图像
290
  rgb_normal_grid = generate_multi_view_images(prompt, seed)