LTT commited on
Commit
b78fa8b
·
verified ·
1 Parent(s): 6422901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -94,12 +94,11 @@ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(d
94
 
95
  # model initialization and loading
96
  # flux
97
- flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device, dtype=torch.bfloat16)
98
  flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
99
  flux_pipe.load_lora_weights(flux_lora_ckpt_path)
100
 
101
- flux_pipe.to(device=device, dtype=torch.bfloat16)
102
- generator = torch.Generator(device=device).manual_seed(10)
103
 
104
  # lrm
105
  config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
@@ -111,9 +110,6 @@ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
111
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
112
  model.load_state_dict(state_dict, strict=True)
113
 
114
- model = model.to(device)
115
- model.init_flexicubes_geometry(device, fovy=50.0)
116
- model = model.eval()
117
 
118
  @spaces.GPU
119
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
@@ -284,6 +280,10 @@ def reconstruct_3d_model(images, prompt):
284
 
285
  # Gradio 接口函数
286
  def gradio_pipeline(prompt, seed):
 
 
 
 
287
  # 生成多视图图像
288
  rgb_normal_grid = generate_multi_view_images(prompt, seed)
289
  image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8))
 
94
 
95
  # model initialization and loading
96
  # flux
97
+ flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
98
  flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
99
  flux_pipe.load_lora_weights(flux_lora_ckpt_path)
100
 
101
+
 
102
 
103
  # lrm
104
  config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
 
110
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
111
  model.load_state_dict(state_dict, strict=True)
112
 
 
 
 
113
 
114
  @spaces.GPU
115
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
 
280
 
281
  # Gradio 接口函数
282
  def gradio_pipeline(prompt, seed):
283
+ flux_pipe.to(device=device, dtype=torch.bfloat16)
284
+ model = model.to(device)
285
+ model.init_flexicubes_geometry(device, fovy=50.0)
286
+ model = model.eval()
287
  # 生成多视图图像
288
  rgb_normal_grid = generate_multi_view_images(prompt, seed)
289
  image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8))