Ali Mohsin commited on
Commit
4c539b3
·
1 Parent(s): e2b6169

final fixes

Browse files
Files changed (4) hide show
  1. app.py +190 -31
  2. loop.py +19 -9
  3. utilities/clip_spatial.py +52 -5
  4. utils.py +5 -0
app.py CHANGED
@@ -27,6 +27,94 @@ try:
27
  except:
28
  pass
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Check if complex dependencies are installed
31
  def check_complex_dependencies():
32
  """Check if complex dependencies are available"""
@@ -668,25 +756,20 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
668
 
669
  progress(0.9, desc="Processing complete, preparing output...")
670
 
671
- # Look for output files, prioritize mesh files
672
  obj_files = []
673
  glb_files = []
674
  image_files = []
675
 
676
- print("Searching for output files...")
677
 
678
  # First check for mesh files in mesh_final directory (priority)
679
  mesh_final_dir = Path(temp_dir) / "mesh_final"
680
  if mesh_final_dir.exists():
681
  print(f"Found mesh_final directory at {mesh_final_dir}")
682
- for file_path in mesh_final_dir.rglob("*"):
683
- if file_path.is_file():
684
- if file_path.suffix.lower() == '.obj':
685
- obj_files.append(str(file_path))
686
- print(f"Found OBJ file: {file_path}")
687
- elif file_path.suffix.lower() == '.glb':
688
- glb_files.append(str(file_path))
689
- print(f"Found GLB file: {file_path}")
690
  else:
691
  print("mesh_final directory not found")
692
 
@@ -694,14 +777,9 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
694
  for mesh_dir in Path(temp_dir).glob("mesh_*"):
695
  if mesh_dir.is_dir() and mesh_dir.name != 'mesh_final':
696
  print(f"Checking directory: {mesh_dir}")
697
- for file_path in mesh_dir.rglob("*"):
698
- if file_path.is_file():
699
- if file_path.suffix.lower() == '.obj':
700
- obj_files.append(str(file_path))
701
- print(f"Found OBJ file: {file_path}")
702
- elif file_path.suffix.lower() == '.glb':
703
- glb_files.append(str(file_path))
704
- print(f"Found GLB file: {file_path}")
705
 
706
  # Collect image files for visualization
707
  for file_path in Path(temp_dir).rglob("*"):
@@ -732,6 +810,28 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
732
  # Return None instead of an error string to avoid file not found errors with Gradio
733
  return None
734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  def create_interface():
736
  """
737
  Create the Gradio interface with simplified components
@@ -748,6 +848,7 @@ def create_interface():
748
  2. For **Text** mode: Enter descriptions of your target and base garment styles
749
  3. For **Image to Mesh** mode: Upload an image to generate a 3D mesh directly and select a base mesh type
750
  4. Click "Generate 3D Garment" to create your 3D mesh file
 
751
  """)
752
 
753
  with gr.Row():
@@ -838,12 +939,21 @@ def create_interface():
838
  generate_btn = gr.Button("Generate 3D Garment")
839
 
840
  with gr.Column():
 
841
  output = gr.File(
842
- label="Generated 3D Garment",
843
  file_types=[".obj", ".glb", ".png", ".jpg"],
844
  file_count="single"
845
  )
846
 
 
 
 
 
 
 
 
 
847
  gr.Markdown("""
848
  ## Tips:
849
 
@@ -851,6 +961,8 @@ def create_interface():
851
  - For image to mesh mode: Use clear, front-facing garment images to generate a 3D mesh directly
852
  - Choose the appropriate base mesh type that matches your target garment
853
  - Higher epochs = better quality but longer processing time
 
 
854
  - Output files can be downloaded by clicking on them
855
 
856
  Processing may take several minutes.
@@ -887,33 +999,55 @@ def create_interface():
887
  status_msg
888
  )
889
 
890
- # Function to handle processing with better error feedback
891
  def process_with_feedback(*args):
892
  try:
893
  # Check if processing engine is available
894
  if loop is None:
895
- return None, "❌ ERROR: Processing engine not available. Please check that all dependencies are properly installed."
896
 
897
  result = process_garment(*args)
898
  if result is None:
899
- return None, "Processing completed but no output files were generated. Please check the logs for more details."
900
  elif isinstance(result, str) and result.startswith("Error:"):
901
- # Return None for the file output and the error message for status
902
- return None, result
903
  elif isinstance(result, str) and os.path.exists(result):
904
- # Valid file path
905
- return result, "🎉 Processing completed successfully! Download your 3D garment file below."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
  elif isinstance(result, str):
907
  # Some other string that's not an error and not a file path
908
- return None, f"Unexpected result: {result}"
909
  else:
910
  # Should be a file path or None
911
- return result, "🎉 Processing completed successfully! Download your 3D garment file below."
912
  except Exception as e:
913
  import traceback
914
  print(f"Error in interface: {str(e)}")
915
  print(traceback.format_exc())
916
- return None, f"❌ Error: {str(e)}"
917
 
918
  # Toggle visibility based on input mode with better feedback
919
  input_type.change(
@@ -923,7 +1057,7 @@ def create_interface():
923
  show_progress=True
924
  )
925
 
926
- # Connect the button to the processing function with error handling
927
  generate_btn.click(
928
  fn=process_with_feedback,
929
  inputs=[
@@ -938,7 +1072,32 @@ def create_interface():
938
  clip_weight,
939
  delta_clip_weight
940
  ],
941
- outputs=[output, status_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
942
  )
943
 
944
  return interface
 
27
  except:
28
  pass
29
 
30
+ # Function to convert OBJ to GLB format
31
+ def convert_obj_to_glb(obj_file_path, glb_file_path=None):
32
+ """
33
+ Convert OBJ file to GLB format using trimesh
34
+
35
+ Args:
36
+ obj_file_path: Path to the OBJ file
37
+ glb_file_path: Path for the output GLB file (optional)
38
+
39
+ Returns:
40
+ Path to the created GLB file or None if conversion failed
41
+ """
42
+ try:
43
+ import trimesh
44
+ print(f"Converting {obj_file_path} to GLB format...")
45
+
46
+ # Check if input file exists
47
+ if not os.path.exists(obj_file_path):
48
+ print(f"Error: OBJ file {obj_file_path} does not exist")
49
+ return None
50
+
51
+ # Load the OBJ file
52
+ mesh = trimesh.load(obj_file_path)
53
+
54
+ # If no GLB path specified, create one in the same directory
55
+ if glb_file_path is None:
56
+ glb_file_path = str(Path(obj_file_path).with_suffix('.glb'))
57
+
58
+ # Export as GLB
59
+ mesh.export(glb_file_path, file_type='glb')
60
+ print(f"Successfully converted to GLB: {glb_file_path}")
61
+ return glb_file_path
62
+
63
+ except ImportError:
64
+ print("trimesh not available for GLB conversion")
65
+ return None
66
+ except Exception as e:
67
+ print(f"Error converting OBJ to GLB: {e}")
68
+ return None
69
+
70
+ # Function to ensure both OBJ and GLB files are created
71
+ def ensure_mesh_formats(output_dir):
72
+ """
73
+ Ensure both OBJ and GLB files are available in the output directory
74
+
75
+ Args:
76
+ output_dir: Directory containing mesh files
77
+
78
+ Returns:
79
+ tuple: (obj_files, glb_files) - lists of available files
80
+ """
81
+ obj_files = []
82
+ glb_files = []
83
+
84
+ output_path = Path(output_dir)
85
+
86
+ if not output_path.exists():
87
+ print(f"Warning: Output directory {output_dir} does not exist")
88
+ return obj_files, glb_files
89
+
90
+ # Find all OBJ files
91
+ for obj_file in output_path.rglob("*.obj"):
92
+ obj_files.append(str(obj_file))
93
+ print(f"Found OBJ file: {obj_file}")
94
+
95
+ # Try to create corresponding GLB file
96
+ glb_file = obj_file.with_suffix('.glb')
97
+ if not glb_file.exists():
98
+ print(f"Creating GLB file for {obj_file}")
99
+ glb_path = convert_obj_to_glb(str(obj_file), str(glb_file))
100
+ if glb_path:
101
+ glb_files.append(glb_path)
102
+ print(f"Successfully created GLB: {glb_path}")
103
+ else:
104
+ print(f"Failed to create GLB for {obj_file}")
105
+ else:
106
+ glb_files.append(str(glb_file))
107
+ print(f"GLB file already exists: {glb_file}")
108
+
109
+ # Also check for existing GLB files that might not have corresponding OBJ
110
+ for glb_file in output_path.rglob("*.glb"):
111
+ if str(glb_file) not in glb_files:
112
+ glb_files.append(str(glb_file))
113
+ print(f"Found standalone GLB file: {glb_file}")
114
+
115
+ print(f"Total files found: {len(obj_files)} OBJ files, {len(glb_files)} GLB files")
116
+ return obj_files, glb_files
117
+
118
  # Check if complex dependencies are installed
119
  def check_complex_dependencies():
120
  """Check if complex dependencies are available"""
 
756
 
757
  progress(0.9, desc="Processing complete, preparing output...")
758
 
759
+ # Look for output files and ensure both OBJ and GLB formats are available
760
  obj_files = []
761
  glb_files = []
762
  image_files = []
763
 
764
+ print("Searching for output files and ensuring GLB conversion...")
765
 
766
  # First check for mesh files in mesh_final directory (priority)
767
  mesh_final_dir = Path(temp_dir) / "mesh_final"
768
  if mesh_final_dir.exists():
769
  print(f"Found mesh_final directory at {mesh_final_dir}")
770
+ # Ensure both OBJ and GLB formats are available
771
+ obj_files, glb_files = ensure_mesh_formats(mesh_final_dir)
772
+ print(f"Found {len(obj_files)} OBJ files and {len(glb_files)} GLB files in mesh_final")
 
 
 
 
 
773
  else:
774
  print("mesh_final directory not found")
775
 
 
777
  for mesh_dir in Path(temp_dir).glob("mesh_*"):
778
  if mesh_dir.is_dir() and mesh_dir.name != 'mesh_final':
779
  print(f"Checking directory: {mesh_dir}")
780
+ dir_obj_files, dir_glb_files = ensure_mesh_formats(mesh_dir)
781
+ obj_files.extend(dir_obj_files)
782
+ glb_files.extend(dir_glb_files)
 
 
 
 
 
783
 
784
  # Collect image files for visualization
785
  for file_path in Path(temp_dir).rglob("*"):
 
810
  # Return None instead of an error string to avoid file not found errors with Gradio
811
  return None
812
 
813
+ def create_combined_mesh_output(output_dir):
814
+ """
815
+ Create a combined output showing both OBJ and GLB files if available
816
+
817
+ Args:
818
+ output_dir: Directory containing mesh files
819
+
820
+ Returns:
821
+ tuple: (primary_file, secondary_file, status_message)
822
+ """
823
+ obj_files, glb_files = ensure_mesh_formats(output_dir)
824
+
825
+ if glb_files and obj_files:
826
+ # Both formats available - return GLB as primary (better for web viewing)
827
+ return glb_files[0], obj_files[0], "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
828
+ elif glb_files:
829
+ return glb_files[0], None, "🎉 Success! GLB file generated and ready for download."
830
+ elif obj_files:
831
+ return obj_files[0], None, "🎉 Success! OBJ file generated and ready for download."
832
+ else:
833
+ return None, None, "❌ No mesh files were generated. Please check the processing logs."
834
+
835
  def create_interface():
836
  """
837
  Create the Gradio interface with simplified components
 
848
  2. For **Text** mode: Enter descriptions of your target and base garment styles
849
  3. For **Image to Mesh** mode: Upload an image to generate a 3D mesh directly and select a base mesh type
850
  4. Click "Generate 3D Garment" to create your 3D mesh file
851
+ 5. **GLB files** are automatically generated for better web viewing and virtual try-on compatibility
852
  """)
853
 
854
  with gr.Row():
 
939
  generate_btn = gr.Button("Generate 3D Garment")
940
 
941
  with gr.Column():
942
+ # Primary output (GLB preferred)
943
  output = gr.File(
944
+ label="Generated 3D Garment (GLB/OBJ)",
945
  file_types=[".obj", ".glb", ".png", ".jpg"],
946
  file_count="single"
947
  )
948
 
949
+ # Secondary output (OBJ if GLB is primary)
950
+ secondary_output = gr.File(
951
+ label="Alternative Format (OBJ/GLB)",
952
+ file_types=[".obj", ".glb"],
953
+ file_count="single",
954
+ visible=False
955
+ )
956
+
957
  gr.Markdown("""
958
  ## Tips:
959
 
 
961
  - For image to mesh mode: Use clear, front-facing garment images to generate a 3D mesh directly
962
  - Choose the appropriate base mesh type that matches your target garment
963
  - Higher epochs = better quality but longer processing time
964
+ - **GLB files** are automatically generated for better web viewing and virtual try-on compatibility
965
+ - **OBJ files** are also available for traditional 3D software compatibility
966
  - Output files can be downloaded by clicking on them
967
 
968
  Processing may take several minutes.
 
999
  status_msg
1000
  )
1001
 
1002
+ # Function to handle processing with better error feedback and dual output
1003
  def process_with_feedback(*args):
1004
  try:
1005
  # Check if processing engine is available
1006
  if loop is None:
1007
+ return None, None, "❌ ERROR: Processing engine not available. Please check that all dependencies are properly installed."
1008
 
1009
  result = process_garment(*args)
1010
  if result is None:
1011
+ return None, None, "Processing completed but no output files were generated. Please check the logs for more details."
1012
  elif isinstance(result, str) and result.startswith("Error:"):
1013
+ # Return None for the file outputs and the error message for status
1014
+ return None, None, result
1015
  elif isinstance(result, str) and os.path.exists(result):
1016
+ # Valid file path - check if we can create a combined output
1017
+ result_path = Path(result)
1018
+ if result_path.suffix.lower() == '.glb':
1019
+ # GLB file - try to find corresponding OBJ
1020
+ obj_file = result_path.with_suffix('.obj')
1021
+ if obj_file.exists():
1022
+ return result, str(obj_file), "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
1023
+ else:
1024
+ return result, None, "🎉 Success! GLB file generated and ready for download."
1025
+ elif result_path.suffix.lower() == '.obj':
1026
+ # OBJ file - try to find corresponding GLB or create one
1027
+ glb_file = result_path.with_suffix('.glb')
1028
+ if glb_file.exists():
1029
+ return str(glb_file), result, "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
1030
+ else:
1031
+ # Try to convert OBJ to GLB
1032
+ glb_path = convert_obj_to_glb(result)
1033
+ if glb_path:
1034
+ return glb_path, result, "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
1035
+ else:
1036
+ return result, None, "🎉 Success! OBJ file generated and ready for download."
1037
+ else:
1038
+ # Some other file type
1039
+ return result, None, "🎉 Processing completed successfully! Download your file below."
1040
  elif isinstance(result, str):
1041
  # Some other string that's not an error and not a file path
1042
+ return None, None, f"Unexpected result: {result}"
1043
  else:
1044
  # Should be a file path or None
1045
+ return result, None, "🎉 Processing completed successfully! Download your 3D garment file below."
1046
  except Exception as e:
1047
  import traceback
1048
  print(f"Error in interface: {str(e)}")
1049
  print(traceback.format_exc())
1050
+ return None, None, f"❌ Error: {str(e)}"
1051
 
1052
  # Toggle visibility based on input mode with better feedback
1053
  input_type.change(
 
1057
  show_progress=True
1058
  )
1059
 
1060
+ # Connect the button to the processing function with error handling and dual output
1061
  generate_btn.click(
1062
  fn=process_with_feedback,
1063
  inputs=[
 
1072
  clip_weight,
1073
  delta_clip_weight
1074
  ],
1075
+ outputs=[output, secondary_output, status_output]
1076
+ )
1077
+
1078
+ # Update secondary output visibility when primary output changes
1079
+ def update_secondary_visibility(primary_file):
1080
+ """Update secondary output visibility based on whether both formats are available"""
1081
+ if primary_file is not None and primary_file != "":
1082
+ # Check if there's a corresponding file in the other format
1083
+ primary_path = Path(primary_file)
1084
+ if primary_path.suffix.lower() == '.glb':
1085
+ # Check if corresponding OBJ exists
1086
+ obj_file = primary_path.with_suffix('.obj')
1087
+ if obj_file.exists():
1088
+ return gr.update(visible=True)
1089
+ elif primary_path.suffix.lower() == '.obj':
1090
+ # Check if corresponding GLB exists
1091
+ glb_file = primary_path.with_suffix('.glb')
1092
+ if glb_file.exists():
1093
+ return gr.update(visible=True)
1094
+ return gr.update(visible=False)
1095
+
1096
+ # Connect the secondary output visibility to the primary output
1097
+ output.change(
1098
+ fn=update_secondary_visibility,
1099
+ inputs=[output],
1100
+ outputs=[secondary_output]
1101
  )
1102
 
1103
  return interface
loop.py CHANGED
@@ -1,4 +1,3 @@
1
- import clip
2
  import kornia
3
  import os
4
  import sys
@@ -108,15 +107,26 @@ def loop(cfg):
108
  fashion_image = cfg.get('fashion_image', False)
109
  fashion_text = cfg.get('fashion_text', True) # Default to fashion text mode
110
  use_target_mesh = cfg.get('use_target_mesh', True)
111
- CLIP_embeddings = False
112
 
113
- if CLIP_embeddings:
114
- print('Loading CLIP Models')
115
- model, preprocess = clip.load(cfg.clip_model, device=device)
116
- else:
117
  fclip = FashionCLIP('fashion-clip')
118
-
119
- fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Use FashionCLIP for all modes to avoid CLIP loading issues
122
  if fashion_image:
@@ -434,7 +444,7 @@ def loop(cfg):
434
  r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
435
  logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
436
 
437
- if cfg.consistency_loss_weight != 0:
438
  consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
439
  else:
440
  consistency_loss = r_loss
 
 
1
  import kornia
2
  import os
3
  import sys
 
107
  fashion_image = cfg.get('fashion_image', False)
108
  fashion_text = cfg.get('fashion_text', True) # Default to fashion text mode
109
  use_target_mesh = cfg.get('use_target_mesh', True)
110
+ CLIP_embeddings = False # Always use FashionCLIP to avoid CLIP issues
111
 
112
+ # Always use FashionCLIP to avoid CLIP loading issues
113
+ print('Loading FashionCLIP model...')
114
+ try:
 
115
  fclip = FashionCLIP('fashion-clip')
116
+ print('FashionCLIP loaded successfully')
117
+ except Exception as e:
118
+ print(f'Error loading FashionCLIP: {e}')
119
+ raise RuntimeError(f"Failed to load FashionCLIP: {e}")
120
+
121
+ # Load CLIPVisualEncoder with error handling
122
+ print('Loading CLIPVisualEncoder...')
123
+ try:
124
+ fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
125
+ print('CLIPVisualEncoder loaded successfully')
126
+ except Exception as e:
127
+ print(f'Error loading CLIPVisualEncoder: {e}')
128
+ print('Continuing without CLIPVisualEncoder...')
129
+ fe = None
130
 
131
  # Use FashionCLIP for all modes to avoid CLIP loading issues
132
  if fashion_image:
 
444
  r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
445
  logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
446
 
447
+ if cfg.consistency_loss_weight != 0 and fe is not None:
448
  consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
449
  else:
450
  consistency_loss = r_loss
utilities/clip_spatial.py CHANGED
@@ -3,11 +3,18 @@ import math
3
  import types
4
  import typing
5
 
6
- import clip
7
  import torch
8
  import torch.nn as nn
9
  from torchvision import models, transforms
10
 
 
 
 
 
 
 
 
 
11
  # code lifted from CLIPasso
12
 
13
  # For ViT
@@ -25,10 +32,50 @@ class CLIPVisualEncoder(nn.Module):
25
 
26
 
27
  def load_model(self, model_name, device):
28
- model, preprocess = clip.load(model_name, device=device)
29
- self.model = model.visual
30
- self.mean = torch.tensor(preprocess.transforms[-1].mean, device=device)
31
- self.std = torch.tensor(preprocess.transforms[-1].std, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  @staticmethod
34
  def _fix_pos_enc(patch_size: int, stride_hw: typing.Tuple[int, int]):
 
3
  import types
4
  import typing
5
 
 
6
  import torch
7
  import torch.nn as nn
8
  from torchvision import models, transforms
9
 
10
+ # Try to import CLIP, but handle import errors gracefully
11
+ try:
12
+ import clip
13
+ CLIP_AVAILABLE = True
14
+ except ImportError:
15
+ print("Warning: CLIP not available, using FashionCLIP fallback")
16
+ CLIP_AVAILABLE = False
17
+
18
  # code lifted from CLIPasso
19
 
20
  # For ViT
 
32
 
33
 
34
  def load_model(self, model_name, device):
35
+ if CLIP_AVAILABLE:
36
+ try:
37
+ model, preprocess = clip.load(model_name, device=device)
38
+ self.model = model.visual
39
+ self.mean = torch.tensor(preprocess.transforms[-1].mean, device=device)
40
+ self.std = torch.tensor(preprocess.transforms[-1].std, device=device)
41
+ except Exception as e:
42
+ print(f"Error loading CLIP model: {e}")
43
+ print("Falling back to FashionCLIP...")
44
+ self._load_fashion_clip_fallback(device)
45
+ else:
46
+ print("CLIP not available, using FashionCLIP fallback...")
47
+ self._load_fashion_clip_fallback(device)
48
+
49
+ def _load_fashion_clip_fallback(self, device):
50
+ """Fallback method using FashionCLIP when regular CLIP is not available"""
51
+ try:
52
+ from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP
53
+ fclip = FashionCLIP('fashion-clip')
54
+ self.model = fclip.model.vision_model
55
+ # Use standard CLIP mean and std values
56
+ self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
57
+ self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
58
+ print("Successfully loaded FashionCLIP fallback")
59
+ except Exception as e:
60
+ print(f"Error loading FashionCLIP fallback: {e}")
61
+ # Create a dummy model if all else fails
62
+ self._create_dummy_model(device)
63
+
64
+ def _create_dummy_model(self, device):
65
+ """Create a dummy model when all CLIP options fail"""
66
+ print("Creating dummy CLIP model - functionality will be limited")
67
+ # Create a simple dummy model structure
68
+ class DummyModel:
69
+ def __init__(self):
70
+ self.conv1 = nn.Conv2d(3, 768, kernel_size=16, stride=16)
71
+ self.class_embedding = nn.Parameter(torch.randn(768))
72
+ self.positional_embedding = nn.Parameter(torch.randn(197, 768))
73
+ self.ln_pre = nn.LayerNorm(768)
74
+ self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(768, 12), 12)
75
+
76
+ self.model = DummyModel()
77
+ self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
78
+ self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
79
 
80
  @staticmethod
81
  def _fix_pos_enc(patch_size: int, stride_hw: typing.Tuple[int, int]):
utils.py CHANGED
@@ -77,6 +77,11 @@ def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='
77
 
78
  def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device):
79
  # Consistency loss
 
 
 
 
 
80
  # Get mapping from vertex to pixels
81
  curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224)
82
  for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)):
 
77
 
78
  def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device):
79
  # Consistency loss
80
+ # Check if fe is available
81
+ if fe is None:
82
+ print("Warning: CLIPVisualEncoder not available, skipping consistency loss")
83
+ return torch.tensor(0.0, device=device)
84
+
85
  # Get mapping from vertex to pixels
86
  curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224)
87
  for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)):