Ali Mohsin commited on
Commit
0ba16db
·
1 Parent(s): 63bea2f

live preview

Browse files
nvdiffmodeling/src/material.py CHANGED
@@ -22,14 +22,25 @@ def load_mtl(fn, clear_ks=True):
22
  import re
23
  mtl_path = os.path.dirname(fn)
24
 
 
 
 
 
 
25
  # Read file
26
- with open(fn) as f:
27
- lines = f.readlines()
 
 
 
 
28
 
29
  # Parse materials
30
  materials = []
31
  for line in lines:
32
  split_line = re.split(' +|\t+|\n+', line.strip())
 
 
33
  prefix = split_line[0].lower()
34
  data = split_line[1:]
35
  if 'newmtl' in prefix:
@@ -39,33 +50,64 @@ def load_mtl(fn, clear_ks=True):
39
  if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
40
  material[prefix] = data[0]
41
  else:
42
- material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
 
 
 
 
43
 
44
  # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
45
  for mat in materials:
46
  if not 'bsdf' in mat:
47
  mat['bsdf'] = 'pbr'
48
 
 
49
  if 'map_kd' in mat:
50
- mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
51
- else:
 
 
 
 
52
  mat['kd'] = texture.Texture2D(mat['kd'])
 
 
 
53
 
 
54
  if 'map_ks' in mat:
55
- mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
56
- else:
 
 
 
 
57
  mat['ks'] = texture.Texture2D(mat['ks'])
 
 
 
58
 
 
59
  if 'bump' in mat:
60
- mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
 
 
 
 
61
 
62
  # Convert Kd from sRGB to linear RGB
63
- mat['kd'] = texture.srgb_to_rgb(mat['kd'])
 
 
 
64
 
65
  if clear_ks:
66
  # Override ORM occlusion (red) channel by zeros. We hijack this channel
67
- for mip in mat['ks'].getMips():
68
- mip[..., 0] = 0.0
 
 
 
69
 
70
  return materials
71
 
 
22
  import re
23
  mtl_path = os.path.dirname(fn)
24
 
25
+ # Check if file exists
26
+ if not os.path.exists(fn):
27
+ print(f"Warning: Material file {fn} does not exist, returning empty material list")
28
+ return []
29
+
30
  # Read file
31
+ try:
32
+ with open(fn) as f:
33
+ lines = f.readlines()
34
+ except Exception as e:
35
+ print(f"Warning: Could not read material file {fn}: {e}, returning empty material list")
36
+ return []
37
 
38
  # Parse materials
39
  materials = []
40
  for line in lines:
41
  split_line = re.split(' +|\t+|\n+', line.strip())
42
+ if len(split_line) == 0:
43
+ continue
44
  prefix = split_line[0].lower()
45
  data = split_line[1:]
46
  if 'newmtl' in prefix:
 
50
  if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
51
  material[prefix] = data[0]
52
  else:
53
+ try:
54
+ material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
55
+ except (ValueError, IndexError) as e:
56
+ print(f"Warning: Could not parse material property {prefix} with data {data}: {e}")
57
+ continue
58
 
59
  # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
60
  for mat in materials:
61
  if not 'bsdf' in mat:
62
  mat['bsdf'] = 'pbr'
63
 
64
+ # Handle kd (diffuse color)
65
  if 'map_kd' in mat:
66
+ try:
67
+ mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
68
+ except Exception as e:
69
+ print(f"Warning: Could not load kd texture {mat['map_kd']}: {e}, using default")
70
+ mat['kd'] = texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda'))
71
+ elif 'kd' in mat:
72
  mat['kd'] = texture.Texture2D(mat['kd'])
73
+ else:
74
+ # Default diffuse color
75
+ mat['kd'] = texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda'))
76
 
77
+ # Handle ks (specular color)
78
  if 'map_ks' in mat:
79
+ try:
80
+ mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
81
+ except Exception as e:
82
+ print(f"Warning: Could not load ks texture {mat['map_ks']}: {e}, using default")
83
+ mat['ks'] = texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
84
+ elif 'ks' in mat:
85
  mat['ks'] = texture.Texture2D(mat['ks'])
86
+ else:
87
+ # Default specular color
88
+ mat['ks'] = texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
89
 
90
+ # Handle normal map
91
  if 'bump' in mat:
92
+ try:
93
+ mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
94
+ except Exception as e:
95
+ print(f"Warning: Could not load normal texture {mat['bump']}: {e}, using default")
96
+ mat['normal'] = texture.Texture2D(torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
97
 
98
  # Convert Kd from sRGB to linear RGB
99
+ try:
100
+ mat['kd'] = texture.srgb_to_rgb(mat['kd'])
101
+ except Exception as e:
102
+ print(f"Warning: Could not convert kd to linear RGB: {e}")
103
 
104
  if clear_ks:
105
  # Override ORM occlusion (red) channel by zeros. We hijack this channel
106
+ try:
107
+ for mip in mat['ks'].getMips():
108
+ mip[..., 0] = 0.0
109
+ except Exception as e:
110
+ print(f"Warning: Could not clear ks occlusion channel: {e}")
111
 
112
  return materials
113
 
nvdiffmodeling/src/obj.py CHANGED
@@ -41,6 +41,8 @@ def _find_mat(materials, name):
41
 
42
  def load_obj(filename, clear_ks=True, mtl_override=None):
43
  obj_path = os.path.dirname(filename)
 
 
44
 
45
  # Read entire file
46
  with open(filename) as f:
@@ -60,9 +62,27 @@ def load_obj(filename, clear_ks=True, mtl_override=None):
60
  if len(line.split()) == 0:
61
  continue
62
  if line.split()[0] == 'mtllib':
63
- all_materials += material.load_mtl(obj_path + os.path.join(line.split()[1]), clear_ks) # Read in entire material library #obj_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  else:
65
- all_materials += material.load_mtl(mtl_override)
 
 
 
66
 
67
  # load vertices
68
  vertices, texcoords, normals = [], [], []
 
41
 
42
  def load_obj(filename, clear_ks=True, mtl_override=None):
43
  obj_path = os.path.dirname(filename)
44
+ print(f"Loading OBJ from: {filename}")
45
+ print(f"OBJ path: {obj_path}")
46
 
47
  # Read entire file
48
  with open(filename) as f:
 
62
  if len(line.split()) == 0:
63
  continue
64
  if line.split()[0] == 'mtllib':
65
+ try:
66
+ mtl_filename = line.split()[1].strip()
67
+ # Handle both relative and absolute paths
68
+ if os.path.isabs(mtl_filename):
69
+ mtl_path = mtl_filename
70
+ else:
71
+ mtl_path = os.path.join(obj_path, mtl_filename)
72
+
73
+ print(f"Looking for material file: {mtl_path}")
74
+ if os.path.exists(mtl_path):
75
+ print(f"Found material file: {mtl_path}")
76
+ all_materials += material.load_mtl(mtl_path, clear_ks)
77
+ else:
78
+ print(f"Warning: Material file {mtl_path} not found, using default materials")
79
+ except Exception as e:
80
+ print(f"Warning: Could not load material file {line.split()[1]}: {e}, using default materials")
81
  else:
82
+ try:
83
+ all_materials += material.load_mtl(mtl_override)
84
+ except Exception as e:
85
+ print(f"Warning: Could not load material override {mtl_override}: {e}, using default materials")
86
 
87
  # load vertices
88
  vertices, texcoords, normals = [], [], []
utils.py CHANGED
@@ -11,36 +11,50 @@ normal_map = texture.create_trainable(np.array([0, 0, 1]), [512] * 2, True)
11
  specular_map = texture.create_trainable(np.array([0, 0, 0]), [512] * 2, True)
12
 
13
  def get_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
14
- ms = pymeshlab.MeshSet()
15
- ms.load_new_mesh(mesh_path)
16
-
17
- if triangulate_flag:
18
- print('Retriangulating shape')
19
- ms.meshing_isotropic_explicit_remeshing()
20
-
21
- if not ms.current_mesh().has_wedge_tex_coord():
22
- # some arbitrarily high number
23
- ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000)
24
-
25
- ms.save_current_mesh(str(output_path / 'tmp' / mesh_name))
26
-
27
- load_mesh = obj.load_obj(str(output_path / 'tmp' / mesh_name))
28
- load_mesh = mesh.unit_size(load_mesh)
29
-
30
- ms.add_mesh(
31
- pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy()))
32
- ms.save_current_mesh(str(output_path / 'tmp' / mesh_name), save_vertex_color=False)
33
-
34
- load_mesh = mesh.Mesh(
35
- material={
36
- 'bsdf': bsdf_flag,
37
- 'kd': texture_map,
38
- 'ks': specular_map,
39
- 'normal': normal_map,
40
- },
41
- base=load_mesh # Get UVs from original loaded mesh
42
- )
43
- return load_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
 
11
  specular_map = texture.create_trainable(np.array([0, 0, 0]), [512] * 2, True)
12
 
13
  def get_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):
14
+ try:
15
+ print(f"Loading mesh from: {mesh_path}")
16
+ ms = pymeshlab.MeshSet()
17
+ ms.load_new_mesh(mesh_path)
18
+
19
+ if triangulate_flag:
20
+ print('Retriangulating shape')
21
+ ms.meshing_isotropic_explicit_remeshing()
22
+
23
+ if not ms.current_mesh().has_wedge_tex_coord():
24
+ # some arbitrarily high number
25
+ ms.compute_texcoord_parametrization_triangle_trivial_per_wedge(textdim=10000)
26
+
27
+ # Ensure the tmp directory exists
28
+ tmp_dir = output_path / 'tmp'
29
+ tmp_dir.mkdir(exist_ok=True)
30
+
31
+ tmp_mesh_path = tmp_dir / mesh_name
32
+ print(f"Saving temporary mesh to: {tmp_mesh_path}")
33
+ ms.save_current_mesh(str(tmp_mesh_path))
34
+
35
+ print(f"Loading OBJ from temporary path: {tmp_mesh_path}")
36
+ load_mesh = obj.load_obj(str(tmp_mesh_path))
37
+ load_mesh = mesh.unit_size(load_mesh)
38
+
39
+ ms.add_mesh(
40
+ pymeshlab.Mesh(vertex_matrix=load_mesh.v_pos.cpu().numpy(), face_matrix=load_mesh.t_pos_idx.cpu().numpy()))
41
+ ms.save_current_mesh(str(tmp_mesh_path), save_vertex_color=False)
42
+
43
+ load_mesh = mesh.Mesh(
44
+ material={
45
+ 'bsdf': bsdf_flag,
46
+ 'kd': texture_map,
47
+ 'ks': specular_map,
48
+ 'normal': normal_map,
49
+ },
50
+ base=load_mesh # Get UVs from original loaded mesh
51
+ )
52
+ return load_mesh
53
+ except Exception as e:
54
+ print(f"Error in get_mesh: {e}")
55
+ import traceback
56
+ traceback.print_exc()
57
+ raise
58
 
59
 
60
  def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='mesh.obj'):