Spaces:
Runtime error
Runtime error
ipekoztas
commited on
Commit
Β·
b7c5eaf
1
Parent(s):
7860b91
Code upload.
Browse files- README.md +1 -1
- app.py +93 -78
- requirements.txt +11 -3
- src/data/objaverse.py +17 -21
- src/data/objaverse_zero123plus.py +124 -0
- src/model.py +2 -9
- src/model_mesh.py +2 -2
- src/models/decoder/transformer.py +55 -16
- src/models/lrm.py +37 -4
- src/models/lrm_mesh.py +26 -9
- src/utils/infer_util.py +14 -1
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: π
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: green
|
|
|
|
| 1 |
---
|
| 2 |
+
title: 3D Stylization LRM
|
| 3 |
emoji: π
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: green
|
app.py
CHANGED
|
@@ -30,6 +30,7 @@ from huggingface_hub import hf_hub_download
|
|
| 30 |
import gradio as gr
|
| 31 |
|
| 32 |
|
|
|
|
| 33 |
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
| 34 |
"""
|
| 35 |
Get the rendering camera parameters.
|
|
@@ -90,7 +91,7 @@ if cuda_path:
|
|
| 90 |
else:
|
| 91 |
print("CUDA installation not found")
|
| 92 |
|
| 93 |
-
config_path = 'configs/instant-
|
| 94 |
config = OmegaConf.load(config_path)
|
| 95 |
config_name = os.path.basename(config_path).replace('.yaml', '')
|
| 96 |
model_config = config.model_config
|
|
@@ -120,7 +121,7 @@ pipeline = pipeline.to(device)
|
|
| 120 |
|
| 121 |
# load reconstruction model
|
| 122 |
print('Loading reconstruction model ...')
|
| 123 |
-
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="
|
| 124 |
model = instantiate_from_config(model_config)
|
| 125 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
| 126 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
|
|
@@ -134,6 +135,10 @@ print('Loading Finished!')
|
|
| 134 |
def check_input_image(input_image):
|
| 135 |
if input_image is None:
|
| 136 |
raise gr.Error("No image uploaded!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
def preprocess(input_image, do_remove_background):
|
|
@@ -158,7 +163,7 @@ def generate_mvs(input_image, sample_steps, sample_seed):
|
|
| 158 |
num_inference_steps=sample_steps
|
| 159 |
).images[0]
|
| 160 |
|
| 161 |
-
show_image = np.asarray(z123_image, dtype=np.uint8)
|
| 162 |
show_image = torch.from_numpy(show_image) # (960, 640, 3)
|
| 163 |
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
| 164 |
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
|
@@ -166,66 +171,53 @@ def generate_mvs(input_image, sample_steps, sample_seed):
|
|
| 166 |
|
| 167 |
return z123_image, show_image
|
| 168 |
|
| 169 |
-
|
| 170 |
@spaces.GPU
|
| 171 |
-
def make3d(
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
global model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if IS_FLEXICUBES:
|
| 175 |
model.init_flexicubes_geometry(device, use_renderer=False)
|
| 176 |
-
model = model.eval()
|
| 177 |
|
| 178 |
-
images = np.asarray(
|
| 179 |
-
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
|
| 180 |
-
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
|
| 181 |
|
| 182 |
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
|
| 183 |
-
render_cameras = get_render_cameras(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
images = images.unsqueeze(0).to(device)
|
| 186 |
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
| 187 |
|
| 188 |
-
mesh_fpath = tempfile.NamedTemporaryFile(suffix=
|
| 189 |
-
print(mesh_fpath)
|
| 190 |
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
| 191 |
mesh_dirname = os.path.dirname(mesh_fpath)
|
| 192 |
-
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
|
| 193 |
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
| 194 |
|
| 195 |
with torch.no_grad():
|
| 196 |
-
# get triplane
|
| 197 |
-
planes = model.forward_planes(
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
#
|
| 206 |
-
# frame = model.forward_geometry(
|
| 207 |
-
# planes,
|
| 208 |
-
# render_cameras[:, i:i+chunk_size],
|
| 209 |
-
# render_size=render_size,
|
| 210 |
-
# )['img']
|
| 211 |
-
# else:
|
| 212 |
-
# frame = model.synthesizer(
|
| 213 |
-
# planes,
|
| 214 |
-
# cameras=render_cameras[:, i:i+chunk_size],
|
| 215 |
-
# render_size=render_size,
|
| 216 |
-
# )['images_rgb']
|
| 217 |
-
# frames.append(frame)
|
| 218 |
-
# frames = torch.cat(frames, dim=1)
|
| 219 |
-
|
| 220 |
-
# images_to_video(
|
| 221 |
-
# frames[0],
|
| 222 |
-
# video_fpath,
|
| 223 |
-
# fps=30,
|
| 224 |
-
# )
|
| 225 |
-
|
| 226 |
-
# print(f"Video saved to {video_fpath}")
|
| 227 |
-
|
| 228 |
-
# get mesh
|
| 229 |
mesh_out = model.extract_mesh(
|
| 230 |
planes,
|
| 231 |
use_texture_map=False,
|
|
@@ -234,52 +226,40 @@ def make3d(images):
|
|
| 234 |
|
| 235 |
vertices, faces, vertex_colors = mesh_out
|
| 236 |
vertices = vertices[:, [1, 2, 0]]
|
| 237 |
-
|
| 238 |
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
|
| 239 |
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
| 240 |
-
|
| 241 |
print(f"Mesh saved to {mesh_fpath}")
|
| 242 |
-
|
| 243 |
return mesh_fpath, mesh_glb_fpath
|
| 244 |
|
| 245 |
-
|
| 246 |
_HEADER_ = '''
|
| 247 |
-
<h2><b>
|
| 248 |
-
|
| 249 |
-
**InstantMesh** is a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
-
|
| 255 |
-
- The 3D mesh generation results highly depend on the quality of generated multi-view images. Please try a different **seed value** if the result is unsatisfying (Default: 42).
|
| 256 |
'''
|
| 257 |
|
| 258 |
_CITE_ = r"""
|
| 259 |
-
If
|
| 260 |
---
|
| 261 |
π **Citation**
|
| 262 |
|
| 263 |
If you find our work useful for your research or applications, please cite using this bibtex:
|
| 264 |
```bibtex
|
| 265 |
-
@article{
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
}
|
| 271 |
```
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
|
| 276 |
-
|
| 277 |
-
π§ **Contact**
|
| 278 |
-
|
| 279 |
-
If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
|
| 280 |
"""
|
| 281 |
|
| 282 |
-
|
| 283 |
with gr.Blocks() as demo:
|
| 284 |
gr.Markdown(_HEADER_)
|
| 285 |
with gr.Row(variant="panel"):
|
|
@@ -294,6 +274,13 @@ with gr.Blocks() as demo:
|
|
| 294 |
type="pil",
|
| 295 |
elem_id="content_image",
|
| 296 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
processed_image = gr.Image(
|
| 298 |
label="Processed Image",
|
| 299 |
image_mode="RGBA",
|
|
@@ -317,6 +304,22 @@ with gr.Blocks() as demo:
|
|
| 317 |
step=5
|
| 318 |
)
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
with gr.Row():
|
| 321 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
| 322 |
|
|
@@ -330,6 +333,16 @@ with gr.Blocks() as demo:
|
|
| 330 |
cache_examples=False,
|
| 331 |
examples_per_page=16
|
| 332 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
with gr.Column():
|
| 335 |
|
|
@@ -372,19 +385,21 @@ with gr.Blocks() as demo:
|
|
| 372 |
|
| 373 |
mv_images = gr.State()
|
| 374 |
|
|
|
|
| 375 |
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
|
|
|
|
|
|
| 376 |
fn=preprocess,
|
| 377 |
inputs=[input_image, do_remove_background],
|
| 378 |
outputs=[processed_image],
|
| 379 |
).success(
|
| 380 |
fn=generate_mvs,
|
| 381 |
inputs=[processed_image, sample_steps, sample_seed],
|
| 382 |
-
outputs=[mv_images, mv_show_images]
|
| 383 |
-
|
| 384 |
).success(
|
| 385 |
fn=make3d,
|
| 386 |
-
inputs=[mv_images],
|
| 387 |
-
outputs=[output_model_obj, output_model_glb]
|
| 388 |
)
|
| 389 |
|
| 390 |
demo.launch()
|
|
|
|
| 30 |
import gradio as gr
|
| 31 |
|
| 32 |
|
| 33 |
+
|
| 34 |
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
| 35 |
"""
|
| 36 |
Get the rendering camera parameters.
|
|
|
|
| 91 |
else:
|
| 92 |
print("CUDA installation not found")
|
| 93 |
|
| 94 |
+
config_path = 'configs/instant-nerf-large.yaml'
|
| 95 |
config = OmegaConf.load(config_path)
|
| 96 |
config_name = os.path.basename(config_path).replace('.yaml', '')
|
| 97 |
model_config = config.model_config
|
|
|
|
| 121 |
|
| 122 |
# load reconstruction model
|
| 123 |
print('Loading reconstruction model ...')
|
| 124 |
+
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_nerf_large.ckpt", repo_type="model")
|
| 125 |
model = instantiate_from_config(model_config)
|
| 126 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
| 127 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
|
|
|
|
| 135 |
def check_input_image(input_image):
|
| 136 |
if input_image is None:
|
| 137 |
raise gr.Error("No image uploaded!")
|
| 138 |
+
|
| 139 |
+
def check_style_image(style_image):
|
| 140 |
+
if style_image is None:
|
| 141 |
+
raise gr.Error("No style image uploaded!")
|
| 142 |
|
| 143 |
|
| 144 |
def preprocess(input_image, do_remove_background):
|
|
|
|
| 163 |
num_inference_steps=sample_steps
|
| 164 |
).images[0]
|
| 165 |
|
| 166 |
+
show_image = np.asarray(z123_image, dtype=np.uint8).copy()
|
| 167 |
show_image = torch.from_numpy(show_image) # (960, 640, 3)
|
| 168 |
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
| 169 |
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
|
|
|
| 171 |
|
| 172 |
return z123_image, show_image
|
| 173 |
|
|
|
|
| 174 |
@spaces.GPU
|
| 175 |
+
def make3d(mv_images, style_image, alpha, style_layers):
|
| 176 |
+
"""
|
| 177 |
+
mv_images: single multi-view image (pil or numpy)
|
| 178 |
+
style_image: PIL image
|
| 179 |
+
alpha: float
|
| 180 |
+
style_layers: int
|
| 181 |
+
"""
|
| 182 |
global model
|
| 183 |
+
|
| 184 |
+
# Save the uploaded style image to a temporary file, so the model can read it from disk
|
| 185 |
+
style_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
|
| 186 |
+
style_image.save(style_path)
|
| 187 |
+
|
| 188 |
if IS_FLEXICUBES:
|
| 189 |
model.init_flexicubes_geometry(device, use_renderer=False)
|
|
|
|
| 190 |
|
| 191 |
+
images = np.asarray(mv_images, dtype=np.float32) / 255.0
|
| 192 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
| 193 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
| 194 |
|
| 195 |
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
|
| 196 |
+
render_cameras = get_render_cameras(
|
| 197 |
+
batch_size=1,
|
| 198 |
+
radius=2.5,
|
| 199 |
+
is_flexicubes=IS_FLEXICUBES
|
| 200 |
+
).to(device)
|
| 201 |
|
| 202 |
images = images.unsqueeze(0).to(device)
|
| 203 |
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
| 204 |
|
| 205 |
+
mesh_fpath = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
|
|
|
|
| 206 |
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
| 207 |
mesh_dirname = os.path.dirname(mesh_fpath)
|
|
|
|
| 208 |
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
| 209 |
|
| 210 |
with torch.no_grad():
|
| 211 |
+
# get triplane, now passing style_path, alpha, style_layers
|
| 212 |
+
planes = model.forward_planes(
|
| 213 |
+
images,
|
| 214 |
+
input_cameras,
|
| 215 |
+
style_path,
|
| 216 |
+
float(alpha),
|
| 217 |
+
int(style_layers),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# extract mesh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
mesh_out = model.extract_mesh(
|
| 222 |
planes,
|
| 223 |
use_texture_map=False,
|
|
|
|
| 226 |
|
| 227 |
vertices, faces, vertex_colors = mesh_out
|
| 228 |
vertices = vertices[:, [1, 2, 0]]
|
| 229 |
+
|
| 230 |
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
|
| 231 |
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
|
|
|
| 232 |
print(f"Mesh saved to {mesh_fpath}")
|
|
|
|
| 233 |
return mesh_fpath, mesh_glb_fpath
|
| 234 |
|
|
|
|
| 235 |
_HEADER_ = '''
|
| 236 |
+
<h2><b>3DStylizationLRM</b></h2>
|
| 237 |
+
This demo lets you provide a content image, a style image, an alpha blending value, and the number of style layers to inject. It will generate 3D geometry stylized accordingly.
|
|
|
|
| 238 |
|
| 239 |
+
βοΈβοΈβοΈ **Notes:**
|
| 240 |
+
- Content image background can be removed automatically.
|
| 241 |
+
- Adjust the **Alpha** slider to control style blending strength.
|
| 242 |
+
- Adjust **Style Layers** to choose how many layers of style to inject.
|
|
|
|
| 243 |
'''
|
| 244 |
|
| 245 |
_CITE_ = r"""
|
| 246 |
+
If 3D Stylization LRM is helpful, please help to β the <a href='https://github.com/ipekoztas/3D-Stylization-LRM' target='_blank'>Github Repo</a>. Thanks!
|
| 247 |
---
|
| 248 |
π **Citation**
|
| 249 |
|
| 250 |
If you find our work useful for your research or applications, please cite using this bibtex:
|
| 251 |
```bibtex
|
| 252 |
+
@article{oztas20253dstylizationlargereconstruction,
|
| 253 |
+
title={3D Stylization via Large Reconstruction Model},
|
| 254 |
+
author={Ipek Oztas and Duygu Ceylan and Aysegul Dundar},
|
| 255 |
+
journal={https://arxiv.org/abs/2504.21836},
|
| 256 |
+
year={2025}
|
| 257 |
}
|
| 258 |
```
|
| 259 |
+
π **License**
|
| 260 |
+
Apache-2.0 LICENSE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
"""
|
| 262 |
|
|
|
|
| 263 |
with gr.Blocks() as demo:
|
| 264 |
gr.Markdown(_HEADER_)
|
| 265 |
with gr.Row(variant="panel"):
|
|
|
|
| 274 |
type="pil",
|
| 275 |
elem_id="content_image",
|
| 276 |
)
|
| 277 |
+
# Style Image Upload
|
| 278 |
+
style_image = gr.Image(
|
| 279 |
+
label="Style Image",
|
| 280 |
+
image_mode="RGB",
|
| 281 |
+
type="pil",
|
| 282 |
+
elem_id="style_image",
|
| 283 |
+
)
|
| 284 |
processed_image = gr.Image(
|
| 285 |
label="Processed Image",
|
| 286 |
image_mode="RGBA",
|
|
|
|
| 304 |
step=5
|
| 305 |
)
|
| 306 |
|
| 307 |
+
with gr.Row():
|
| 308 |
+
alpha = gr.Slider(
|
| 309 |
+
label="Alpha Value",
|
| 310 |
+
minimum=0.0,
|
| 311 |
+
maximum=1.0,
|
| 312 |
+
value=0.7,
|
| 313 |
+
step=0.01,
|
| 314 |
+
)
|
| 315 |
+
style_layers = gr.Slider(
|
| 316 |
+
label="Style Layers",
|
| 317 |
+
minimum=1,
|
| 318 |
+
maximum=10,
|
| 319 |
+
value=4,
|
| 320 |
+
step=1,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
with gr.Row():
|
| 324 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
| 325 |
|
|
|
|
| 333 |
cache_examples=False,
|
| 334 |
examples_per_page=16
|
| 335 |
)
|
| 336 |
+
with gr.Row(variant="panel"):
|
| 337 |
+
gr.Examples(
|
| 338 |
+
examples=[
|
| 339 |
+
os.path.join("styles", img_name) for img_name in sorted(os.listdir("styles"))
|
| 340 |
+
],
|
| 341 |
+
inputs=[input_image],
|
| 342 |
+
label="Styles",
|
| 343 |
+
cache_examples=False,
|
| 344 |
+
examples_per_page=16
|
| 345 |
+
)
|
| 346 |
|
| 347 |
with gr.Column():
|
| 348 |
|
|
|
|
| 385 |
|
| 386 |
mv_images = gr.State()
|
| 387 |
|
| 388 |
+
# Chain of actions:
|
| 389 |
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
| 390 |
+
fn=check_style_image, inputs=[style_image]
|
| 391 |
+
).success(
|
| 392 |
fn=preprocess,
|
| 393 |
inputs=[input_image, do_remove_background],
|
| 394 |
outputs=[processed_image],
|
| 395 |
).success(
|
| 396 |
fn=generate_mvs,
|
| 397 |
inputs=[processed_image, sample_steps, sample_seed],
|
| 398 |
+
outputs=[mv_images, mv_show_images],
|
|
|
|
| 399 |
).success(
|
| 400 |
fn=make3d,
|
| 401 |
+
inputs=[mv_images, style_image, alpha, style_layers],
|
| 402 |
+
outputs=[output_model_obj, output_model_glb],
|
| 403 |
)
|
| 404 |
|
| 405 |
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1,7 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
torch==2.1.0
|
| 2 |
torchvision==0.16.0
|
| 3 |
torchaudio==2.1.0
|
| 4 |
pytorch-lightning==2.1.2
|
|
|
|
| 5 |
einops
|
| 6 |
omegaconf
|
| 7 |
deepspeed
|
|
@@ -12,12 +17,15 @@ tensorboard
|
|
| 12 |
PyMCubes
|
| 13 |
trimesh
|
| 14 |
rembg
|
| 15 |
-
transformers==4.
|
| 16 |
-
diffusers==0.
|
|
|
|
|
|
|
| 17 |
bitsandbytes
|
| 18 |
imageio[ffmpeg]
|
| 19 |
xatlas
|
| 20 |
plyfile
|
|
|
|
| 21 |
xformers==0.0.22.post7
|
| 22 |
git+https://github.com/NVlabs/nvdiffrast/
|
| 23 |
-
|
|
|
|
| 1 |
+
pydantic==2.10.6
|
| 2 |
+
gradio==4.44.1
|
| 3 |
+
gradio-client==1.3.0
|
| 4 |
+
huggingface-hub==0.25.2
|
| 5 |
torch==2.1.0
|
| 6 |
torchvision==0.16.0
|
| 7 |
torchaudio==2.1.0
|
| 8 |
pytorch-lightning==2.1.2
|
| 9 |
+
|
| 10 |
einops
|
| 11 |
omegaconf
|
| 12 |
deepspeed
|
|
|
|
| 17 |
PyMCubes
|
| 18 |
trimesh
|
| 19 |
rembg
|
| 20 |
+
transformers==4.39.3
|
| 21 |
+
diffusers==0.27.0
|
| 22 |
+
tokenizers==0.15.2
|
| 23 |
+
|
| 24 |
bitsandbytes
|
| 25 |
imageio[ffmpeg]
|
| 26 |
xatlas
|
| 27 |
plyfile
|
| 28 |
+
|
| 29 |
xformers==0.0.22.post7
|
| 30 |
git+https://github.com/NVlabs/nvdiffrast/
|
| 31 |
+
onnxruntime
|
src/data/objaverse.py
CHANGED
|
@@ -22,7 +22,7 @@ from src.utils.train_util import instantiate_from_config
|
|
| 22 |
from src.utils.camera_util import (
|
| 23 |
FOV_to_intrinsics,
|
| 24 |
center_looking_at_camera_pose,
|
| 25 |
-
|
| 26 |
)
|
| 27 |
|
| 28 |
|
|
@@ -78,7 +78,7 @@ class ObjaverseData(Dataset):
|
|
| 78 |
input_image_dir='rendering_random_32views',
|
| 79 |
target_image_dir='rendering_random_32views',
|
| 80 |
input_view_num=6,
|
| 81 |
-
target_view_num=
|
| 82 |
total_view_n=32,
|
| 83 |
fov=50,
|
| 84 |
camera_rotation=True,
|
|
@@ -99,7 +99,7 @@ class ObjaverseData(Dataset):
|
|
| 99 |
paths = filtered_dict['good_objs']
|
| 100 |
self.paths = paths
|
| 101 |
|
| 102 |
-
self.depth_scale =
|
| 103 |
|
| 104 |
total_objects = len(self.paths)
|
| 105 |
print('============= length of dataset %d =============' % len(self.paths))
|
|
@@ -122,7 +122,6 @@ class ObjaverseData(Dataset):
|
|
| 122 |
return image, alpha
|
| 123 |
|
| 124 |
def __getitem__(self, index):
|
| 125 |
-
# load data
|
| 126 |
while True:
|
| 127 |
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
|
| 128 |
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
|
|
@@ -212,7 +211,7 @@ class ObjaverseData(Dataset):
|
|
| 212 |
|
| 213 |
# random scaling
|
| 214 |
if np.random.rand() < 0.5:
|
| 215 |
-
scale = np.random.uniform(0.
|
| 216 |
c2ws[:, :3, 3] *= scale
|
| 217 |
depths *= scale
|
| 218 |
|
|
@@ -221,11 +220,11 @@ class ObjaverseData(Dataset):
|
|
| 221 |
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
|
| 222 |
|
| 223 |
data = {
|
| 224 |
-
'input_images': images[:self.input_view_num],
|
| 225 |
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
|
| 226 |
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
|
| 227 |
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
|
| 228 |
-
'input_c2ws':
|
| 229 |
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
|
| 230 |
|
| 231 |
# lrm generator input and supervision
|
|
@@ -235,8 +234,6 @@ class ObjaverseData(Dataset):
|
|
| 235 |
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
|
| 236 |
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
|
| 237 |
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
|
| 238 |
-
|
| 239 |
-
'depth_available': 1,
|
| 240 |
}
|
| 241 |
return data
|
| 242 |
|
|
@@ -245,8 +242,8 @@ class ValidationData(Dataset):
|
|
| 245 |
def __init__(self,
|
| 246 |
root_dir='objaverse/',
|
| 247 |
input_view_num=6,
|
| 248 |
-
input_image_size=
|
| 249 |
-
fov=
|
| 250 |
):
|
| 251 |
self.root_dir = Path(root_dir)
|
| 252 |
self.input_view_num = input_view_num
|
|
@@ -256,9 +253,9 @@ class ValidationData(Dataset):
|
|
| 256 |
self.paths = sorted(os.listdir(self.root_dir))
|
| 257 |
print('============= length of dataset %d =============' % len(self.paths))
|
| 258 |
|
| 259 |
-
cam_distance =
|
| 260 |
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
| 261 |
-
elevations = np.array([
|
| 262 |
azimuths = np.deg2rad(azimuths)
|
| 263 |
elevations = np.deg2rad(elevations)
|
| 264 |
|
|
@@ -272,7 +269,7 @@ class ValidationData(Dataset):
|
|
| 272 |
self.c2ws = c2ws.float()
|
| 273 |
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
| 274 |
|
| 275 |
-
render_c2ws =
|
| 276 |
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
| 277 |
self.render_c2ws = render_c2ws.float()
|
| 278 |
self.render_Ks = render_Ks.float()
|
|
@@ -303,7 +300,6 @@ class ValidationData(Dataset):
|
|
| 303 |
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
| 304 |
|
| 305 |
'''background color, default: white'''
|
| 306 |
-
# color = np.random.uniform(0.48, 0.52)
|
| 307 |
bkg_color = [1.0, 1.0, 1.0]
|
| 308 |
|
| 309 |
image_list = []
|
|
@@ -314,14 +310,14 @@ class ValidationData(Dataset):
|
|
| 314 |
image_list.append(image)
|
| 315 |
alpha_list.append(alpha)
|
| 316 |
|
| 317 |
-
images = torch.stack(image_list, dim=0).float()
|
| 318 |
-
alphas = torch.stack(alpha_list, dim=0).float()
|
| 319 |
|
| 320 |
data = {
|
| 321 |
-
'input_images': images,
|
| 322 |
-
'input_alphas': alphas,
|
| 323 |
-
'input_c2ws': self.c2ws,
|
| 324 |
-
'input_Ks': self.Ks,
|
| 325 |
|
| 326 |
'render_c2ws': self.render_c2ws,
|
| 327 |
'render_Ks': self.render_Ks,
|
|
|
|
| 22 |
from src.utils.camera_util import (
|
| 23 |
FOV_to_intrinsics,
|
| 24 |
center_looking_at_camera_pose,
|
| 25 |
+
get_circular_camera_poses,
|
| 26 |
)
|
| 27 |
|
| 28 |
|
|
|
|
| 78 |
input_image_dir='rendering_random_32views',
|
| 79 |
target_image_dir='rendering_random_32views',
|
| 80 |
input_view_num=6,
|
| 81 |
+
target_view_num=4,
|
| 82 |
total_view_n=32,
|
| 83 |
fov=50,
|
| 84 |
camera_rotation=True,
|
|
|
|
| 99 |
paths = filtered_dict['good_objs']
|
| 100 |
self.paths = paths
|
| 101 |
|
| 102 |
+
self.depth_scale = 6.0
|
| 103 |
|
| 104 |
total_objects = len(self.paths)
|
| 105 |
print('============= length of dataset %d =============' % len(self.paths))
|
|
|
|
| 122 |
return image, alpha
|
| 123 |
|
| 124 |
def __getitem__(self, index):
|
|
|
|
| 125 |
while True:
|
| 126 |
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
|
| 127 |
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
|
|
|
|
| 211 |
|
| 212 |
# random scaling
|
| 213 |
if np.random.rand() < 0.5:
|
| 214 |
+
scale = np.random.uniform(0.7, 1.1)
|
| 215 |
c2ws[:, :3, 3] *= scale
|
| 216 |
depths *= scale
|
| 217 |
|
|
|
|
| 220 |
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
|
| 221 |
|
| 222 |
data = {
|
| 223 |
+
'input_images': images[:self.input_view_num], # (6, 3, H, W)
|
| 224 |
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
|
| 225 |
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
|
| 226 |
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
|
| 227 |
+
'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4)
|
| 228 |
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
|
| 229 |
|
| 230 |
# lrm generator input and supervision
|
|
|
|
| 234 |
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
|
| 235 |
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
|
| 236 |
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
return data
|
| 239 |
|
|
|
|
| 242 |
def __init__(self,
|
| 243 |
root_dir='objaverse/',
|
| 244 |
input_view_num=6,
|
| 245 |
+
input_image_size=320,
|
| 246 |
+
fov=30,
|
| 247 |
):
|
| 248 |
self.root_dir = Path(root_dir)
|
| 249 |
self.input_view_num = input_view_num
|
|
|
|
| 253 |
self.paths = sorted(os.listdir(self.root_dir))
|
| 254 |
print('============= length of dataset %d =============' % len(self.paths))
|
| 255 |
|
| 256 |
+
cam_distance = 4.0
|
| 257 |
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
| 258 |
+
elevations = np.array([20, -10, 20, -10, 20, -10])
|
| 259 |
azimuths = np.deg2rad(azimuths)
|
| 260 |
elevations = np.deg2rad(elevations)
|
| 261 |
|
|
|
|
| 269 |
self.c2ws = c2ws.float()
|
| 270 |
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
| 271 |
|
| 272 |
+
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
|
| 273 |
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
| 274 |
self.render_c2ws = render_c2ws.float()
|
| 275 |
self.render_Ks = render_Ks.float()
|
|
|
|
| 300 |
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
| 301 |
|
| 302 |
'''background color, default: white'''
|
|
|
|
| 303 |
bkg_color = [1.0, 1.0, 1.0]
|
| 304 |
|
| 305 |
image_list = []
|
|
|
|
| 310 |
image_list.append(image)
|
| 311 |
alpha_list.append(alpha)
|
| 312 |
|
| 313 |
+
images = torch.stack(image_list, dim=0).float()
|
| 314 |
+
alphas = torch.stack(alpha_list, dim=0).float()
|
| 315 |
|
| 316 |
data = {
|
| 317 |
+
'input_images': images,
|
| 318 |
+
'input_alphas': alphas,
|
| 319 |
+
'input_c2ws': self.c2ws,
|
| 320 |
+
'input_Ks': self.Ks,
|
| 321 |
|
| 322 |
'render_c2ws': self.render_c2ws,
|
| 323 |
'render_Ks': self.render_Ks,
|
src/data/objaverse_zero123plus.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import webdataset as wds
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from src.utils.train_util import instantiate_from_config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
batch_size=8,
|
| 19 |
+
num_workers=4,
|
| 20 |
+
train=None,
|
| 21 |
+
validation=None,
|
| 22 |
+
test=None,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
self.batch_size = batch_size
|
| 28 |
+
self.num_workers = num_workers
|
| 29 |
+
|
| 30 |
+
self.dataset_configs = dict()
|
| 31 |
+
if train is not None:
|
| 32 |
+
self.dataset_configs['train'] = train
|
| 33 |
+
if validation is not None:
|
| 34 |
+
self.dataset_configs['validation'] = validation
|
| 35 |
+
if test is not None:
|
| 36 |
+
self.dataset_configs['test'] = test
|
| 37 |
+
|
| 38 |
+
def setup(self, stage):
|
| 39 |
+
|
| 40 |
+
if stage in ['fit']:
|
| 41 |
+
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
def train_dataloader(self):
|
| 46 |
+
|
| 47 |
+
sampler = DistributedSampler(self.datasets['train'])
|
| 48 |
+
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
| 49 |
+
|
| 50 |
+
def val_dataloader(self):
|
| 51 |
+
|
| 52 |
+
sampler = DistributedSampler(self.datasets['validation'])
|
| 53 |
+
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
| 54 |
+
|
| 55 |
+
def test_dataloader(self):
|
| 56 |
+
|
| 57 |
+
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ObjaverseData(Dataset):
|
| 61 |
+
def __init__(self,
|
| 62 |
+
root_dir='objaverse/',
|
| 63 |
+
meta_fname='valid_paths.json',
|
| 64 |
+
image_dir='rendering_zero123plus',
|
| 65 |
+
validation=False,
|
| 66 |
+
):
|
| 67 |
+
self.root_dir = Path(root_dir)
|
| 68 |
+
self.image_dir = image_dir
|
| 69 |
+
|
| 70 |
+
with open(os.path.join(root_dir, meta_fname)) as f:
|
| 71 |
+
lvis_dict = json.load(f)
|
| 72 |
+
paths = []
|
| 73 |
+
for k in lvis_dict.keys():
|
| 74 |
+
paths.extend(lvis_dict[k])
|
| 75 |
+
self.paths = paths
|
| 76 |
+
|
| 77 |
+
total_objects = len(self.paths)
|
| 78 |
+
if validation:
|
| 79 |
+
self.paths = self.paths[-16:] # used last 16 as validation
|
| 80 |
+
else:
|
| 81 |
+
self.paths = self.paths[:-16]
|
| 82 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
| 83 |
+
|
| 84 |
+
def __len__(self):
|
| 85 |
+
return len(self.paths)
|
| 86 |
+
|
| 87 |
+
def load_im(self, path, color):
|
| 88 |
+
pil_img = Image.open(path)
|
| 89 |
+
|
| 90 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
| 91 |
+
alpha = image[:, :, 3:]
|
| 92 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
| 93 |
+
|
| 94 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
| 95 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
| 96 |
+
return image, alpha
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, index):
|
| 99 |
+
while True:
|
| 100 |
+
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index])
|
| 101 |
+
|
| 102 |
+
'''background color, default: white'''
|
| 103 |
+
bkg_color = [1., 1., 1.]
|
| 104 |
+
|
| 105 |
+
img_list = []
|
| 106 |
+
try:
|
| 107 |
+
for idx in range(7):
|
| 108 |
+
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color)
|
| 109 |
+
img_list.append(img)
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(e)
|
| 113 |
+
index = np.random.randint(0, len(self.paths))
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
imgs = torch.stack(img_list, dim=0).float()
|
| 119 |
+
|
| 120 |
+
data = {
|
| 121 |
+
'cond_imgs': imgs[0], # (3, H, W)
|
| 122 |
+
'target_imgs': imgs[1:], # (6, 3, H, W)
|
| 123 |
+
}
|
| 124 |
+
return data
|
src/model.py
CHANGED
|
@@ -295,16 +295,9 @@ class MVRecon(pl.LightningModule):
|
|
| 295 |
|
| 296 |
params = []
|
| 297 |
|
| 298 |
-
|
| 299 |
-
for n, p in self.lrm_generator.named_parameters():
|
| 300 |
-
if 'adaLN_modulation' in n or 'camera_embedder' in n:
|
| 301 |
-
lrm_params_fast.append(p)
|
| 302 |
-
else:
|
| 303 |
-
lrm_params_slow.append(p)
|
| 304 |
-
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
|
| 305 |
-
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
|
| 306 |
|
| 307 |
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
|
| 308 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/
|
| 309 |
|
| 310 |
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
|
|
|
| 295 |
|
| 296 |
params = []
|
| 297 |
|
| 298 |
+
params.append({"params": self.lrm_generator.parameters(), "lr": lr, "weight_decay": 0.01 })
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
|
| 301 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/10)
|
| 302 |
|
| 303 |
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
src/model_mesh.py
CHANGED
|
@@ -56,7 +56,7 @@ class MVRecon(pl.LightningModule):
|
|
| 56 |
if 'weight' in k:
|
| 57 |
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
|
| 58 |
else:
|
| 59 |
-
sd_fc[k.replace('net.', 'net_sdf.')] =
|
| 60 |
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
|
| 61 |
else:
|
| 62 |
sd_fc[k.replace('net.', 'net_sdf.')] = v
|
|
@@ -274,7 +274,7 @@ class MVRecon(pl.LightningModule):
|
|
| 274 |
|
| 275 |
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
|
| 276 |
|
| 277 |
-
loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
|
| 278 |
|
| 279 |
prefix = 'train'
|
| 280 |
loss_dict = {}
|
|
|
|
| 56 |
if 'weight' in k:
|
| 57 |
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
|
| 58 |
else:
|
| 59 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = 10.0 - v[0:1]
|
| 60 |
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
|
| 61 |
else:
|
| 62 |
sd_fc[k.replace('net.', 'net_sdf.')] = v
|
|
|
|
| 274 |
|
| 275 |
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
|
| 276 |
|
| 277 |
+
loss = loss_mse + loss_lpips + loss_mask + loss_depth + loss_normal + loss_reg
|
| 278 |
|
| 279 |
prefix = 'train'
|
| 280 |
loss_dict = {}
|
src/models/decoder/transformer.py
CHANGED
|
@@ -53,14 +53,37 @@ class BasicTransformerBlock(nn.Module):
|
|
| 53 |
nn.Dropout(mlp_drop),
|
| 54 |
)
|
| 55 |
|
| 56 |
-
def forward(self, x, cond):
|
| 57 |
-
# x: [N, L, D]
|
| 58 |
-
# cond: [
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class TriplaneTransformer(nn.Module):
|
|
@@ -98,18 +121,34 @@ class TriplaneTransformer(nn.Module):
|
|
| 98 |
])
|
| 99 |
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
| 100 |
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
|
|
|
| 101 |
|
| 102 |
-
def forward(self, image_feats):
|
| 103 |
-
# image_feats: [
|
| 104 |
-
|
| 105 |
-
N = image_feats.shape[0]
|
| 106 |
H = W = self.triplane_low_res
|
| 107 |
L = 3 * H * W
|
| 108 |
-
|
| 109 |
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# separate each plane and apply deconv
|
| 115 |
x = x.view(N, 3, H, W, -1)
|
|
|
|
| 53 |
nn.Dropout(mlp_drop),
|
| 54 |
)
|
| 55 |
|
| 56 |
+
def forward(self, x, cond, i, alpha, content_layers):
|
| 57 |
+
# x: [N, L, D] or [x1, x2]
|
| 58 |
+
# cond: [content_feats] or [content_feats, style_feats]
|
| 59 |
+
if len(cond) == 2:
|
| 60 |
+
# Style injection mode
|
| 61 |
+
x1, x2 = x[0], x[1]
|
| 62 |
+
content, style = cond[0], cond[1]
|
| 63 |
+
if i <= content_layers:
|
| 64 |
+
x1 = x1 + self.cross_attn(self.norm1(x1), content, content)[0]
|
| 65 |
+
else:
|
| 66 |
+
x1 = x1 + (1-alpha)*self.cross_attn(self.norm1(x1), content, content)[0] + (alpha)*self.cross_attn(self.norm1(x1), style, style)[0]
|
| 67 |
+
x2 = x2 + self.cross_attn(self.norm1(x2), style, style)[0]
|
| 68 |
+
|
| 69 |
+
before_sa1 = self.norm2(x1)
|
| 70 |
+
before_sa2 = self.norm2(x2)
|
| 71 |
+
x1 = x1 + self.self_attn(before_sa1, before_sa1, before_sa1)[0]
|
| 72 |
+
x2 = x2 + self.self_attn(before_sa2, before_sa2, before_sa2)[0]
|
| 73 |
+
|
| 74 |
+
x1 = x1 + self.mlp(self.norm3(x1))
|
| 75 |
+
x2 = x2 + self.mlp(self.norm3(x2))
|
| 76 |
+
|
| 77 |
+
return [x1, x2]
|
| 78 |
+
else:
|
| 79 |
+
# No style, only content
|
| 80 |
+
x1 = x[0] if isinstance(x, list) else x
|
| 81 |
+
content = cond[0]
|
| 82 |
+
x1 = x1 + self.cross_attn(self.norm1(x1), content, content)[0]
|
| 83 |
+
before_sa1 = self.norm2(x1)
|
| 84 |
+
x1 = x1 + self.self_attn(before_sa1, before_sa1, before_sa1)[0]
|
| 85 |
+
x1 = x1 + self.mlp(self.norm3(x1))
|
| 86 |
+
return [x1]
|
| 87 |
|
| 88 |
|
| 89 |
class TriplaneTransformer(nn.Module):
|
|
|
|
| 121 |
])
|
| 122 |
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
| 123 |
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
| 124 |
+
self.num_layers = num_layers
|
| 125 |
|
| 126 |
+
def forward(self, image_feats, alpha, style_layers):
|
| 127 |
+
# image_feats: [content_feats] or [content_feats, style_feats]
|
| 128 |
+
N = image_feats[0].shape[0]
|
|
|
|
| 129 |
H = W = self.triplane_low_res
|
| 130 |
L = 3 * H * W
|
| 131 |
+
content_layers = self.num_layers - style_layers
|
| 132 |
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
| 133 |
+
i = 1
|
| 134 |
+
if len(image_feats) == 2:
|
| 135 |
+
# Style injection mode
|
| 136 |
+
for layer in self.layers:
|
| 137 |
+
if i == 1:
|
| 138 |
+
x = layer([x, x], image_feats, i, alpha, content_layers)
|
| 139 |
+
else:
|
| 140 |
+
x = layer(x, image_feats, i, alpha, content_layers)
|
| 141 |
+
i += 1
|
| 142 |
+
x = self.norm(x[0])
|
| 143 |
+
else:
|
| 144 |
+
# No style, only content
|
| 145 |
+
for layer in self.layers:
|
| 146 |
+
if i == 1:
|
| 147 |
+
x = layer([x], image_feats, i, alpha, content_layers)
|
| 148 |
+
else:
|
| 149 |
+
x = layer(x, image_feats, i, alpha, content_layers)
|
| 150 |
+
i += 1
|
| 151 |
+
x = self.norm(x[0])
|
| 152 |
|
| 153 |
# separate each plane and apply deconv
|
| 154 |
x = x.view(N, 3, H, W, -1)
|
src/models/lrm.py
CHANGED
|
@@ -18,6 +18,7 @@ import torch.nn as nn
|
|
| 18 |
import mcubes
|
| 19 |
import nvdiffrast.torch as dr
|
| 20 |
from einops import rearrange, repeat
|
|
|
|
| 21 |
|
| 22 |
from .encoder.dino_wrapper import DinoWrapper
|
| 23 |
from .decoder.transformer import TriplaneTransformer
|
|
@@ -65,19 +66,46 @@ class InstantNeRF(nn.Module):
|
|
| 65 |
samples_per_ray=rendering_samples_per_ray,
|
| 66 |
)
|
| 67 |
|
| 68 |
-
def forward_planes(self, images, cameras):
|
| 69 |
# images: [B, V, C_img, H_img, W_img]
|
| 70 |
# cameras: [B, V, 16]
|
| 71 |
B = images.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# encode images
|
| 74 |
image_feats = self.encoder(images, cameras)
|
| 75 |
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
|
| 76 |
-
|
| 77 |
# transformer generating planes
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
return planes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
def forward(self, images, cameras, render_cameras, render_size: int):
|
| 83 |
# images: [B, V, C_img, H_img, W_img]
|
|
@@ -125,7 +153,12 @@ class InstantNeRF(nn.Module):
|
|
| 125 |
sample_tex_pose_list.append(tex_pos_one_shape)
|
| 126 |
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
|
| 127 |
|
| 128 |
-
tex_feat =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
if hard_mask is not None:
|
| 131 |
final_tex_feat = torch.zeros(
|
|
|
|
| 18 |
import mcubes
|
| 19 |
import nvdiffrast.torch as dr
|
| 20 |
from einops import rearrange, repeat
|
| 21 |
+
from PIL import Image
|
| 22 |
|
| 23 |
from .encoder.dino_wrapper import DinoWrapper
|
| 24 |
from .decoder.transformer import TriplaneTransformer
|
|
|
|
| 66 |
samples_per_ray=rendering_samples_per_ray,
|
| 67 |
)
|
| 68 |
|
| 69 |
+
def forward_planes(self, images, cameras, style, alpha, style_layers):
|
| 70 |
# images: [B, V, C_img, H_img, W_img]
|
| 71 |
# cameras: [B, V, 16]
|
| 72 |
B = images.shape[0]
|
| 73 |
+
style_feats = None
|
| 74 |
+
|
| 75 |
+
if style is not None:
|
| 76 |
+
style_img = np.asarray(Image.open(style), dtype=np.float32) / 255.0
|
| 77 |
+
if style_img.ndim == 2: # Handle depth image
|
| 78 |
+
style_img = np.stack([style_img] * 3, axis=-1)
|
| 79 |
+
style_img = torch.from_numpy(style_img).permute(2, 0, 1).contiguous().float()
|
| 80 |
+
style_img = torch.nn.functional.interpolate(
|
| 81 |
+
style_img.unsqueeze(0), size=(320, 320), mode='bilinear', align_corners=False
|
| 82 |
+
) # Shape: [1, 3, 320, 320]
|
| 83 |
+
style_img = style_img.unsqueeze(1)
|
| 84 |
+
style_img = style_img.to(images.device) # torch.Size([1, 1, 3, 320, 320])
|
| 85 |
+
if style_img.shape[2] == 4: # Check if there are 4 channels
|
| 86 |
+
style_img = style_img[:, :, :3, :, :]
|
| 87 |
+
style_feats = self.encoder(style_img, cameras[:, :1, :]) # torch.Size([6, 401, 768]) cameras:torch.Size([1, 6, 16])
|
| 88 |
+
style_feats = rearrange(style_feats, '(b v) l d -> b (v l) d', b=B)
|
| 89 |
|
| 90 |
# encode images
|
| 91 |
image_feats = self.encoder(images, cameras)
|
| 92 |
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
|
| 93 |
+
|
| 94 |
# transformer generating planes
|
| 95 |
+
if style_feats is not None:
|
| 96 |
+
planes = self.transformer([image_feats, style_feats], alpha, style_layers)
|
| 97 |
+
else:
|
| 98 |
+
planes = self.transformer([image_feats], alpha, style_layers)
|
| 99 |
|
| 100 |
return planes
|
| 101 |
+
|
| 102 |
+
def forward_synthesizer(self, planes, render_cameras, render_size: int):
|
| 103 |
+
render_results = self.synthesizer(
|
| 104 |
+
planes,
|
| 105 |
+
render_cameras,
|
| 106 |
+
render_size,
|
| 107 |
+
)
|
| 108 |
+
return render_results
|
| 109 |
|
| 110 |
def forward(self, images, cameras, render_cameras, render_size: int):
|
| 111 |
# images: [B, V, C_img, H_img, W_img]
|
|
|
|
| 153 |
sample_tex_pose_list.append(tex_pos_one_shape)
|
| 154 |
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
|
| 155 |
|
| 156 |
+
tex_feat = torch.utils.checkpoint.checkpoint(
|
| 157 |
+
self.synthesizer.forward_points,
|
| 158 |
+
planes,
|
| 159 |
+
tex_pos,
|
| 160 |
+
use_reentrant=False,
|
| 161 |
+
)['rgb']
|
| 162 |
|
| 163 |
if hard_mask is not None:
|
| 164 |
final_tex_feat = torch.zeros(
|
src/models/lrm_mesh.py
CHANGED
|
@@ -17,6 +17,7 @@ import torch
|
|
| 17 |
import torch.nn as nn
|
| 18 |
import nvdiffrast.torch as dr
|
| 19 |
from einops import rearrange, repeat
|
|
|
|
| 20 |
|
| 21 |
from .encoder.dino_wrapper import DinoWrapper
|
| 22 |
from .decoder.transformer import TriplaneTransformer
|
|
@@ -74,12 +75,9 @@ class InstantMesh(nn.Module):
|
|
| 74 |
samples_per_ray=rendering_samples_per_ray,
|
| 75 |
)
|
| 76 |
|
| 77 |
-
def init_flexicubes_geometry(self, device, fovy=50.0
|
| 78 |
camera = PerspectiveCamera(fovy=fovy, device=device)
|
| 79 |
-
|
| 80 |
-
renderer = NeuralRender(device, camera_model=camera)
|
| 81 |
-
else:
|
| 82 |
-
renderer = None
|
| 83 |
self.geometry = FlexiCubesGeometry(
|
| 84 |
grid_res=self.grid_res,
|
| 85 |
scale=self.grid_scale,
|
|
@@ -88,17 +86,36 @@ class InstantMesh(nn.Module):
|
|
| 88 |
device=device,
|
| 89 |
)
|
| 90 |
|
| 91 |
-
def forward_planes(self, images, cameras):
|
| 92 |
# images: [B, V, C_img, H_img, W_img]
|
| 93 |
# cameras: [B, V, 16]
|
| 94 |
B = images.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# encode images
|
| 97 |
image_feats = self.encoder(images, cameras)
|
| 98 |
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
|
| 99 |
-
|
| 100 |
-
#
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
return planes
|
| 104 |
|
|
|
|
| 17 |
import torch.nn as nn
|
| 18 |
import nvdiffrast.torch as dr
|
| 19 |
from einops import rearrange, repeat
|
| 20 |
+
from PIL import Image
|
| 21 |
|
| 22 |
from .encoder.dino_wrapper import DinoWrapper
|
| 23 |
from .decoder.transformer import TriplaneTransformer
|
|
|
|
| 75 |
samples_per_ray=rendering_samples_per_ray,
|
| 76 |
)
|
| 77 |
|
| 78 |
+
def init_flexicubes_geometry(self, device, fovy=50.0):
|
| 79 |
camera = PerspectiveCamera(fovy=fovy, device=device)
|
| 80 |
+
renderer = NeuralRender(device, camera_model=camera)
|
|
|
|
|
|
|
|
|
|
| 81 |
self.geometry = FlexiCubesGeometry(
|
| 82 |
grid_res=self.grid_res,
|
| 83 |
scale=self.grid_scale,
|
|
|
|
| 86 |
device=device,
|
| 87 |
)
|
| 88 |
|
| 89 |
+
def forward_planes(self, images, cameras, style, alpha, style_layers):
|
| 90 |
# images: [B, V, C_img, H_img, W_img]
|
| 91 |
# cameras: [B, V, 16]
|
| 92 |
B = images.shape[0]
|
| 93 |
+
style_feats = None
|
| 94 |
+
|
| 95 |
+
if style is not None:
|
| 96 |
+
style_img = np.asarray(Image.open(style), dtype=np.float32) / 255.0
|
| 97 |
+
if style_img.ndim == 2: # Handle depth image
|
| 98 |
+
style_img = np.stack([style_img] * 3, axis=-1)
|
| 99 |
+
style_img = torch.from_numpy(style_img).permute(2, 0, 1).contiguous().float()
|
| 100 |
+
style_img = torch.nn.functional.interpolate(
|
| 101 |
+
style_img.unsqueeze(0), size=(320, 320), mode='bilinear', align_corners=False
|
| 102 |
+
) # Shape: [1, 3, 320, 320]
|
| 103 |
+
style_img = style_img.unsqueeze(1)
|
| 104 |
+
style_img = style_img.to(images.device) # torch.Size([1, 1, 3, 320, 320])
|
| 105 |
+
if style_img.shape[2] == 4: # Check if there are 4 channels
|
| 106 |
+
style_img = style_img[:, :, :3, :, :]
|
| 107 |
+
style_feats = self.encoder(style_img, cameras[:, :1, :]) # torch.Size([6, 401, 768]) cameras:torch.Size([1, 6, 16])
|
| 108 |
+
style_feats = rearrange(style_feats, '(b v) l d -> b (v l) d', b=B)
|
| 109 |
|
| 110 |
# encode images
|
| 111 |
image_feats = self.encoder(images, cameras)
|
| 112 |
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
|
| 113 |
+
|
| 114 |
+
# transformer generating planes
|
| 115 |
+
if style_feats is not None:
|
| 116 |
+
planes = self.transformer([image_feats, style_feats], alpha, style_layers)
|
| 117 |
+
else:
|
| 118 |
+
planes = self.transformer([image_feats], alpha, style_layers)
|
| 119 |
|
| 120 |
return planes
|
| 121 |
|
src/utils/infer_util.py
CHANGED
|
@@ -81,4 +81,17 @@ def images_to_video(
|
|
| 81 |
assert frame.min() >= 0 and frame.max() <= 255, \
|
| 82 |
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
| 83 |
frames.append(frame)
|
| 84 |
-
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
assert frame.min() >= 0 and frame.max() <= 255, \
|
| 82 |
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
| 83 |
frames.append(frame)
|
| 84 |
+
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def save_video(
|
| 88 |
+
frames: torch.Tensor,
|
| 89 |
+
output_path: str,
|
| 90 |
+
fps: int = 30,
|
| 91 |
+
) -> None:
|
| 92 |
+
# images: (N, C, H, W)
|
| 93 |
+
frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]
|
| 94 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
| 95 |
+
for frame in frames:
|
| 96 |
+
writer.append_data(frame)
|
| 97 |
+
writer.close()
|