ipekoztas commited on
Commit
b7c5eaf
Β·
1 Parent(s): 7860b91

Code upload.

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: InstantMesh
3
  emoji: πŸ“š
4
  colorFrom: indigo
5
  colorTo: green
 
1
  ---
2
+ title: 3D Stylization LRM
3
  emoji: πŸ“š
4
  colorFrom: indigo
5
  colorTo: green
app.py CHANGED
@@ -30,6 +30,7 @@ from huggingface_hub import hf_hub_download
30
  import gradio as gr
31
 
32
 
 
33
  def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
34
  """
35
  Get the rendering camera parameters.
@@ -90,7 +91,7 @@ if cuda_path:
90
  else:
91
  print("CUDA installation not found")
92
 
93
- config_path = 'configs/instant-mesh-large.yaml'
94
  config = OmegaConf.load(config_path)
95
  config_name = os.path.basename(config_path).replace('.yaml', '')
96
  model_config = config.model_config
@@ -120,7 +121,7 @@ pipeline = pipeline.to(device)
120
 
121
  # load reconstruction model
122
  print('Loading reconstruction model ...')
123
- model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
124
  model = instantiate_from_config(model_config)
125
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
126
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
@@ -134,6 +135,10 @@ print('Loading Finished!')
134
  def check_input_image(input_image):
135
  if input_image is None:
136
  raise gr.Error("No image uploaded!")
 
 
 
 
137
 
138
 
139
  def preprocess(input_image, do_remove_background):
@@ -158,7 +163,7 @@ def generate_mvs(input_image, sample_steps, sample_seed):
158
  num_inference_steps=sample_steps
159
  ).images[0]
160
 
161
- show_image = np.asarray(z123_image, dtype=np.uint8)
162
  show_image = torch.from_numpy(show_image) # (960, 640, 3)
163
  show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
164
  show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
@@ -166,66 +171,53 @@ def generate_mvs(input_image, sample_steps, sample_seed):
166
 
167
  return z123_image, show_image
168
 
169
-
170
  @spaces.GPU
171
- def make3d(images):
172
-
 
 
 
 
 
173
  global model
 
 
 
 
 
174
  if IS_FLEXICUBES:
175
  model.init_flexicubes_geometry(device, use_renderer=False)
176
- model = model.eval()
177
 
178
- images = np.asarray(images, dtype=np.float32) / 255.0
179
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
180
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
181
 
182
  input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
183
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
 
 
 
 
184
 
185
  images = images.unsqueeze(0).to(device)
186
  images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
187
 
188
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
189
- print(mesh_fpath)
190
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
191
  mesh_dirname = os.path.dirname(mesh_fpath)
192
- video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
193
  mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
194
 
195
  with torch.no_grad():
196
- # get triplane
197
- planes = model.forward_planes(images, input_cameras)
198
-
199
- # # get video
200
- # chunk_size = 20 if IS_FLEXICUBES else 1
201
- # render_size = 384
202
-
203
- # frames = []
204
- # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
205
- # if IS_FLEXICUBES:
206
- # frame = model.forward_geometry(
207
- # planes,
208
- # render_cameras[:, i:i+chunk_size],
209
- # render_size=render_size,
210
- # )['img']
211
- # else:
212
- # frame = model.synthesizer(
213
- # planes,
214
- # cameras=render_cameras[:, i:i+chunk_size],
215
- # render_size=render_size,
216
- # )['images_rgb']
217
- # frames.append(frame)
218
- # frames = torch.cat(frames, dim=1)
219
-
220
- # images_to_video(
221
- # frames[0],
222
- # video_fpath,
223
- # fps=30,
224
- # )
225
-
226
- # print(f"Video saved to {video_fpath}")
227
-
228
- # get mesh
229
  mesh_out = model.extract_mesh(
230
  planes,
231
  use_texture_map=False,
@@ -234,52 +226,40 @@ def make3d(images):
234
 
235
  vertices, faces, vertex_colors = mesh_out
236
  vertices = vertices[:, [1, 2, 0]]
237
-
238
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
239
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
240
-
241
  print(f"Mesh saved to {mesh_fpath}")
242
-
243
  return mesh_fpath, mesh_glb_fpath
244
 
245
-
246
  _HEADER_ = '''
247
- <h2><b>Official πŸ€— Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
248
-
249
- **InstantMesh** is a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
250
 
251
- Code: <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
252
-
253
- ❗️❗️❗️**Important Notes:**
254
- - Our demo can export a .obj mesh with vertex colors or a .glb mesh now. If you prefer to export a .obj mesh with a **texture map**, please refer to our <a href='https://github.com/TencentARC/InstantMesh?tab=readme-ov-file#running-with-command-line' target='_blank'>Github Repo</a>.
255
- - The 3D mesh generation results highly depend on the quality of generated multi-view images. Please try a different **seed value** if the result is unsatisfying (Default: 42).
256
  '''
257
 
258
  _CITE_ = r"""
259
- If InstantMesh is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/InstantMesh?style=social)](https://github.com/TencentARC/InstantMesh)
260
  ---
261
  πŸ“ **Citation**
262
 
263
  If you find our work useful for your research or applications, please cite using this bibtex:
264
  ```bibtex
265
- @article{xu2024instantmesh,
266
- title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
267
- author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
268
- journal={arXiv preprint arXiv:2404.07191},
269
- year={2024}
270
  }
271
  ```
272
-
273
- πŸ“‹ **License**
274
-
275
- Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
276
-
277
- πŸ“§ **Contact**
278
-
279
- If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
280
  """
281
 
282
-
283
  with gr.Blocks() as demo:
284
  gr.Markdown(_HEADER_)
285
  with gr.Row(variant="panel"):
@@ -294,6 +274,13 @@ with gr.Blocks() as demo:
294
  type="pil",
295
  elem_id="content_image",
296
  )
 
 
 
 
 
 
 
297
  processed_image = gr.Image(
298
  label="Processed Image",
299
  image_mode="RGBA",
@@ -317,6 +304,22 @@ with gr.Blocks() as demo:
317
  step=5
318
  )
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  with gr.Row():
321
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
322
 
@@ -330,6 +333,16 @@ with gr.Blocks() as demo:
330
  cache_examples=False,
331
  examples_per_page=16
332
  )
 
 
 
 
 
 
 
 
 
 
333
 
334
  with gr.Column():
335
 
@@ -372,19 +385,21 @@ with gr.Blocks() as demo:
372
 
373
  mv_images = gr.State()
374
 
 
375
  submit.click(fn=check_input_image, inputs=[input_image]).success(
 
 
376
  fn=preprocess,
377
  inputs=[input_image, do_remove_background],
378
  outputs=[processed_image],
379
  ).success(
380
  fn=generate_mvs,
381
  inputs=[processed_image, sample_steps, sample_seed],
382
- outputs=[mv_images, mv_show_images]
383
-
384
  ).success(
385
  fn=make3d,
386
- inputs=[mv_images],
387
- outputs=[output_model_obj, output_model_glb]
388
  )
389
 
390
  demo.launch()
 
30
  import gradio as gr
31
 
32
 
33
+
34
  def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
35
  """
36
  Get the rendering camera parameters.
 
91
  else:
92
  print("CUDA installation not found")
93
 
94
+ config_path = 'configs/instant-nerf-large.yaml'
95
  config = OmegaConf.load(config_path)
96
  config_name = os.path.basename(config_path).replace('.yaml', '')
97
  model_config = config.model_config
 
121
 
122
  # load reconstruction model
123
  print('Loading reconstruction model ...')
124
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_nerf_large.ckpt", repo_type="model")
125
  model = instantiate_from_config(model_config)
126
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
127
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
 
135
  def check_input_image(input_image):
136
  if input_image is None:
137
  raise gr.Error("No image uploaded!")
138
+
139
+ def check_style_image(style_image):
140
+ if style_image is None:
141
+ raise gr.Error("No style image uploaded!")
142
 
143
 
144
  def preprocess(input_image, do_remove_background):
 
163
  num_inference_steps=sample_steps
164
  ).images[0]
165
 
166
+ show_image = np.asarray(z123_image, dtype=np.uint8).copy()
167
  show_image = torch.from_numpy(show_image) # (960, 640, 3)
168
  show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
169
  show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
 
171
 
172
  return z123_image, show_image
173
 
 
174
  @spaces.GPU
175
+ def make3d(mv_images, style_image, alpha, style_layers):
176
+ """
177
+ mv_images: single multi-view image (pil or numpy)
178
+ style_image: PIL image
179
+ alpha: float
180
+ style_layers: int
181
+ """
182
  global model
183
+
184
+ # Save the uploaded style image to a temporary file, so the model can read it from disk
185
+ style_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
186
+ style_image.save(style_path)
187
+
188
  if IS_FLEXICUBES:
189
  model.init_flexicubes_geometry(device, use_renderer=False)
 
190
 
191
+ images = np.asarray(mv_images, dtype=np.float32) / 255.0
192
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
193
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
194
 
195
  input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
196
+ render_cameras = get_render_cameras(
197
+ batch_size=1,
198
+ radius=2.5,
199
+ is_flexicubes=IS_FLEXICUBES
200
+ ).to(device)
201
 
202
  images = images.unsqueeze(0).to(device)
203
  images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
204
 
205
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
 
206
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
207
  mesh_dirname = os.path.dirname(mesh_fpath)
 
208
  mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
209
 
210
  with torch.no_grad():
211
+ # get triplane, now passing style_path, alpha, style_layers
212
+ planes = model.forward_planes(
213
+ images,
214
+ input_cameras,
215
+ style_path,
216
+ float(alpha),
217
+ int(style_layers),
218
+ )
219
+
220
+ # extract mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  mesh_out = model.extract_mesh(
222
  planes,
223
  use_texture_map=False,
 
226
 
227
  vertices, faces, vertex_colors = mesh_out
228
  vertices = vertices[:, [1, 2, 0]]
229
+
230
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
231
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
 
232
  print(f"Mesh saved to {mesh_fpath}")
 
233
  return mesh_fpath, mesh_glb_fpath
234
 
 
235
  _HEADER_ = '''
236
+ <h2><b>3DStylizationLRM</b></h2>
237
+ This demo lets you provide a content image, a style image, an alpha blending value, and the number of style layers to inject. It will generate 3D geometry stylized accordingly.
 
238
 
239
+ ❗️❗️❗️ **Notes:**
240
+ - Content image background can be removed automatically.
241
+ - Adjust the **Alpha** slider to control style blending strength.
242
+ - Adjust **Style Layers** to choose how many layers of style to inject.
 
243
  '''
244
 
245
  _CITE_ = r"""
246
+ If 3D Stylization LRM is helpful, please help to ⭐ the <a href='https://github.com/ipekoztas/3D-Stylization-LRM' target='_blank'>Github Repo</a>. Thanks!
247
  ---
248
  πŸ“ **Citation**
249
 
250
  If you find our work useful for your research or applications, please cite using this bibtex:
251
  ```bibtex
252
+ @article{oztas20253dstylizationlargereconstruction,
253
+ title={3D Stylization via Large Reconstruction Model},
254
+ author={Ipek Oztas and Duygu Ceylan and Aysegul Dundar},
255
+ journal={https://arxiv.org/abs/2504.21836},
256
+ year={2025}
257
  }
258
  ```
259
+ πŸ“‹ **License**
260
+ Apache-2.0 LICENSE.
 
 
 
 
 
 
261
  """
262
 
 
263
  with gr.Blocks() as demo:
264
  gr.Markdown(_HEADER_)
265
  with gr.Row(variant="panel"):
 
274
  type="pil",
275
  elem_id="content_image",
276
  )
277
+ # Style Image Upload
278
+ style_image = gr.Image(
279
+ label="Style Image",
280
+ image_mode="RGB",
281
+ type="pil",
282
+ elem_id="style_image",
283
+ )
284
  processed_image = gr.Image(
285
  label="Processed Image",
286
  image_mode="RGBA",
 
304
  step=5
305
  )
306
 
307
+ with gr.Row():
308
+ alpha = gr.Slider(
309
+ label="Alpha Value",
310
+ minimum=0.0,
311
+ maximum=1.0,
312
+ value=0.7,
313
+ step=0.01,
314
+ )
315
+ style_layers = gr.Slider(
316
+ label="Style Layers",
317
+ minimum=1,
318
+ maximum=10,
319
+ value=4,
320
+ step=1,
321
+ )
322
+
323
  with gr.Row():
324
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
325
 
 
333
  cache_examples=False,
334
  examples_per_page=16
335
  )
336
+ with gr.Row(variant="panel"):
337
+ gr.Examples(
338
+ examples=[
339
+ os.path.join("styles", img_name) for img_name in sorted(os.listdir("styles"))
340
+ ],
341
+ inputs=[input_image],
342
+ label="Styles",
343
+ cache_examples=False,
344
+ examples_per_page=16
345
+ )
346
 
347
  with gr.Column():
348
 
 
385
 
386
  mv_images = gr.State()
387
 
388
+ # Chain of actions:
389
  submit.click(fn=check_input_image, inputs=[input_image]).success(
390
+ fn=check_style_image, inputs=[style_image]
391
+ ).success(
392
  fn=preprocess,
393
  inputs=[input_image, do_remove_background],
394
  outputs=[processed_image],
395
  ).success(
396
  fn=generate_mvs,
397
  inputs=[processed_image, sample_steps, sample_seed],
398
+ outputs=[mv_images, mv_show_images],
 
399
  ).success(
400
  fn=make3d,
401
+ inputs=[mv_images, style_image, alpha, style_layers],
402
+ outputs=[output_model_obj, output_model_glb],
403
  )
404
 
405
  demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,12 @@
 
 
 
 
1
  torch==2.1.0
2
  torchvision==0.16.0
3
  torchaudio==2.1.0
4
  pytorch-lightning==2.1.2
 
5
  einops
6
  omegaconf
7
  deepspeed
@@ -12,12 +17,15 @@ tensorboard
12
  PyMCubes
13
  trimesh
14
  rembg
15
- transformers==4.34.1
16
- diffusers==0.19.3
 
 
17
  bitsandbytes
18
  imageio[ffmpeg]
19
  xatlas
20
  plyfile
 
21
  xformers==0.0.22.post7
22
  git+https://github.com/NVlabs/nvdiffrast/
23
- huggingface-hub
 
1
+ pydantic==2.10.6
2
+ gradio==4.44.1
3
+ gradio-client==1.3.0
4
+ huggingface-hub==0.25.2
5
  torch==2.1.0
6
  torchvision==0.16.0
7
  torchaudio==2.1.0
8
  pytorch-lightning==2.1.2
9
+
10
  einops
11
  omegaconf
12
  deepspeed
 
17
  PyMCubes
18
  trimesh
19
  rembg
20
+ transformers==4.39.3
21
+ diffusers==0.27.0
22
+ tokenizers==0.15.2
23
+
24
  bitsandbytes
25
  imageio[ffmpeg]
26
  xatlas
27
  plyfile
28
+
29
  xformers==0.0.22.post7
30
  git+https://github.com/NVlabs/nvdiffrast/
31
+ onnxruntime
src/data/objaverse.py CHANGED
@@ -22,7 +22,7 @@ from src.utils.train_util import instantiate_from_config
22
  from src.utils.camera_util import (
23
  FOV_to_intrinsics,
24
  center_looking_at_camera_pose,
25
- get_surrounding_views,
26
  )
27
 
28
 
@@ -78,7 +78,7 @@ class ObjaverseData(Dataset):
78
  input_image_dir='rendering_random_32views',
79
  target_image_dir='rendering_random_32views',
80
  input_view_num=6,
81
- target_view_num=2,
82
  total_view_n=32,
83
  fov=50,
84
  camera_rotation=True,
@@ -99,7 +99,7 @@ class ObjaverseData(Dataset):
99
  paths = filtered_dict['good_objs']
100
  self.paths = paths
101
 
102
- self.depth_scale = 4.0
103
 
104
  total_objects = len(self.paths)
105
  print('============= length of dataset %d =============' % len(self.paths))
@@ -122,7 +122,6 @@ class ObjaverseData(Dataset):
122
  return image, alpha
123
 
124
  def __getitem__(self, index):
125
- # load data
126
  while True:
127
  input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
128
  target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
@@ -212,7 +211,7 @@ class ObjaverseData(Dataset):
212
 
213
  # random scaling
214
  if np.random.rand() < 0.5:
215
- scale = np.random.uniform(0.8, 1.0)
216
  c2ws[:, :3, 3] *= scale
217
  depths *= scale
218
 
@@ -221,11 +220,11 @@ class ObjaverseData(Dataset):
221
  Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
222
 
223
  data = {
224
- 'input_images': images[:self.input_view_num], # (6, 3, H, W)
225
  'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
226
  'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
227
  'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
228
- 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
229
  'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
230
 
231
  # lrm generator input and supervision
@@ -235,8 +234,6 @@ class ObjaverseData(Dataset):
235
  'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
236
  'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
237
  'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
238
-
239
- 'depth_available': 1,
240
  }
241
  return data
242
 
@@ -245,8 +242,8 @@ class ValidationData(Dataset):
245
  def __init__(self,
246
  root_dir='objaverse/',
247
  input_view_num=6,
248
- input_image_size=256,
249
- fov=50,
250
  ):
251
  self.root_dir = Path(root_dir)
252
  self.input_view_num = input_view_num
@@ -256,9 +253,9 @@ class ValidationData(Dataset):
256
  self.paths = sorted(os.listdir(self.root_dir))
257
  print('============= length of dataset %d =============' % len(self.paths))
258
 
259
- cam_distance = 2.5
260
  azimuths = np.array([30, 90, 150, 210, 270, 330])
261
- elevations = np.array([30, -20, 30, -20, 30, -20])
262
  azimuths = np.deg2rad(azimuths)
263
  elevations = np.deg2rad(elevations)
264
 
@@ -272,7 +269,7 @@ class ValidationData(Dataset):
272
  self.c2ws = c2ws.float()
273
  self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
274
 
275
- render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
276
  render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
277
  self.render_c2ws = render_c2ws.float()
278
  self.render_Ks = render_Ks.float()
@@ -303,7 +300,6 @@ class ValidationData(Dataset):
303
  input_image_path = os.path.join(self.root_dir, self.paths[index])
304
 
305
  '''background color, default: white'''
306
- # color = np.random.uniform(0.48, 0.52)
307
  bkg_color = [1.0, 1.0, 1.0]
308
 
309
  image_list = []
@@ -314,14 +310,14 @@ class ValidationData(Dataset):
314
  image_list.append(image)
315
  alpha_list.append(alpha)
316
 
317
- images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
318
- alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
319
 
320
  data = {
321
- 'input_images': images, # (6, 3, H, W)
322
- 'input_alphas': alphas, # (6, 1, H, W)
323
- 'input_c2ws': self.c2ws, # (6, 4, 4)
324
- 'input_Ks': self.Ks, # (6, 3, 3)
325
 
326
  'render_c2ws': self.render_c2ws,
327
  'render_Ks': self.render_Ks,
 
22
  from src.utils.camera_util import (
23
  FOV_to_intrinsics,
24
  center_looking_at_camera_pose,
25
+ get_circular_camera_poses,
26
  )
27
 
28
 
 
78
  input_image_dir='rendering_random_32views',
79
  target_image_dir='rendering_random_32views',
80
  input_view_num=6,
81
+ target_view_num=4,
82
  total_view_n=32,
83
  fov=50,
84
  camera_rotation=True,
 
99
  paths = filtered_dict['good_objs']
100
  self.paths = paths
101
 
102
+ self.depth_scale = 6.0
103
 
104
  total_objects = len(self.paths)
105
  print('============= length of dataset %d =============' % len(self.paths))
 
122
  return image, alpha
123
 
124
  def __getitem__(self, index):
 
125
  while True:
126
  input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
127
  target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
 
211
 
212
  # random scaling
213
  if np.random.rand() < 0.5:
214
+ scale = np.random.uniform(0.7, 1.1)
215
  c2ws[:, :3, 3] *= scale
216
  depths *= scale
217
 
 
220
  Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
221
 
222
  data = {
223
+ 'input_images': images[:self.input_view_num], # (6, 3, H, W)
224
  'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
225
  'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
226
  'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
227
+ 'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4)
228
  'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
229
 
230
  # lrm generator input and supervision
 
234
  'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
235
  'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
236
  'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
 
 
237
  }
238
  return data
239
 
 
242
  def __init__(self,
243
  root_dir='objaverse/',
244
  input_view_num=6,
245
+ input_image_size=320,
246
+ fov=30,
247
  ):
248
  self.root_dir = Path(root_dir)
249
  self.input_view_num = input_view_num
 
253
  self.paths = sorted(os.listdir(self.root_dir))
254
  print('============= length of dataset %d =============' % len(self.paths))
255
 
256
+ cam_distance = 4.0
257
  azimuths = np.array([30, 90, 150, 210, 270, 330])
258
+ elevations = np.array([20, -10, 20, -10, 20, -10])
259
  azimuths = np.deg2rad(azimuths)
260
  elevations = np.deg2rad(elevations)
261
 
 
269
  self.c2ws = c2ws.float()
270
  self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
271
 
272
+ render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
273
  render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
274
  self.render_c2ws = render_c2ws.float()
275
  self.render_Ks = render_Ks.float()
 
300
  input_image_path = os.path.join(self.root_dir, self.paths[index])
301
 
302
  '''background color, default: white'''
 
303
  bkg_color = [1.0, 1.0, 1.0]
304
 
305
  image_list = []
 
310
  image_list.append(image)
311
  alpha_list.append(alpha)
312
 
313
+ images = torch.stack(image_list, dim=0).float()
314
+ alphas = torch.stack(alpha_list, dim=0).float()
315
 
316
  data = {
317
+ 'input_images': images,
318
+ 'input_alphas': alphas,
319
+ 'input_c2ws': self.c2ws,
320
+ 'input_Ks': self.Ks,
321
 
322
  'render_c2ws': self.render_c2ws,
323
  'render_Ks': self.render_Ks,
src/data/objaverse_zero123plus.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import webdataset as wds
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from PIL import Image
10
+ from pathlib import Path
11
+
12
+ from src.utils.train_util import instantiate_from_config
13
+
14
+
15
+ class DataModuleFromConfig(pl.LightningDataModule):
16
+ def __init__(
17
+ self,
18
+ batch_size=8,
19
+ num_workers=4,
20
+ train=None,
21
+ validation=None,
22
+ test=None,
23
+ **kwargs,
24
+ ):
25
+ super().__init__()
26
+
27
+ self.batch_size = batch_size
28
+ self.num_workers = num_workers
29
+
30
+ self.dataset_configs = dict()
31
+ if train is not None:
32
+ self.dataset_configs['train'] = train
33
+ if validation is not None:
34
+ self.dataset_configs['validation'] = validation
35
+ if test is not None:
36
+ self.dataset_configs['test'] = test
37
+
38
+ def setup(self, stage):
39
+
40
+ if stage in ['fit']:
41
+ self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
42
+ else:
43
+ raise NotImplementedError
44
+
45
+ def train_dataloader(self):
46
+
47
+ sampler = DistributedSampler(self.datasets['train'])
48
+ return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
49
+
50
+ def val_dataloader(self):
51
+
52
+ sampler = DistributedSampler(self.datasets['validation'])
53
+ return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)
54
+
55
+ def test_dataloader(self):
56
+
57
+ return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
58
+
59
+
60
+ class ObjaverseData(Dataset):
61
+ def __init__(self,
62
+ root_dir='objaverse/',
63
+ meta_fname='valid_paths.json',
64
+ image_dir='rendering_zero123plus',
65
+ validation=False,
66
+ ):
67
+ self.root_dir = Path(root_dir)
68
+ self.image_dir = image_dir
69
+
70
+ with open(os.path.join(root_dir, meta_fname)) as f:
71
+ lvis_dict = json.load(f)
72
+ paths = []
73
+ for k in lvis_dict.keys():
74
+ paths.extend(lvis_dict[k])
75
+ self.paths = paths
76
+
77
+ total_objects = len(self.paths)
78
+ if validation:
79
+ self.paths = self.paths[-16:] # used last 16 as validation
80
+ else:
81
+ self.paths = self.paths[:-16]
82
+ print('============= length of dataset %d =============' % len(self.paths))
83
+
84
+ def __len__(self):
85
+ return len(self.paths)
86
+
87
+ def load_im(self, path, color):
88
+ pil_img = Image.open(path)
89
+
90
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
91
+ alpha = image[:, :, 3:]
92
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
93
+
94
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
95
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
96
+ return image, alpha
97
+
98
+ def __getitem__(self, index):
99
+ while True:
100
+ image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index])
101
+
102
+ '''background color, default: white'''
103
+ bkg_color = [1., 1., 1.]
104
+
105
+ img_list = []
106
+ try:
107
+ for idx in range(7):
108
+ img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color)
109
+ img_list.append(img)
110
+
111
+ except Exception as e:
112
+ print(e)
113
+ index = np.random.randint(0, len(self.paths))
114
+ continue
115
+
116
+ break
117
+
118
+ imgs = torch.stack(img_list, dim=0).float()
119
+
120
+ data = {
121
+ 'cond_imgs': imgs[0], # (3, H, W)
122
+ 'target_imgs': imgs[1:], # (6, 3, H, W)
123
+ }
124
+ return data
src/model.py CHANGED
@@ -295,16 +295,9 @@ class MVRecon(pl.LightningModule):
295
 
296
  params = []
297
 
298
- lrm_params_fast, lrm_params_slow = [], []
299
- for n, p in self.lrm_generator.named_parameters():
300
- if 'adaLN_modulation' in n or 'camera_embedder' in n:
301
- lrm_params_fast.append(p)
302
- else:
303
- lrm_params_slow.append(p)
304
- params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
305
- params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
306
 
307
  optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
308
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
309
 
310
  return {'optimizer': optimizer, 'lr_scheduler': scheduler}
 
295
 
296
  params = []
297
 
298
+ params.append({"params": self.lrm_generator.parameters(), "lr": lr, "weight_decay": 0.01 })
 
 
 
 
 
 
 
299
 
300
  optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
301
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/10)
302
 
303
  return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/model_mesh.py CHANGED
@@ -56,7 +56,7 @@ class MVRecon(pl.LightningModule):
56
  if 'weight' in k:
57
  sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
  else:
59
- sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
60
  sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
  else:
62
  sd_fc[k.replace('net.', 'net_sdf.')] = v
@@ -274,7 +274,7 @@ class MVRecon(pl.LightningModule):
274
 
275
  loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
 
277
- loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
278
 
279
  prefix = 'train'
280
  loss_dict = {}
 
56
  if 'weight' in k:
57
  sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
  else:
59
+ sd_fc[k.replace('net.', 'net_sdf.')] = 10.0 - v[0:1]
60
  sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
  else:
62
  sd_fc[k.replace('net.', 'net_sdf.')] = v
 
274
 
275
  loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
 
277
+ loss = loss_mse + loss_lpips + loss_mask + loss_depth + loss_normal + loss_reg
278
 
279
  prefix = 'train'
280
  loss_dict = {}
src/models/decoder/transformer.py CHANGED
@@ -53,14 +53,37 @@ class BasicTransformerBlock(nn.Module):
53
  nn.Dropout(mlp_drop),
54
  )
55
 
56
- def forward(self, x, cond):
57
- # x: [N, L, D]
58
- # cond: [N, L_cond, D_cond]
59
- x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
- before_sa = self.norm2(x)
61
- x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
- x = x + self.mlp(self.norm3(x))
63
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  class TriplaneTransformer(nn.Module):
@@ -98,18 +121,34 @@ class TriplaneTransformer(nn.Module):
98
  ])
99
  self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
  self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
 
101
 
102
- def forward(self, image_feats):
103
- # image_feats: [N, L_cond, D_cond]
104
-
105
- N = image_feats.shape[0]
106
  H = W = self.triplane_low_res
107
  L = 3 * H * W
108
-
109
  x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
- for layer in self.layers:
111
- x = layer(x, image_feats)
112
- x = self.norm(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # separate each plane and apply deconv
115
  x = x.view(N, 3, H, W, -1)
 
53
  nn.Dropout(mlp_drop),
54
  )
55
 
56
+ def forward(self, x, cond, i, alpha, content_layers):
57
+ # x: [N, L, D] or [x1, x2]
58
+ # cond: [content_feats] or [content_feats, style_feats]
59
+ if len(cond) == 2:
60
+ # Style injection mode
61
+ x1, x2 = x[0], x[1]
62
+ content, style = cond[0], cond[1]
63
+ if i <= content_layers:
64
+ x1 = x1 + self.cross_attn(self.norm1(x1), content, content)[0]
65
+ else:
66
+ x1 = x1 + (1-alpha)*self.cross_attn(self.norm1(x1), content, content)[0] + (alpha)*self.cross_attn(self.norm1(x1), style, style)[0]
67
+ x2 = x2 + self.cross_attn(self.norm1(x2), style, style)[0]
68
+
69
+ before_sa1 = self.norm2(x1)
70
+ before_sa2 = self.norm2(x2)
71
+ x1 = x1 + self.self_attn(before_sa1, before_sa1, before_sa1)[0]
72
+ x2 = x2 + self.self_attn(before_sa2, before_sa2, before_sa2)[0]
73
+
74
+ x1 = x1 + self.mlp(self.norm3(x1))
75
+ x2 = x2 + self.mlp(self.norm3(x2))
76
+
77
+ return [x1, x2]
78
+ else:
79
+ # No style, only content
80
+ x1 = x[0] if isinstance(x, list) else x
81
+ content = cond[0]
82
+ x1 = x1 + self.cross_attn(self.norm1(x1), content, content)[0]
83
+ before_sa1 = self.norm2(x1)
84
+ x1 = x1 + self.self_attn(before_sa1, before_sa1, before_sa1)[0]
85
+ x1 = x1 + self.mlp(self.norm3(x1))
86
+ return [x1]
87
 
88
 
89
  class TriplaneTransformer(nn.Module):
 
121
  ])
122
  self.norm = nn.LayerNorm(inner_dim, eps=eps)
123
  self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
124
+ self.num_layers = num_layers
125
 
126
+ def forward(self, image_feats, alpha, style_layers):
127
+ # image_feats: [content_feats] or [content_feats, style_feats]
128
+ N = image_feats[0].shape[0]
 
129
  H = W = self.triplane_low_res
130
  L = 3 * H * W
131
+ content_layers = self.num_layers - style_layers
132
  x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
133
+ i = 1
134
+ if len(image_feats) == 2:
135
+ # Style injection mode
136
+ for layer in self.layers:
137
+ if i == 1:
138
+ x = layer([x, x], image_feats, i, alpha, content_layers)
139
+ else:
140
+ x = layer(x, image_feats, i, alpha, content_layers)
141
+ i += 1
142
+ x = self.norm(x[0])
143
+ else:
144
+ # No style, only content
145
+ for layer in self.layers:
146
+ if i == 1:
147
+ x = layer([x], image_feats, i, alpha, content_layers)
148
+ else:
149
+ x = layer(x, image_feats, i, alpha, content_layers)
150
+ i += 1
151
+ x = self.norm(x[0])
152
 
153
  # separate each plane and apply deconv
154
  x = x.view(N, 3, H, W, -1)
src/models/lrm.py CHANGED
@@ -18,6 +18,7 @@ import torch.nn as nn
18
  import mcubes
19
  import nvdiffrast.torch as dr
20
  from einops import rearrange, repeat
 
21
 
22
  from .encoder.dino_wrapper import DinoWrapper
23
  from .decoder.transformer import TriplaneTransformer
@@ -65,19 +66,46 @@ class InstantNeRF(nn.Module):
65
  samples_per_ray=rendering_samples_per_ray,
66
  )
67
 
68
- def forward_planes(self, images, cameras):
69
  # images: [B, V, C_img, H_img, W_img]
70
  # cameras: [B, V, 16]
71
  B = images.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # encode images
74
  image_feats = self.encoder(images, cameras)
75
  image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
76
-
77
  # transformer generating planes
78
- planes = self.transformer(image_feats)
 
 
 
79
 
80
  return planes
 
 
 
 
 
 
 
 
81
 
82
  def forward(self, images, cameras, render_cameras, render_size: int):
83
  # images: [B, V, C_img, H_img, W_img]
@@ -125,7 +153,12 @@ class InstantNeRF(nn.Module):
125
  sample_tex_pose_list.append(tex_pos_one_shape)
126
  tex_pos = torch.cat(sample_tex_pose_list, dim=0)
127
 
128
- tex_feat = self.synthesizer.forward_points(planes, tex_pos)['rgb']
 
 
 
 
 
129
 
130
  if hard_mask is not None:
131
  final_tex_feat = torch.zeros(
 
18
  import mcubes
19
  import nvdiffrast.torch as dr
20
  from einops import rearrange, repeat
21
+ from PIL import Image
22
 
23
  from .encoder.dino_wrapper import DinoWrapper
24
  from .decoder.transformer import TriplaneTransformer
 
66
  samples_per_ray=rendering_samples_per_ray,
67
  )
68
 
69
+ def forward_planes(self, images, cameras, style, alpha, style_layers):
70
  # images: [B, V, C_img, H_img, W_img]
71
  # cameras: [B, V, 16]
72
  B = images.shape[0]
73
+ style_feats = None
74
+
75
+ if style is not None:
76
+ style_img = np.asarray(Image.open(style), dtype=np.float32) / 255.0
77
+ if style_img.ndim == 2: # Handle depth image
78
+ style_img = np.stack([style_img] * 3, axis=-1)
79
+ style_img = torch.from_numpy(style_img).permute(2, 0, 1).contiguous().float()
80
+ style_img = torch.nn.functional.interpolate(
81
+ style_img.unsqueeze(0), size=(320, 320), mode='bilinear', align_corners=False
82
+ ) # Shape: [1, 3, 320, 320]
83
+ style_img = style_img.unsqueeze(1)
84
+ style_img = style_img.to(images.device) # torch.Size([1, 1, 3, 320, 320])
85
+ if style_img.shape[2] == 4: # Check if there are 4 channels
86
+ style_img = style_img[:, :, :3, :, :]
87
+ style_feats = self.encoder(style_img, cameras[:, :1, :]) # torch.Size([6, 401, 768]) cameras:torch.Size([1, 6, 16])
88
+ style_feats = rearrange(style_feats, '(b v) l d -> b (v l) d', b=B)
89
 
90
  # encode images
91
  image_feats = self.encoder(images, cameras)
92
  image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
93
+
94
  # transformer generating planes
95
+ if style_feats is not None:
96
+ planes = self.transformer([image_feats, style_feats], alpha, style_layers)
97
+ else:
98
+ planes = self.transformer([image_feats], alpha, style_layers)
99
 
100
  return planes
101
+
102
+ def forward_synthesizer(self, planes, render_cameras, render_size: int):
103
+ render_results = self.synthesizer(
104
+ planes,
105
+ render_cameras,
106
+ render_size,
107
+ )
108
+ return render_results
109
 
110
  def forward(self, images, cameras, render_cameras, render_size: int):
111
  # images: [B, V, C_img, H_img, W_img]
 
153
  sample_tex_pose_list.append(tex_pos_one_shape)
154
  tex_pos = torch.cat(sample_tex_pose_list, dim=0)
155
 
156
+ tex_feat = torch.utils.checkpoint.checkpoint(
157
+ self.synthesizer.forward_points,
158
+ planes,
159
+ tex_pos,
160
+ use_reentrant=False,
161
+ )['rgb']
162
 
163
  if hard_mask is not None:
164
  final_tex_feat = torch.zeros(
src/models/lrm_mesh.py CHANGED
@@ -17,6 +17,7 @@ import torch
17
  import torch.nn as nn
18
  import nvdiffrast.torch as dr
19
  from einops import rearrange, repeat
 
20
 
21
  from .encoder.dino_wrapper import DinoWrapper
22
  from .decoder.transformer import TriplaneTransformer
@@ -74,12 +75,9 @@ class InstantMesh(nn.Module):
74
  samples_per_ray=rendering_samples_per_ray,
75
  )
76
 
77
- def init_flexicubes_geometry(self, device, fovy=50.0, use_renderer=True):
78
  camera = PerspectiveCamera(fovy=fovy, device=device)
79
- if use_renderer:
80
- renderer = NeuralRender(device, camera_model=camera)
81
- else:
82
- renderer = None
83
  self.geometry = FlexiCubesGeometry(
84
  grid_res=self.grid_res,
85
  scale=self.grid_scale,
@@ -88,17 +86,36 @@ class InstantMesh(nn.Module):
88
  device=device,
89
  )
90
 
91
- def forward_planes(self, images, cameras):
92
  # images: [B, V, C_img, H_img, W_img]
93
  # cameras: [B, V, 16]
94
  B = images.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # encode images
97
  image_feats = self.encoder(images, cameras)
98
  image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
99
-
100
- # decode triplanes
101
- planes = self.transformer(image_feats)
 
 
 
102
 
103
  return planes
104
 
 
17
  import torch.nn as nn
18
  import nvdiffrast.torch as dr
19
  from einops import rearrange, repeat
20
+ from PIL import Image
21
 
22
  from .encoder.dino_wrapper import DinoWrapper
23
  from .decoder.transformer import TriplaneTransformer
 
75
  samples_per_ray=rendering_samples_per_ray,
76
  )
77
 
78
+ def init_flexicubes_geometry(self, device, fovy=50.0):
79
  camera = PerspectiveCamera(fovy=fovy, device=device)
80
+ renderer = NeuralRender(device, camera_model=camera)
 
 
 
81
  self.geometry = FlexiCubesGeometry(
82
  grid_res=self.grid_res,
83
  scale=self.grid_scale,
 
86
  device=device,
87
  )
88
 
89
+ def forward_planes(self, images, cameras, style, alpha, style_layers):
90
  # images: [B, V, C_img, H_img, W_img]
91
  # cameras: [B, V, 16]
92
  B = images.shape[0]
93
+ style_feats = None
94
+
95
+ if style is not None:
96
+ style_img = np.asarray(Image.open(style), dtype=np.float32) / 255.0
97
+ if style_img.ndim == 2: # Handle depth image
98
+ style_img = np.stack([style_img] * 3, axis=-1)
99
+ style_img = torch.from_numpy(style_img).permute(2, 0, 1).contiguous().float()
100
+ style_img = torch.nn.functional.interpolate(
101
+ style_img.unsqueeze(0), size=(320, 320), mode='bilinear', align_corners=False
102
+ ) # Shape: [1, 3, 320, 320]
103
+ style_img = style_img.unsqueeze(1)
104
+ style_img = style_img.to(images.device) # torch.Size([1, 1, 3, 320, 320])
105
+ if style_img.shape[2] == 4: # Check if there are 4 channels
106
+ style_img = style_img[:, :, :3, :, :]
107
+ style_feats = self.encoder(style_img, cameras[:, :1, :]) # torch.Size([6, 401, 768]) cameras:torch.Size([1, 6, 16])
108
+ style_feats = rearrange(style_feats, '(b v) l d -> b (v l) d', b=B)
109
 
110
  # encode images
111
  image_feats = self.encoder(images, cameras)
112
  image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
113
+
114
+ # transformer generating planes
115
+ if style_feats is not None:
116
+ planes = self.transformer([image_feats, style_feats], alpha, style_layers)
117
+ else:
118
+ planes = self.transformer([image_feats], alpha, style_layers)
119
 
120
  return planes
121
 
src/utils/infer_util.py CHANGED
@@ -81,4 +81,17 @@ def images_to_video(
81
  assert frame.min() >= 0 and frame.max() <= 255, \
82
  f"Frame value out of range: {frame.min()} ~ {frame.max()}"
83
  frames.append(frame)
84
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  assert frame.min() >= 0 and frame.max() <= 255, \
82
  f"Frame value out of range: {frame.min()} ~ {frame.max()}"
83
  frames.append(frame)
84
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
85
+
86
+
87
+ def save_video(
88
+ frames: torch.Tensor,
89
+ output_path: str,
90
+ fps: int = 30,
91
+ ) -> None:
92
+ # images: (N, C, H, W)
93
+ frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]
94
+ writer = imageio.get_writer(output_path, fps=fps)
95
+ for frame in frames:
96
+ writer.append_data(frame)
97
+ writer.close()