JiantaoLin commited on
Commit
d346594
·
1 Parent(s): c8bf07b
Files changed (3) hide show
  1. app.py +424 -322
  2. app_demo.py +384 -0
  3. app_demo_.py +0 -491
app.py CHANGED
@@ -1,10 +1,16 @@
1
- import gradio as gr
2
  import os
 
3
  import subprocess
4
- import shlex
5
  import spaces
 
 
6
  import torch
7
- access_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
8
  subprocess.run(
9
  shlex.split(
10
  "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
@@ -41,6 +47,7 @@ def install_cuda_toolkit():
41
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
  print("==> finfish install")
43
  install_cuda_toolkit()
 
44
  @spaces.GPU
45
  def check_gpu():
46
  os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
@@ -51,334 +58,429 @@ def check_gpu():
51
  print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
52
  check_gpu()
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  from PIL import Image
55
- from einops import rearrange
56
- from diffusers import FluxPipeline
57
- from models.lrm.utils.camera_util import get_flux_input_cameras
58
- from models.lrm.utils.infer_util import save_video
59
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
60
- from models.lrm.utils.render_utils import rotate_x, rotate_y
61
- from models.lrm.utils.train_util import instantiate_from_config
62
- from models.ISOMER.reconstruction_func import reconstruction
63
- from models.ISOMER.projection_func import projection
64
- import os
65
- from einops import rearrange
66
- from omegaconf import OmegaConf
67
- import torch
68
- import numpy as np
69
  import trimesh
70
- import torchvision
71
- import torch.nn.functional as F
72
- from PIL import Image
73
- from torchvision import transforms
74
- from torchvision.transforms import v2
75
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
76
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
77
- from diffusers import FluxPipeline
78
- from pytorch_lightning import seed_everything
79
- import os
80
- from huggingface_hub import hf_hub_download
81
-
82
-
83
- from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
84
-
85
- device_0 = "cuda"
86
- device_1 = "cuda"
87
- resolution = 512
88
- save_dir = "./outputs"
89
- normal_transfer = NormalTransfer()
90
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
91
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
92
- isomer_radius = 4.5
93
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
94
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
95
-
96
- # model initialization and loading
97
- # flux
98
- # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
99
- # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
100
- # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
101
- # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
102
- # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
103
- # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
104
- # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
105
- # torch.cuda.empty_cache()
106
- # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
107
-
108
-
109
- # lrm
110
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
111
- model_config = config.model_config
112
- infer_config = config.infer_config
113
- model = instantiate_from_config(model_config)
114
- model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
115
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
116
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
117
- model.load_state_dict(state_dict, strict=True)
118
- model = model.to(device_1)
119
- torch.cuda.empty_cache()
120
- @spaces.GPU
121
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
122
- images = image.unsqueeze(0).to(device_1)
123
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
124
- # breakpoint()
125
- with torch.no_grad():
126
- # get triplane
127
- planes = model.forward_planes(images, input_cameras)
128
-
129
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
130
-
131
- mesh_out = model.extract_mesh(
132
- planes,
133
- use_texture_map=export_texmap,
134
- **infer_config,
135
- )
136
- if export_texmap:
137
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
138
- save_obj_with_mtl(
139
- vertices.data.cpu().numpy(),
140
- uvs.data.cpu().numpy(),
141
- faces.data.cpu().numpy(),
142
- mesh_tex_idx.data.cpu().numpy(),
143
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
144
- mesh_path_idx,
145
- )
146
- else:
147
- vertices, faces, vertex_colors = mesh_out
148
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
149
- print(f"Mesh saved to {mesh_path_idx}")
150
-
151
- render_size = 512
152
- if if_save_video:
153
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
154
- render_size = infer_config.render_resolution
155
- ENV = load_mipmap("models/lrm/env_mipmap/6")
156
- materials = (0.0,0.9)
157
-
158
- all_mv, all_mvp, all_campos = get_render_cameras_video(
159
- batch_size=1,
160
- M=24,
161
- radius=4.5,
162
- elevation=(90, 60.0),
163
- is_flexicubes=True,
164
- fov=30
165
- )
166
-
167
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
168
- model,
169
- planes,
170
- render_cameras=all_mvp,
171
- camera_pos=all_campos,
172
- env=ENV,
173
- materials=materials,
174
- render_size=render_size,
175
- chunk_size=20,
176
- is_flexicubes=True,
177
- )
178
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
179
- normals = normals * alphas + (1-alphas)
180
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
181
-
182
- save_video(
183
- all_frames,
184
- video_path_idx,
185
- fps=30,
186
- )
187
- print(f"Video saved to {video_path_idx}")
188
 
189
- return vertices, faces
 
 
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
193
- if local_normal_images.min() >= 0:
194
- local_normal = local_normal_images.float() * 2 - 1
195
- else:
196
- local_normal = local_normal_images.float()
197
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
198
- global_normal[...,0] *= -1
199
- global_normal = (global_normal + 1) / 2
200
- global_normal = global_normal.permute(0, 3, 1, 2)
201
- return global_normal
202
-
203
- # 生成多视图图像
204
- @spaces.GPU(duration=120)
205
- def generate_multi_view_images(prompt, seed):
206
- # torch.cuda.empty_cache()
207
- # generator = torch.manual_seed(seed)
208
- generator = torch.Generator().manual_seed(seed)
209
- with torch.no_grad():
210
- img = flux_pipe(
211
- prompt=prompt,
212
- num_inference_steps=5,
213
- guidance_scale=3.5,
214
- num_images_per_prompt=1,
215
- width=resolution * 2,
216
- height=resolution * 1,
217
- output_type='np',
218
- generator=generator,
219
- ).images
220
- # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
221
- # prompt=prompt,
222
- # guidance_scale=3.5,
223
- # num_inference_steps=4,
224
- # width=resolution * 4,
225
- # height=resolution * 2,
226
- # generator=generator,
227
- # output_type="np",
228
- # good_vae=good_vae,
229
- # ):
230
- # pass
231
- # 返回最终的图像和种子(通过外部调用处理)
232
- return img
233
-
234
- # 重建 3D 模型
235
  @spaces.GPU
236
- def reconstruct_3d_model(images, prompt):
237
- global model
238
- model.init_flexicubes_geometry(device_1, fovy=50.0)
239
- model = model.eval()
240
- rgb_normal_grid = images
241
- save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
242
- os.makedirs(save_dir_path, exist_ok=True)
243
-
244
- images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
245
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
246
- rgb_multi_view = images[:4, :3, :, :]
247
- normal_multi_view = images[4:, :3, :, :]
248
- multi_view_mask = get_background(normal_multi_view)
249
- rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
250
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
251
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
252
- # local normal to global normal
253
-
254
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
255
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
256
-
257
- global_normal = global_normal.permute(0,2,3,1)
258
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
259
- multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
260
- vertices = torch.from_numpy(vertices).to(device_1)
261
- faces = torch.from_numpy(faces).to(device_1)
262
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
263
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
264
-
265
- # global_normal: B,H,W,3
266
- # multi_view_mask: B,H,W
267
- # rgb_multi_view: B,H,W,3
268
-
269
- meshes = reconstruction(
270
- normal_pils=global_normal,
271
- masks=multi_view_mask,
272
- weights=isomer_geo_weights,
273
- fov=30,
274
- radius=isomer_radius,
275
- camera_angles_azi=isomer_azimuths,
276
- camera_angles_ele=isomer_elevations,
277
- expansion_weight_stage1=0.1,
278
- init_type="file",
279
- init_verts=vertices,
280
- init_faces=faces,
281
- stage1_steps=0,
282
- stage2_steps=50,
283
- start_edge_len_stage1=0.1,
284
- end_edge_len_stage1=0.02,
285
- start_edge_len_stage2=0.02,
286
- end_edge_len_stage2=0.005,
287
- )
288
 
289
 
290
- save_glb_addr = projection(
291
- meshes,
292
- masks=multi_view_mask,
293
- images=rgb_multi_view,
294
- azimuths=isomer_azimuths,
295
- elevations=isomer_elevations,
296
- weights=isomer_color_weights,
297
- fov=30,
298
- radius=isomer_radius,
299
- save_dir=f"{save_dir_path}/ISOMER/",
300
- )
301
 
302
- return save_glb_addr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- # Gradio 接口函数
305
  @spaces.GPU
306
- def gradio_pipeline(prompt, seed):
307
- import ctypes
308
- # 显式加载 libnvrtc.so.12
309
- cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
310
- try:
311
- ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
312
- print(f"Successfully preloaded {cuda_lib_path}")
313
- except OSError as e:
314
- print(f"Failed to preload {cuda_lib_path}: {e}")
315
- # 生成多视图图像
316
- # rgb_normal_grid = generate_multi_view_images(prompt, seed)
317
- rgb_normal_grid = np.load("rgb_normal_grid.npy")
318
- image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
319
-
320
- # 3d reconstruction
321
-
322
-
323
- # 重建 3D 模型并返回 glb 路径
324
- save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
325
- # save_glb_addr = None
326
- return image_preview, save_glb_addr
327
-
328
- # Gradio Blocks 应用
329
- with gr.Blocks() as demo:
330
- with gr.Row(variant="panel"):
331
- # 左侧输入区域
332
- with gr.Column():
333
- with gr.Row():
334
- prompt_input = gr.Textbox(
335
- label="Enter Prompt",
336
- placeholder="Describe your 3D model...",
337
- lines=2,
338
- elem_id="prompt_input"
339
- )
340
-
341
- with gr.Row():
342
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
343
-
344
- with gr.Row():
345
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
346
-
347
- with gr.Row(variant="panel"):
348
- gr.Markdown("Examples:")
349
- gr.Examples(
350
- examples=[
351
- ["a castle on a hill"],
352
- ["an owl wearing a hat"],
353
- ["a futuristic car"]
354
- ],
355
- inputs=[prompt_input],
356
- label="Prompt Examples"
357
- )
358
-
359
- # 右侧输出区域
360
- with gr.Column():
361
- with gr.Row():
362
- rgb_normal_grid_image = gr.Image(
363
- label="RGB Normal Grid",
364
- type="pil",
365
- interactive=False
366
- )
367
-
368
- with gr.Row():
369
- with gr.Tab("GLB"):
370
- output_glb_model = gr.Model3D(
371
- label="Generated 3D Model (GLB Format)",
372
- interactive=False
373
- )
374
- gr.Markdown("Download the model for proper visualization.")
375
-
376
- # 处理逻辑
377
- submit.click(
378
- fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
379
- outputs=[rgb_normal_grid_image, output_glb_model]
380
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- # 启动应用
383
- # demo.queue(max_size=10)
384
- demo.launch()
 
 
1
  import os
2
+ import gradio as gr
3
  import subprocess
 
4
  import spaces
5
+ import ctypes
6
+ import shlex
7
  import torch
8
+
9
+ subprocess.run(
10
+ shlex.split(
11
+ "pip install ./custom_diffusers --force-reinstall --no-deps"
12
+ )
13
+ )
14
  subprocess.run(
15
  shlex.split(
16
  "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
 
47
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
48
  print("==> finfish install")
49
  install_cuda_toolkit()
50
+
51
  @spaces.GPU
52
  def check_gpu():
53
  os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
 
58
  print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
59
  check_gpu()
60
 
61
+
62
+ import base64
63
+ import re
64
+ import sys
65
+
66
+ sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
67
+ if 'OMP_NUM_THREADS' not in os.environ:
68
+ os.environ['OMP_NUM_THREADS'] = '32'
69
+
70
+ import shutil
71
+ import json
72
+ import requests
73
+ import shutil
74
+ import threading
75
  from PIL import Image
76
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  import trimesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ import random
80
+ import time
81
+ import numpy as np
82
+ from video_render import render_video_from_obj
83
 
84
+ access_token = os.getenv("HUGGINGFACE_TOKEN")
85
+ from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d, image2mesh_preprocess, image2mesh_main
86
+
87
+
88
+ # Add logo file path and hyperlinks
89
+ LOGO_PATH = "app_assets/logo_temp_.png" # Update this to the actual path of your logo
90
+ ARXIV_LINK = "https://arxiv.org/abs/example"
91
+ GITHUB_LINK = "https://github.com/example"
92
+
93
+
94
+ k3d_wrapper = init_wrapper_from_config('./pipeline/pipeline_config/default.yaml')
95
+
96
+
97
+ from models.ISOMER.scripts.utils import fix_vert_color_glb
98
+ torch.backends.cuda.matmul.allow_tf32 = True
99
+
100
+
101
+
102
+ TEMP_MESH_ADDRESS=''
103
+
104
+ mesh_cache = None
105
+ preprocessed_input_image = None
106
+
107
+ def save_cached_mesh():
108
+ global mesh_cache
109
+ return mesh_cache
110
+ # if mesh_cache is None:
111
+ # return None
112
+ # return save_py3dmesh_with_trimesh_fast(mesh_cache)
113
+
114
+ def save_py3dmesh_with_trimesh_fast(meshes, save_glb_path=TEMP_MESH_ADDRESS, apply_sRGB_to_LinearRGB=True):
115
+ from pytorch3d.structures import Meshes
116
+ import trimesh
117
+
118
+ # convert from pytorch3d meshes to trimesh mesh
119
+ vertices = meshes.verts_packed().cpu().float().numpy()
120
+ triangles = meshes.faces_packed().cpu().long().numpy()
121
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
122
+ if save_glb_path.endswith(".glb"):
123
+ # rotate 180 along +Y
124
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
125
+
126
+ def srgb_to_linear(c_srgb):
127
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
128
+ return c_linear.clip(0, 1.)
129
+ if apply_sRGB_to_LinearRGB:
130
+ np_color = srgb_to_linear(np_color)
131
+ assert vertices.shape[0] == np_color.shape[0]
132
+ assert np_color.shape[1] == 3
133
+ assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
134
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
135
+ mesh.remove_unreferenced_vertices()
136
+ # save mesh
137
+ mesh.export(save_glb_path)
138
+ if save_glb_path.endswith(".glb"):
139
+ fix_vert_color_glb(save_glb_path)
140
+ print(f"saving to {save_glb_path}")
141
+ #
142
+ #
143
+ # @spaces.GPU
144
+ def text_to_detailed(prompt, seed=None):
145
+ # print(torch.cuda.is_available())
146
+ # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
147
+ return k3d_wrapper.get_detailed_prompt(prompt, seed)
148
+
149
+ def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
150
+ # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
151
+ k3d_wrapper.renew_uuid()
152
+ init_image = None
153
+ # if init_image_path is not None:
154
+ # init_image = Image.open(init_image_path)
155
+ result = k3d_wrapper.generate_3d_bundle_image_text(
156
+ prompt,
157
+ image=init_image,
158
+ strength=strength,
159
+ lora_scale=lora_scale,
160
+ num_inference_steps=num_inference_steps,
161
+ seed=int(seed) if seed is not None else None,
162
+ redux_hparam=redux_hparam,
163
+ save_intermediate_results=True,
164
+ **kwargs)
165
+ return result[-1]
166
+
167
+ def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
168
+ global preprocessed_input_image
169
+
170
+ seed = int(seed) if seed is not None else None
171
+
172
+ # TODO: delete this later
173
+ k3d_wrapper.del_llm_model()
174
+
175
+ input_image_save_path, reference_save_path, caption = image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb)
176
+
177
+ preprocessed_input_image = Image.open(input_image_save_path)
178
+ return reference_save_path, caption
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  @spaces.GPU
181
+ def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
182
+ global mesh_cache
183
+ seed = int(seed) if seed is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
+ # TODO: delete this later
187
+ k3d_wrapper.del_llm_model()
188
+
189
+ input_image = preprocessed_input_image
 
 
 
 
 
 
 
190
 
191
+ reference_3d_bundle_image = torch.tensor(reference_3d_bundle_image).permute(2,0,1)/255
192
+
193
+ gen_save_path, recon_mesh_path = image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption=caption, seed=seed, strength1=strength1, strength2=strength2, enable_redux=enable_redux, use_controlnet=use_controlnet)
194
+ mesh_cache = recon_mesh_path
195
+
196
+
197
+ # gen_save_ = Image.open(gen_save_path)
198
+
199
+ if if_video:
200
+ video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
201
+ render_video_from_obj(recon_mesh_path, video_path)
202
+ print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
203
+ return gen_save_path, video_path
204
+ else:
205
+ return gen_save_path, recon_mesh_path
206
+ # return gen_save_path, recon_mesh_path
207
 
 
208
  @spaces.GPU
209
+ def bundle_image_to_mesh(
210
+ gen_3d_bundle_image,
211
+ lrm_radius = 4.15,
212
+ isomer_radius = 4.5,
213
+ reconstruction_stage1_steps = 10,
214
+ reconstruction_stage2_steps = 50,
215
+ save_intermediate_results=True,
216
+ if_video=True
217
+ ):
218
+ global mesh_cache
219
+ print(f"Before bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
220
+ k3d_wrapper.recon_model.init_flexicubes_geometry("cuda:0", fovy=50.0)
221
+ # TODO: delete this later
222
+ k3d_wrapper.del_llm_model()
223
+
224
+ print(f"Before bundle_image_to_mesh after deleting llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
225
+
226
+ gen_3d_bundle_image = torch.tensor(gen_3d_bundle_image).permute(2,0,1)/255
227
+ # recon from 3D Bundle image
228
+ recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, lrm_render_radius=lrm_radius, isomer_radius=isomer_radius, save_intermediate_results=save_intermediate_results, reconstruction_stage1_steps=int(reconstruction_stage1_steps), reconstruction_stage2_steps=int(reconstruction_stage2_steps))
229
+ mesh_cache = recon_mesh_path
230
+
231
+ if if_video:
232
+ video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
233
+ # # 检查这个video_path文件大小是是否超过50KB,不超过的话就认为是空文件,需要重新渲染
234
+ # if os.path.exists(video_path):
235
+ # print(f"file size:{os.path.getsize(video_path)}")
236
+ # if os.path.getsize(video_path) > 50*1024:
237
+ # print(f"video path:{video_path}")
238
+ # return video_path
239
+ render_video_from_obj(recon_mesh_path, video_path)
240
+ print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
241
+ return video_path
242
+ else:
243
+ return recon_mesh_path
244
+
245
+ _HEADER_=f"""
246
+ <img src="{LOGO_PATH}">
247
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2>
248
+ <b>Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation</b></a></h2>
249
+
250
+ <p>**Kiss3DGen** is xxxxxxxxx</p>
251
+
252
+ [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
253
+ """
254
+
255
+ _CITE_ = r"""
256
+ <h2>If Kiss3DGen is helpful, please help to ⭐ the <a href='{""" + GITHUB_LINK + r"""}' target='_blank'>Github Repo</a>. Thanks!</h2>
257
+
258
+ 📝 **Citation**
259
+
260
+ If you find our work useful for your research or applications, please cite using this bibtex:
261
+ ```bibtex
262
+ @article{xxxx,
263
+ title={xxxx},
264
+ author={xxxx},
265
+ journal={xxxx},
266
+ year={xxxx}
267
+ }
268
+ ```
269
+
270
+ 📋 **License**
271
+
272
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
273
+
274
+ 📧 **Contact**
275
+
276
+ If you have any questions, feel free to open a discussion or contact us at <b>xxx@xxxx</b>.
277
+ """
278
+
279
+ def image_to_base64(image_path):
280
+ """Converts an image file to a base64-encoded string."""
281
+ with open(image_path, "rb") as img_file:
282
+ return base64.b64encode(img_file.read()).decode('utf-8')
283
+
284
+ def main():
285
+
286
+ torch.set_grad_enabled(False)
287
+
288
+ # Convert the logo image to base64
289
+ logo_base64 = image_to_base64(LOGO_PATH)
290
+ # with gr.Blocks() as demo:
291
+ with gr.Blocks(css="""
292
+ body {
293
+ display: flex;
294
+ justify-content: center;
295
+ align-items: center;
296
+ min-height: 100vh;
297
+ margin: 0;
298
+ padding: 0;
299
+ }
300
+ #col-container { margin: 0px auto; max-width: 200px; }
301
+
302
+
303
+ .gradio-container {
304
+ max-width: 1000px;
305
+ margin: auto;
306
+ width: 100%;
307
+ }
308
+ #center-align-column {
309
+ display: flex;
310
+ justify-content: center;
311
+ align-items: center;
312
+ }
313
+ #right-align-column {
314
+ display: flex;
315
+ justify-content: flex-end;
316
+ align-items: center;
317
+ }
318
+ h1 {text-align: center;}
319
+ h2 {text-align: center;}
320
+ h3 {text-align: center;}
321
+ p {text-align: center;}
322
+ img {text-align: right;}
323
+ .right {
324
+ display: block;
325
+ margin-left: auto;
326
+ }
327
+ .center {
328
+ display: block;
329
+ margin-left: auto;
330
+ margin-right: auto;
331
+ width: 50%;
332
+
333
+ #content-container {
334
+ max-width: 1200px;
335
+ margin: 0 auto;
336
+ }
337
+ #example-container {
338
+ max-width: 300px;
339
+ margin: 0 auto;
340
+ }
341
+ """,elem_id="col-container") as demo:
342
+ # Header Section
343
+ # gr.Image(value=LOGO_PATH, width=64, height=64)
344
+ # gr.Markdown(_HEADER_)
345
+ with gr.Row(elem_id="content-container"):
346
+ # with gr.Column(scale=1):
347
+ # pass
348
+ # with gr.Column(scale=1, elem_id="right-align-column"):
349
+ # # gr.Image(value=LOGO_PATH, interactive=False, show_label=False, width=64, height=64, elem_id="logo-image")
350
+ # # gr.Markdown(f"<img src='{LOGO_PATH}' alt='Logo' style='width:64px;height:64px;border:0;'>")
351
+ # # gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='right' style='width:64px;height:64px;border:0;text-align:right;'>")
352
+ # pass
353
+ with gr.Column(scale=7, elem_id="center-align-column"):
354
+ gr.Markdown(f"""
355
+ ## Official 🤗 Gradio Demo
356
+ # Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation""")
357
+ gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='center' style='width:64px;height:64px;border:0;text-align:center;'>")
358
+
359
+ gr.HTML(f"""
360
+ <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
361
+ <a href="{ARXIV_LINK}" target="_blank">
362
+ <img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv">
363
+ </a>
364
+ <a href="{GITHUB_LINK}" target="_blank">
365
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub">
366
+ </a>
367
+ </div>
368
+
369
+ """)
370
+
371
+
372
+ # gr.HTML(f"""
373
+ # <div style="display: flex; gap: 10px; align-items: center;"><a href="{ARXIV_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv"></a> <a href="{GITHUB_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub"></a></div>
374
+ # """)
375
+
376
+ # gr.Markdown(f"""
377
+ # [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
378
+ # """, elem_id="title")
379
+ # with gr.Column(scale=1):
380
+ # pass
381
+ # with gr.Row():
382
+ # gr.Markdown(f"[![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK})")
383
+ # gr.Markdown(f"[![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})")
384
+
385
+ # Tabs Section
386
+ with gr.Tabs(selected='tab_text_to_3d', elem_id="content-container") as main_tabs:
387
+ with gr.TabItem('Text-to-3D', id='tab_text_to_3d'):
388
+ with gr.Row():
389
+ with gr.Column(scale=1):
390
+ prompt = gr.Textbox(value="", label="Input Prompt", lines=4)
391
+ seed1 = gr.Number(value=10, label="Seed")
392
+
393
+ with gr.Row(elem_id="example-container"):
394
+ gr.Examples(
395
+ examples=[
396
+ # ["A tree with red leaves"],
397
+ # ["A dragon with black texture"],
398
+ ["A girl with pink hair"],
399
+ ["A boy playing guitar"],
400
+
401
+
402
+ ["A dog wearing a hat"],
403
+ ["A boy playing basketball"],
404
+ # [""],
405
+ # [""],
406
+ # [""],
407
+
408
+ ],
409
+ inputs=[prompt], # 将选中的示例填入 prompt 文本框
410
+ label="Example Prompts"
411
+ )
412
+ btn_text2detailed = gr.Button("Refine to detailed prompt")
413
+ detailed_prompt = gr.Textbox(value="", label="Detailed Prompt", placeholder="detailed prompt will be generated here base on your input prompt. You can also edit this prompt", lines=4, interactive=True)
414
+ btn_text2img = gr.Button("Generate Images")
415
+
416
+ with gr.Column(scale=1):
417
+ output_image1 = gr.Image(label="Generated image", interactive=False)
418
+
419
+
420
+ # lrm_radius = gr.Number(value=4.15, label="lrm_radius")
421
+ # isomer_radius = gr.Number(value=4.5, label="isomer_radius")
422
+ # reconstruction_stage1_steps = gr.Number(value=10, label="reconstruction_stage1_steps")
423
+ # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
424
+
425
+ btn_gen_mesh = gr.Button("Generate Mesh")
426
+ output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
427
+ btn_download1 = gr.Button("Download Mesh")
428
+
429
+ file_output1 = gr.File()
430
+
431
+ with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
432
+ with gr.Row():
433
+ with gr.Column(scale=1):
434
+ image = gr.Image(label="Input Image", type="pil")
435
+
436
+ seed2 = gr.Number(value=10, label="Seed (0 for random)")
437
+
438
+ btn_img2mesh_preprocess = gr.Button("Preprocess Image")
439
+
440
+ image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
441
+
442
+ output_image2 = gr.Image(label="Generated image", interactive=False)
443
+ strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="strength1")
444
+ strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="strength2")
445
+ enable_redux = gr.Checkbox(label="enable redux", value=True)
446
+ use_controlnet = gr.Checkbox(label="use controlnet", value=True)
447
+
448
+ btn_img2mesh_main = gr.Button("Generate Mesh")
449
+
450
+ with gr.Column(scale=1):
451
+
452
+ # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
453
+ output_image3 = gr.Image(label="gen save image", interactive=False)
454
+ output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
455
+ btn_download2 = gr.Button("Download Mesh")
456
+ file_output2 = gr.File()
457
+
458
+ # Image2
459
+ btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
460
+
461
+ btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2])
462
+
463
+
464
+ btn_download2.click(fn=save_cached_mesh, inputs=[], outputs=file_output2)
465
+
466
+
467
+ # Button Click Events
468
+ # Text2
469
+ btn_text2detailed.click(fn=text_to_detailed, inputs=[prompt, seed1], outputs=detailed_prompt)
470
+ btn_text2img.click(fn=text_to_image, inputs=[detailed_prompt, seed1], outputs=output_image1)
471
+ btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1,], outputs=output_video1)
472
+ # btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1, lrm_radius, isomer_radius, reconstruction_stage1_steps, reconstruction_stage2_steps], outputs=output_video1)
473
+
474
+ with gr.Row():
475
+ pass
476
+ with gr.Row():
477
+ gr.Markdown(_CITE_)
478
+
479
+ # demo.queue(default_concurrency_limit=1)
480
+ # demo.launch(server_name="0.0.0.0", server_port=9239)
481
+ # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
482
+ demo.launch()
483
+
484
 
485
+ if __name__ == "__main__":
486
+ main()
 
app_demo.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import shlex
5
+ import spaces
6
+ import torch
7
+ access_token = os.getenv("HUGGINGFACE_TOKEN")
8
+ subprocess.run(
9
+ shlex.split(
10
+ "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
11
+ )
12
+ )
13
+
14
+ subprocess.run(
15
+ shlex.split(
16
+ "pip install ./extension/nvdiffrast-0.3.1+torch-py3-none-any.whl --force-reinstall --no-deps"
17
+ )
18
+ )
19
+
20
+ subprocess.run(
21
+ shlex.split(
22
+ "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
23
+ )
24
+ )
25
+ def install_cuda_toolkit():
26
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
27
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
28
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
29
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
30
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
31
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
32
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
33
+
34
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
35
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
36
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
37
+ os.environ["CUDA_HOME"],
38
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
39
+ )
40
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
41
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
+ print("==> finfish install")
43
+ install_cuda_toolkit()
44
+ @spaces.GPU
45
+ def check_gpu():
46
+ os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
47
+ os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
48
+ # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
49
+ os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
50
+ subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
51
+ print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
52
+ check_gpu()
53
+
54
+ from PIL import Image
55
+ from einops import rearrange
56
+ from diffusers import FluxPipeline
57
+ from models.lrm.utils.camera_util import get_flux_input_cameras
58
+ from models.lrm.utils.infer_util import save_video
59
+ from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
60
+ from models.lrm.utils.render_utils import rotate_x, rotate_y
61
+ from models.lrm.utils.train_util import instantiate_from_config
62
+ from models.ISOMER.reconstruction_func import reconstruction
63
+ from models.ISOMER.projection_func import projection
64
+ import os
65
+ from einops import rearrange
66
+ from omegaconf import OmegaConf
67
+ import torch
68
+ import numpy as np
69
+ import trimesh
70
+ import torchvision
71
+ import torch.nn.functional as F
72
+ from PIL import Image
73
+ from torchvision import transforms
74
+ from torchvision.transforms import v2
75
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
76
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
77
+ from diffusers import FluxPipeline
78
+ from pytorch_lightning import seed_everything
79
+ import os
80
+ from huggingface_hub import hf_hub_download
81
+
82
+
83
+ from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
84
+
85
+ device_0 = "cuda"
86
+ device_1 = "cuda"
87
+ resolution = 512
88
+ save_dir = "./outputs"
89
+ normal_transfer = NormalTransfer()
90
+ isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
91
+ isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
92
+ isomer_radius = 4.5
93
+ isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
94
+ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
95
+
96
+ # model initialization and loading
97
+ # flux
98
+ # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
99
+ # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
100
+ # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
101
+ # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
102
+ # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
103
+ # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
104
+ # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
105
+ # torch.cuda.empty_cache()
106
+ # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
107
+
108
+
109
+ # lrm
110
+ config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
111
+ model_config = config.model_config
112
+ infer_config = config.infer_config
113
+ model = instantiate_from_config(model_config)
114
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
115
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
116
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
117
+ model.load_state_dict(state_dict, strict=True)
118
+ model = model.to(device_1)
119
+ torch.cuda.empty_cache()
120
+ @spaces.GPU
121
+ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
122
+ images = image.unsqueeze(0).to(device_1)
123
+ images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
124
+ # breakpoint()
125
+ with torch.no_grad():
126
+ # get triplane
127
+ planes = model.forward_planes(images, input_cameras)
128
+
129
+ mesh_path_idx = os.path.join(save_path, f'{name}.obj')
130
+
131
+ mesh_out = model.extract_mesh(
132
+ planes,
133
+ use_texture_map=export_texmap,
134
+ **infer_config,
135
+ )
136
+ if export_texmap:
137
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
138
+ save_obj_with_mtl(
139
+ vertices.data.cpu().numpy(),
140
+ uvs.data.cpu().numpy(),
141
+ faces.data.cpu().numpy(),
142
+ mesh_tex_idx.data.cpu().numpy(),
143
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
144
+ mesh_path_idx,
145
+ )
146
+ else:
147
+ vertices, faces, vertex_colors = mesh_out
148
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
149
+ print(f"Mesh saved to {mesh_path_idx}")
150
+
151
+ render_size = 512
152
+ if if_save_video:
153
+ video_path_idx = os.path.join(save_path, f'{name}.mp4')
154
+ render_size = infer_config.render_resolution
155
+ ENV = load_mipmap("models/lrm/env_mipmap/6")
156
+ materials = (0.0,0.9)
157
+
158
+ all_mv, all_mvp, all_campos = get_render_cameras_video(
159
+ batch_size=1,
160
+ M=24,
161
+ radius=4.5,
162
+ elevation=(90, 60.0),
163
+ is_flexicubes=True,
164
+ fov=30
165
+ )
166
+
167
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
168
+ model,
169
+ planes,
170
+ render_cameras=all_mvp,
171
+ camera_pos=all_campos,
172
+ env=ENV,
173
+ materials=materials,
174
+ render_size=render_size,
175
+ chunk_size=20,
176
+ is_flexicubes=True,
177
+ )
178
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
179
+ normals = normals * alphas + (1-alphas)
180
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
181
+
182
+ save_video(
183
+ all_frames,
184
+ video_path_idx,
185
+ fps=30,
186
+ )
187
+ print(f"Video saved to {video_path_idx}")
188
+
189
+ return vertices, faces
190
+
191
+
192
+ def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
193
+ if local_normal_images.min() >= 0:
194
+ local_normal = local_normal_images.float() * 2 - 1
195
+ else:
196
+ local_normal = local_normal_images.float()
197
+ global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
198
+ global_normal[...,0] *= -1
199
+ global_normal = (global_normal + 1) / 2
200
+ global_normal = global_normal.permute(0, 3, 1, 2)
201
+ return global_normal
202
+
203
+ # 生成多视图图像
204
+ @spaces.GPU(duration=120)
205
+ def generate_multi_view_images(prompt, seed):
206
+ # torch.cuda.empty_cache()
207
+ # generator = torch.manual_seed(seed)
208
+ generator = torch.Generator().manual_seed(seed)
209
+ with torch.no_grad():
210
+ img = flux_pipe(
211
+ prompt=prompt,
212
+ num_inference_steps=5,
213
+ guidance_scale=3.5,
214
+ num_images_per_prompt=1,
215
+ width=resolution * 2,
216
+ height=resolution * 1,
217
+ output_type='np',
218
+ generator=generator,
219
+ ).images
220
+ # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
221
+ # prompt=prompt,
222
+ # guidance_scale=3.5,
223
+ # num_inference_steps=4,
224
+ # width=resolution * 4,
225
+ # height=resolution * 2,
226
+ # generator=generator,
227
+ # output_type="np",
228
+ # good_vae=good_vae,
229
+ # ):
230
+ # pass
231
+ # 返回最终的图像和种子(通过外部调用处理)
232
+ return img
233
+
234
+ # 重建 3D 模型
235
+ @spaces.GPU
236
+ def reconstruct_3d_model(images, prompt):
237
+ global model
238
+ model.init_flexicubes_geometry(device_1, fovy=50.0)
239
+ model = model.eval()
240
+ rgb_normal_grid = images
241
+ save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
242
+ os.makedirs(save_dir_path, exist_ok=True)
243
+
244
+ images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
245
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
246
+ rgb_multi_view = images[:4, :3, :, :]
247
+ normal_multi_view = images[4:, :3, :, :]
248
+ multi_view_mask = get_background(normal_multi_view)
249
+ rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
250
+ input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
251
+ vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
252
+ # local normal to global normal
253
+
254
+ global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
255
+ global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
256
+
257
+ global_normal = global_normal.permute(0,2,3,1)
258
+ rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
259
+ multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
260
+ vertices = torch.from_numpy(vertices).to(device_1)
261
+ faces = torch.from_numpy(faces).to(device_1)
262
+ vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
263
+ vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
264
+
265
+ # global_normal: B,H,W,3
266
+ # multi_view_mask: B,H,W
267
+ # rgb_multi_view: B,H,W,3
268
+
269
+ meshes = reconstruction(
270
+ normal_pils=global_normal,
271
+ masks=multi_view_mask,
272
+ weights=isomer_geo_weights,
273
+ fov=30,
274
+ radius=isomer_radius,
275
+ camera_angles_azi=isomer_azimuths,
276
+ camera_angles_ele=isomer_elevations,
277
+ expansion_weight_stage1=0.1,
278
+ init_type="file",
279
+ init_verts=vertices,
280
+ init_faces=faces,
281
+ stage1_steps=0,
282
+ stage2_steps=50,
283
+ start_edge_len_stage1=0.1,
284
+ end_edge_len_stage1=0.02,
285
+ start_edge_len_stage2=0.02,
286
+ end_edge_len_stage2=0.005,
287
+ )
288
+
289
+
290
+ save_glb_addr = projection(
291
+ meshes,
292
+ masks=multi_view_mask,
293
+ images=rgb_multi_view,
294
+ azimuths=isomer_azimuths,
295
+ elevations=isomer_elevations,
296
+ weights=isomer_color_weights,
297
+ fov=30,
298
+ radius=isomer_radius,
299
+ save_dir=f"{save_dir_path}/ISOMER/",
300
+ )
301
+
302
+ return save_glb_addr
303
+
304
+ # Gradio 接口函数
305
+ @spaces.GPU
306
+ def gradio_pipeline(prompt, seed):
307
+ import ctypes
308
+ # 显式加载 libnvrtc.so.12
309
+ cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
310
+ try:
311
+ ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
312
+ print(f"Successfully preloaded {cuda_lib_path}")
313
+ except OSError as e:
314
+ print(f"Failed to preload {cuda_lib_path}: {e}")
315
+ # 生成多视图图像
316
+ # rgb_normal_grid = generate_multi_view_images(prompt, seed)
317
+ rgb_normal_grid = np.load("rgb_normal_grid.npy")
318
+ image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
319
+
320
+ # 3d reconstruction
321
+
322
+
323
+ # 重建 3D 模型并返回 glb 路径
324
+ save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
325
+ # save_glb_addr = None
326
+ return image_preview, save_glb_addr
327
+
328
+ # Gradio Blocks 应用
329
+ with gr.Blocks() as demo:
330
+ with gr.Row(variant="panel"):
331
+ # 左侧输入区域
332
+ with gr.Column():
333
+ with gr.Row():
334
+ prompt_input = gr.Textbox(
335
+ label="Enter Prompt",
336
+ placeholder="Describe your 3D model...",
337
+ lines=2,
338
+ elem_id="prompt_input"
339
+ )
340
+
341
+ with gr.Row():
342
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
343
+
344
+ with gr.Row():
345
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
346
+
347
+ with gr.Row(variant="panel"):
348
+ gr.Markdown("Examples:")
349
+ gr.Examples(
350
+ examples=[
351
+ ["a castle on a hill"],
352
+ ["an owl wearing a hat"],
353
+ ["a futuristic car"]
354
+ ],
355
+ inputs=[prompt_input],
356
+ label="Prompt Examples"
357
+ )
358
+
359
+ # 右侧输出区域
360
+ with gr.Column():
361
+ with gr.Row():
362
+ rgb_normal_grid_image = gr.Image(
363
+ label="RGB Normal Grid",
364
+ type="pil",
365
+ interactive=False
366
+ )
367
+
368
+ with gr.Row():
369
+ with gr.Tab("GLB"):
370
+ output_glb_model = gr.Model3D(
371
+ label="Generated 3D Model (GLB Format)",
372
+ interactive=False
373
+ )
374
+ gr.Markdown("Download the model for proper visualization.")
375
+
376
+ # 处理逻辑
377
+ submit.click(
378
+ fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
379
+ outputs=[rgb_normal_grid_image, output_glb_model]
380
+ )
381
+
382
+ # 启动应用
383
+ # demo.queue(max_size=10)
384
+ demo.launch()
app_demo_.py DELETED
@@ -1,491 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import subprocess
4
- import spaces
5
- import ctypes
6
- import shlex
7
- import torch
8
-
9
- subprocess.run(
10
- shlex.split(
11
- "pip install ./custom_diffusers --force-reinstall --no-deps"
12
- )
13
- )
14
- subprocess.run(
15
- shlex.split(
16
- "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
17
- )
18
- )
19
-
20
- subprocess.run(
21
- shlex.split(
22
- "pip install ./extension/nvdiffrast-0.3.1+torch-py3-none-any.whl --force-reinstall --no-deps"
23
- )
24
- )
25
-
26
- subprocess.run(
27
- shlex.split(
28
- "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
29
- )
30
- )
31
- # download cudatoolkit
32
- def install_cuda_toolkit():
33
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
34
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
35
- CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
36
- CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
37
- subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
38
- subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
39
- subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
40
-
41
- os.environ["CUDA_HOME"] = "/usr/local/cuda"
42
- os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
43
- os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
44
- os.environ["CUDA_HOME"],
45
- "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
46
- )
47
- # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
48
- os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
49
- print("==> finfish install")
50
- install_cuda_toolkit()
51
-
52
-
53
- import base64
54
- import re
55
- import sys
56
-
57
- sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
58
- if 'OMP_NUM_THREADS' not in os.environ:
59
- os.environ['OMP_NUM_THREADS'] = '32'
60
-
61
- import shutil
62
- import json
63
- import requests
64
- import shutil
65
- import threading
66
- from PIL import Image
67
- import time
68
- import trimesh
69
-
70
- import random
71
- import time
72
- import numpy as np
73
- from video_render import render_video_from_obj
74
-
75
- access_token = os.getenv("HUGGINGFACE_TOKEN")
76
- from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d, image2mesh_preprocess, image2mesh_main
77
-
78
-
79
- # Add logo file path and hyperlinks
80
- LOGO_PATH = "app_assets/logo_temp_.png" # Update this to the actual path of your logo
81
- ARXIV_LINK = "https://arxiv.org/abs/example"
82
- GITHUB_LINK = "https://github.com/example"
83
-
84
-
85
- k3d_wrapper = init_wrapper_from_config('./pipeline/pipeline_config/default.yaml')
86
-
87
-
88
- from models.ISOMER.scripts.utils import fix_vert_color_glb
89
- torch.backends.cuda.matmul.allow_tf32 = True
90
-
91
- def check_gpu():
92
- os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
93
- os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
94
- # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
95
- os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
96
- # 显式加载 libnvrtc.so.12
97
- cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
98
- try:
99
- ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
100
- print(f"Successfully preloaded {cuda_lib_path}")
101
- except OSError as e:
102
- print(f"Failed to preload {cuda_lib_path}: {e}")
103
- check_gpu()
104
- print(f"GPU: {torch.cuda.is_available()}")
105
- subprocess.run(['nvidia-smi'])
106
-
107
- TEMP_MESH_ADDRESS=''
108
-
109
- mesh_cache = None
110
- preprocessed_input_image = None
111
-
112
- def save_cached_mesh():
113
- global mesh_cache
114
- return mesh_cache
115
- # if mesh_cache is None:
116
- # return None
117
- # return save_py3dmesh_with_trimesh_fast(mesh_cache)
118
-
119
- def save_py3dmesh_with_trimesh_fast(meshes, save_glb_path=TEMP_MESH_ADDRESS, apply_sRGB_to_LinearRGB=True):
120
- from pytorch3d.structures import Meshes
121
- import trimesh
122
-
123
- # convert from pytorch3d meshes to trimesh mesh
124
- vertices = meshes.verts_packed().cpu().float().numpy()
125
- triangles = meshes.faces_packed().cpu().long().numpy()
126
- np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
127
- if save_glb_path.endswith(".glb"):
128
- # rotate 180 along +Y
129
- vertices[:, [0, 2]] = -vertices[:, [0, 2]]
130
-
131
- def srgb_to_linear(c_srgb):
132
- c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
133
- return c_linear.clip(0, 1.)
134
- if apply_sRGB_to_LinearRGB:
135
- np_color = srgb_to_linear(np_color)
136
- assert vertices.shape[0] == np_color.shape[0]
137
- assert np_color.shape[1] == 3
138
- assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
139
- mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
140
- mesh.remove_unreferenced_vertices()
141
- # save mesh
142
- mesh.export(save_glb_path)
143
- if save_glb_path.endswith(".glb"):
144
- fix_vert_color_glb(save_glb_path)
145
- print(f"saving to {save_glb_path}")
146
- #
147
- #
148
- # @spaces.GPU
149
- def text_to_detailed(prompt, seed=None):
150
- # print(torch.cuda.is_available())
151
- # print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
152
- return k3d_wrapper.get_detailed_prompt(prompt, seed)
153
-
154
- def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
155
- # print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
156
- k3d_wrapper.renew_uuid()
157
- init_image = None
158
- # if init_image_path is not None:
159
- # init_image = Image.open(init_image_path)
160
- result = k3d_wrapper.generate_3d_bundle_image_text(
161
- prompt,
162
- image=init_image,
163
- strength=strength,
164
- lora_scale=lora_scale,
165
- num_inference_steps=num_inference_steps,
166
- seed=int(seed) if seed is not None else None,
167
- redux_hparam=redux_hparam,
168
- save_intermediate_results=True,
169
- **kwargs)
170
- return result[-1]
171
-
172
- def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
173
- global preprocessed_input_image
174
-
175
- seed = int(seed) if seed is not None else None
176
-
177
- # TODO: delete this later
178
- k3d_wrapper.del_llm_model()
179
-
180
- input_image_save_path, reference_save_path, caption = image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb)
181
-
182
- preprocessed_input_image = Image.open(input_image_save_path)
183
- return reference_save_path, caption
184
-
185
- @spaces.GPU
186
- def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
187
- global mesh_cache
188
- seed = int(seed) if seed is not None else None
189
-
190
-
191
- # TODO: delete this later
192
- k3d_wrapper.del_llm_model()
193
-
194
- input_image = preprocessed_input_image
195
-
196
- reference_3d_bundle_image = torch.tensor(reference_3d_bundle_image).permute(2,0,1)/255
197
-
198
- gen_save_path, recon_mesh_path = image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption=caption, seed=seed, strength1=strength1, strength2=strength2, enable_redux=enable_redux, use_controlnet=use_controlnet)
199
- mesh_cache = recon_mesh_path
200
-
201
-
202
- # gen_save_ = Image.open(gen_save_path)
203
-
204
- if if_video:
205
- video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
206
- render_video_from_obj(recon_mesh_path, video_path)
207
- print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
208
- return gen_save_path, video_path
209
- else:
210
- return gen_save_path, recon_mesh_path
211
- # return gen_save_path, recon_mesh_path
212
-
213
- @spaces.GPU
214
- def bundle_image_to_mesh(
215
- gen_3d_bundle_image,
216
- lrm_radius = 4.15,
217
- isomer_radius = 4.5,
218
- reconstruction_stage1_steps = 10,
219
- reconstruction_stage2_steps = 50,
220
- save_intermediate_results=True,
221
- if_video=True
222
- ):
223
- global mesh_cache
224
- print(f"Before bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
225
- k3d_wrapper.recon_model.init_flexicubes_geometry("cuda:0", fovy=50.0)
226
- # TODO: delete this later
227
- k3d_wrapper.del_llm_model()
228
-
229
- print(f"Before bundle_image_to_mesh after deleting llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
230
-
231
- gen_3d_bundle_image = torch.tensor(gen_3d_bundle_image).permute(2,0,1)/255
232
- # recon from 3D Bundle image
233
- recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, lrm_render_radius=lrm_radius, isomer_radius=isomer_radius, save_intermediate_results=save_intermediate_results, reconstruction_stage1_steps=int(reconstruction_stage1_steps), reconstruction_stage2_steps=int(reconstruction_stage2_steps))
234
- mesh_cache = recon_mesh_path
235
-
236
- if if_video:
237
- video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
238
- # # 检查这个video_path文件大小是是否超过50KB,不超过的话就认为是空文件,需要重新渲染
239
- # if os.path.exists(video_path):
240
- # print(f"file size:{os.path.getsize(video_path)}")
241
- # if os.path.getsize(video_path) > 50*1024:
242
- # print(f"video path:{video_path}")
243
- # return video_path
244
- render_video_from_obj(recon_mesh_path, video_path)
245
- print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
246
- return video_path
247
- else:
248
- return recon_mesh_path
249
-
250
- _HEADER_=f"""
251
- <img src="{LOGO_PATH}">
252
- <h2><b>Official 🤗 Gradio Demo</b></h2><h2>
253
- <b>Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation</b></a></h2>
254
-
255
- <p>**Kiss3DGen** is xxxxxxxxx</p>
256
-
257
- [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
258
- """
259
-
260
- _CITE_ = r"""
261
- <h2>If Kiss3DGen is helpful, please help to ⭐ the <a href='{""" + GITHUB_LINK + r"""}' target='_blank'>Github Repo</a>. Thanks!</h2>
262
-
263
- 📝 **Citation**
264
-
265
- If you find our work useful for your research or applications, please cite using this bibtex:
266
- ```bibtex
267
- @article{xxxx,
268
- title={xxxx},
269
- author={xxxx},
270
- journal={xxxx},
271
- year={xxxx}
272
- }
273
- ```
274
-
275
- 📋 **License**
276
-
277
- Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
278
-
279
- 📧 **Contact**
280
-
281
- If you have any questions, feel free to open a discussion or contact us at <b>xxx@xxxx</b>.
282
- """
283
-
284
- def image_to_base64(image_path):
285
- """Converts an image file to a base64-encoded string."""
286
- with open(image_path, "rb") as img_file:
287
- return base64.b64encode(img_file.read()).decode('utf-8')
288
-
289
- def main():
290
-
291
- torch.set_grad_enabled(False)
292
-
293
- # Convert the logo image to base64
294
- logo_base64 = image_to_base64(LOGO_PATH)
295
- # with gr.Blocks() as demo:
296
- with gr.Blocks(css="""
297
- body {
298
- display: flex;
299
- justify-content: center;
300
- align-items: center;
301
- min-height: 100vh;
302
- margin: 0;
303
- padding: 0;
304
- }
305
- #col-container { margin: 0px auto; max-width: 200px; }
306
-
307
-
308
- .gradio-container {
309
- max-width: 1000px;
310
- margin: auto;
311
- width: 100%;
312
- }
313
- #center-align-column {
314
- display: flex;
315
- justify-content: center;
316
- align-items: center;
317
- }
318
- #right-align-column {
319
- display: flex;
320
- justify-content: flex-end;
321
- align-items: center;
322
- }
323
- h1 {text-align: center;}
324
- h2 {text-align: center;}
325
- h3 {text-align: center;}
326
- p {text-align: center;}
327
- img {text-align: right;}
328
- .right {
329
- display: block;
330
- margin-left: auto;
331
- }
332
- .center {
333
- display: block;
334
- margin-left: auto;
335
- margin-right: auto;
336
- width: 50%;
337
-
338
- #content-container {
339
- max-width: 1200px;
340
- margin: 0 auto;
341
- }
342
- #example-container {
343
- max-width: 300px;
344
- margin: 0 auto;
345
- }
346
- """,elem_id="col-container") as demo:
347
- # Header Section
348
- # gr.Image(value=LOGO_PATH, width=64, height=64)
349
- # gr.Markdown(_HEADER_)
350
- with gr.Row(elem_id="content-container"):
351
- # with gr.Column(scale=1):
352
- # pass
353
- # with gr.Column(scale=1, elem_id="right-align-column"):
354
- # # gr.Image(value=LOGO_PATH, interactive=False, show_label=False, width=64, height=64, elem_id="logo-image")
355
- # # gr.Markdown(f"<img src='{LOGO_PATH}' alt='Logo' style='width:64px;height:64px;border:0;'>")
356
- # # gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='right' style='width:64px;height:64px;border:0;text-align:right;'>")
357
- # pass
358
- with gr.Column(scale=7, elem_id="center-align-column"):
359
- gr.Markdown(f"""
360
- ## Official 🤗 Gradio Demo
361
- # Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation""")
362
- gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='center' style='width:64px;height:64px;border:0;text-align:center;'>")
363
-
364
- gr.HTML(f"""
365
- <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
366
- <a href="{ARXIV_LINK}" target="_blank">
367
- <img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv">
368
- </a>
369
- <a href="{GITHUB_LINK}" target="_blank">
370
- <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub">
371
- </a>
372
- </div>
373
-
374
- """)
375
-
376
-
377
- # gr.HTML(f"""
378
- # <div style="display: flex; gap: 10px; align-items: center;"><a href="{ARXIV_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv"></a> <a href="{GITHUB_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub"></a></div>
379
- # """)
380
-
381
- # gr.Markdown(f"""
382
- # [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
383
- # """, elem_id="title")
384
- # with gr.Column(scale=1):
385
- # pass
386
- # with gr.Row():
387
- # gr.Markdown(f"[![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK})")
388
- # gr.Markdown(f"[![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})")
389
-
390
- # Tabs Section
391
- with gr.Tabs(selected='tab_text_to_3d', elem_id="content-container") as main_tabs:
392
- with gr.TabItem('Text-to-3D', id='tab_text_to_3d'):
393
- with gr.Row():
394
- with gr.Column(scale=1):
395
- prompt = gr.Textbox(value="", label="Input Prompt", lines=4)
396
- seed1 = gr.Number(value=10, label="Seed")
397
-
398
- with gr.Row(elem_id="example-container"):
399
- gr.Examples(
400
- examples=[
401
- # ["A tree with red leaves"],
402
- # ["A dragon with black texture"],
403
- ["A girl with pink hair"],
404
- ["A boy playing guitar"],
405
-
406
-
407
- ["A dog wearing a hat"],
408
- ["A boy playing basketball"],
409
- # [""],
410
- # [""],
411
- # [""],
412
-
413
- ],
414
- inputs=[prompt], # 将选中的示例填入 prompt 文本框
415
- label="Example Prompts"
416
- )
417
- btn_text2detailed = gr.Button("Refine to detailed prompt")
418
- detailed_prompt = gr.Textbox(value="", label="Detailed Prompt", placeholder="detailed prompt will be generated here base on your input prompt. You can also edit this prompt", lines=4, interactive=True)
419
- btn_text2img = gr.Button("Generate Images")
420
-
421
- with gr.Column(scale=1):
422
- output_image1 = gr.Image(label="Generated image", interactive=False)
423
-
424
-
425
- # lrm_radius = gr.Number(value=4.15, label="lrm_radius")
426
- # isomer_radius = gr.Number(value=4.5, label="isomer_radius")
427
- # reconstruction_stage1_steps = gr.Number(value=10, label="reconstruction_stage1_steps")
428
- # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
429
-
430
- btn_gen_mesh = gr.Button("Generate Mesh")
431
- output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
432
- btn_download1 = gr.Button("Download Mesh")
433
-
434
- file_output1 = gr.File()
435
-
436
- with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
437
- with gr.Row():
438
- with gr.Column(scale=1):
439
- image = gr.Image(label="Input Image", type="pil")
440
-
441
- seed2 = gr.Number(value=10, label="Seed (0 for random)")
442
-
443
- btn_img2mesh_preprocess = gr.Button("Preprocess Image")
444
-
445
- image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
446
-
447
- output_image2 = gr.Image(label="Generated image", interactive=False)
448
- strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="strength1")
449
- strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="strength2")
450
- enable_redux = gr.Checkbox(label="enable redux", value=True)
451
- use_controlnet = gr.Checkbox(label="use controlnet", value=True)
452
-
453
- btn_img2mesh_main = gr.Button("Generate Mesh")
454
-
455
- with gr.Column(scale=1):
456
-
457
- # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
458
- output_image3 = gr.Image(label="gen save image", interactive=False)
459
- output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
460
- btn_download2 = gr.Button("Download Mesh")
461
- file_output2 = gr.File()
462
-
463
- # Image2
464
- btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
465
-
466
- btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2])
467
-
468
-
469
- btn_download2.click(fn=save_cached_mesh, inputs=[], outputs=file_output2)
470
-
471
-
472
- # Button Click Events
473
- # Text2
474
- btn_text2detailed.click(fn=text_to_detailed, inputs=[prompt, seed1], outputs=detailed_prompt)
475
- btn_text2img.click(fn=text_to_image, inputs=[detailed_prompt, seed1], outputs=output_image1)
476
- btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1,], outputs=output_video1)
477
- # btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1, lrm_radius, isomer_radius, reconstruction_stage1_steps, reconstruction_stage2_steps], outputs=output_video1)
478
-
479
- with gr.Row():
480
- pass
481
- with gr.Row():
482
- gr.Markdown(_CITE_)
483
-
484
- # demo.queue(default_concurrency_limit=1)
485
- # demo.launch(server_name="0.0.0.0", server_port=9239)
486
- # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
487
- demo.launch()
488
-
489
-
490
- if __name__ == "__main__":
491
- main()