Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
·
4c539b3
1
Parent(s):
e2b6169
final fixes
Browse files- app.py +190 -31
- loop.py +19 -9
- utilities/clip_spatial.py +52 -5
- utils.py +5 -0
app.py
CHANGED
|
@@ -27,6 +27,94 @@ try:
|
|
| 27 |
except:
|
| 28 |
pass
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Check if complex dependencies are installed
|
| 31 |
def check_complex_dependencies():
|
| 32 |
"""Check if complex dependencies are available"""
|
|
@@ -668,25 +756,20 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
|
|
| 668 |
|
| 669 |
progress(0.9, desc="Processing complete, preparing output...")
|
| 670 |
|
| 671 |
-
# Look for output files
|
| 672 |
obj_files = []
|
| 673 |
glb_files = []
|
| 674 |
image_files = []
|
| 675 |
|
| 676 |
-
print("Searching for output files...")
|
| 677 |
|
| 678 |
# First check for mesh files in mesh_final directory (priority)
|
| 679 |
mesh_final_dir = Path(temp_dir) / "mesh_final"
|
| 680 |
if mesh_final_dir.exists():
|
| 681 |
print(f"Found mesh_final directory at {mesh_final_dir}")
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
obj_files.append(str(file_path))
|
| 686 |
-
print(f"Found OBJ file: {file_path}")
|
| 687 |
-
elif file_path.suffix.lower() == '.glb':
|
| 688 |
-
glb_files.append(str(file_path))
|
| 689 |
-
print(f"Found GLB file: {file_path}")
|
| 690 |
else:
|
| 691 |
print("mesh_final directory not found")
|
| 692 |
|
|
@@ -694,14 +777,9 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
|
|
| 694 |
for mesh_dir in Path(temp_dir).glob("mesh_*"):
|
| 695 |
if mesh_dir.is_dir() and mesh_dir.name != 'mesh_final':
|
| 696 |
print(f"Checking directory: {mesh_dir}")
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
obj_files.append(str(file_path))
|
| 701 |
-
print(f"Found OBJ file: {file_path}")
|
| 702 |
-
elif file_path.suffix.lower() == '.glb':
|
| 703 |
-
glb_files.append(str(file_path))
|
| 704 |
-
print(f"Found GLB file: {file_path}")
|
| 705 |
|
| 706 |
# Collect image files for visualization
|
| 707 |
for file_path in Path(temp_dir).rglob("*"):
|
|
@@ -732,6 +810,28 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
|
|
| 732 |
# Return None instead of an error string to avoid file not found errors with Gradio
|
| 733 |
return None
|
| 734 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
def create_interface():
|
| 736 |
"""
|
| 737 |
Create the Gradio interface with simplified components
|
|
@@ -748,6 +848,7 @@ def create_interface():
|
|
| 748 |
2. For **Text** mode: Enter descriptions of your target and base garment styles
|
| 749 |
3. For **Image to Mesh** mode: Upload an image to generate a 3D mesh directly and select a base mesh type
|
| 750 |
4. Click "Generate 3D Garment" to create your 3D mesh file
|
|
|
|
| 751 |
""")
|
| 752 |
|
| 753 |
with gr.Row():
|
|
@@ -838,12 +939,21 @@ def create_interface():
|
|
| 838 |
generate_btn = gr.Button("Generate 3D Garment")
|
| 839 |
|
| 840 |
with gr.Column():
|
|
|
|
| 841 |
output = gr.File(
|
| 842 |
-
label="Generated 3D Garment",
|
| 843 |
file_types=[".obj", ".glb", ".png", ".jpg"],
|
| 844 |
file_count="single"
|
| 845 |
)
|
| 846 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
gr.Markdown("""
|
| 848 |
## Tips:
|
| 849 |
|
|
@@ -851,6 +961,8 @@ def create_interface():
|
|
| 851 |
- For image to mesh mode: Use clear, front-facing garment images to generate a 3D mesh directly
|
| 852 |
- Choose the appropriate base mesh type that matches your target garment
|
| 853 |
- Higher epochs = better quality but longer processing time
|
|
|
|
|
|
|
| 854 |
- Output files can be downloaded by clicking on them
|
| 855 |
|
| 856 |
Processing may take several minutes.
|
|
@@ -887,33 +999,55 @@ def create_interface():
|
|
| 887 |
status_msg
|
| 888 |
)
|
| 889 |
|
| 890 |
-
# Function to handle processing with better error feedback
|
| 891 |
def process_with_feedback(*args):
|
| 892 |
try:
|
| 893 |
# Check if processing engine is available
|
| 894 |
if loop is None:
|
| 895 |
-
return None, "❌ ERROR: Processing engine not available. Please check that all dependencies are properly installed."
|
| 896 |
|
| 897 |
result = process_garment(*args)
|
| 898 |
if result is None:
|
| 899 |
-
return None, "Processing completed but no output files were generated. Please check the logs for more details."
|
| 900 |
elif isinstance(result, str) and result.startswith("Error:"):
|
| 901 |
-
# Return None for the file
|
| 902 |
-
return None, result
|
| 903 |
elif isinstance(result, str) and os.path.exists(result):
|
| 904 |
-
# Valid file path
|
| 905 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
elif isinstance(result, str):
|
| 907 |
# Some other string that's not an error and not a file path
|
| 908 |
-
return None, f"Unexpected result: {result}"
|
| 909 |
else:
|
| 910 |
# Should be a file path or None
|
| 911 |
-
return result, "🎉 Processing completed successfully! Download your 3D garment file below."
|
| 912 |
except Exception as e:
|
| 913 |
import traceback
|
| 914 |
print(f"Error in interface: {str(e)}")
|
| 915 |
print(traceback.format_exc())
|
| 916 |
-
return None, f"❌ Error: {str(e)}"
|
| 917 |
|
| 918 |
# Toggle visibility based on input mode with better feedback
|
| 919 |
input_type.change(
|
|
@@ -923,7 +1057,7 @@ def create_interface():
|
|
| 923 |
show_progress=True
|
| 924 |
)
|
| 925 |
|
| 926 |
-
# Connect the button to the processing function with error handling
|
| 927 |
generate_btn.click(
|
| 928 |
fn=process_with_feedback,
|
| 929 |
inputs=[
|
|
@@ -938,7 +1072,32 @@ def create_interface():
|
|
| 938 |
clip_weight,
|
| 939 |
delta_clip_weight
|
| 940 |
],
|
| 941 |
-
outputs=[output, status_output]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
)
|
| 943 |
|
| 944 |
return interface
|
|
|
|
| 27 |
except:
|
| 28 |
pass
|
| 29 |
|
| 30 |
+
# Function to convert OBJ to GLB format
|
| 31 |
+
def convert_obj_to_glb(obj_file_path, glb_file_path=None):
|
| 32 |
+
"""
|
| 33 |
+
Convert OBJ file to GLB format using trimesh
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
obj_file_path: Path to the OBJ file
|
| 37 |
+
glb_file_path: Path for the output GLB file (optional)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Path to the created GLB file or None if conversion failed
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
import trimesh
|
| 44 |
+
print(f"Converting {obj_file_path} to GLB format...")
|
| 45 |
+
|
| 46 |
+
# Check if input file exists
|
| 47 |
+
if not os.path.exists(obj_file_path):
|
| 48 |
+
print(f"Error: OBJ file {obj_file_path} does not exist")
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
# Load the OBJ file
|
| 52 |
+
mesh = trimesh.load(obj_file_path)
|
| 53 |
+
|
| 54 |
+
# If no GLB path specified, create one in the same directory
|
| 55 |
+
if glb_file_path is None:
|
| 56 |
+
glb_file_path = str(Path(obj_file_path).with_suffix('.glb'))
|
| 57 |
+
|
| 58 |
+
# Export as GLB
|
| 59 |
+
mesh.export(glb_file_path, file_type='glb')
|
| 60 |
+
print(f"Successfully converted to GLB: {glb_file_path}")
|
| 61 |
+
return glb_file_path
|
| 62 |
+
|
| 63 |
+
except ImportError:
|
| 64 |
+
print("trimesh not available for GLB conversion")
|
| 65 |
+
return None
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Error converting OBJ to GLB: {e}")
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
# Function to ensure both OBJ and GLB files are created
|
| 71 |
+
def ensure_mesh_formats(output_dir):
|
| 72 |
+
"""
|
| 73 |
+
Ensure both OBJ and GLB files are available in the output directory
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
output_dir: Directory containing mesh files
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
tuple: (obj_files, glb_files) - lists of available files
|
| 80 |
+
"""
|
| 81 |
+
obj_files = []
|
| 82 |
+
glb_files = []
|
| 83 |
+
|
| 84 |
+
output_path = Path(output_dir)
|
| 85 |
+
|
| 86 |
+
if not output_path.exists():
|
| 87 |
+
print(f"Warning: Output directory {output_dir} does not exist")
|
| 88 |
+
return obj_files, glb_files
|
| 89 |
+
|
| 90 |
+
# Find all OBJ files
|
| 91 |
+
for obj_file in output_path.rglob("*.obj"):
|
| 92 |
+
obj_files.append(str(obj_file))
|
| 93 |
+
print(f"Found OBJ file: {obj_file}")
|
| 94 |
+
|
| 95 |
+
# Try to create corresponding GLB file
|
| 96 |
+
glb_file = obj_file.with_suffix('.glb')
|
| 97 |
+
if not glb_file.exists():
|
| 98 |
+
print(f"Creating GLB file for {obj_file}")
|
| 99 |
+
glb_path = convert_obj_to_glb(str(obj_file), str(glb_file))
|
| 100 |
+
if glb_path:
|
| 101 |
+
glb_files.append(glb_path)
|
| 102 |
+
print(f"Successfully created GLB: {glb_path}")
|
| 103 |
+
else:
|
| 104 |
+
print(f"Failed to create GLB for {obj_file}")
|
| 105 |
+
else:
|
| 106 |
+
glb_files.append(str(glb_file))
|
| 107 |
+
print(f"GLB file already exists: {glb_file}")
|
| 108 |
+
|
| 109 |
+
# Also check for existing GLB files that might not have corresponding OBJ
|
| 110 |
+
for glb_file in output_path.rglob("*.glb"):
|
| 111 |
+
if str(glb_file) not in glb_files:
|
| 112 |
+
glb_files.append(str(glb_file))
|
| 113 |
+
print(f"Found standalone GLB file: {glb_file}")
|
| 114 |
+
|
| 115 |
+
print(f"Total files found: {len(obj_files)} OBJ files, {len(glb_files)} GLB files")
|
| 116 |
+
return obj_files, glb_files
|
| 117 |
+
|
| 118 |
# Check if complex dependencies are installed
|
| 119 |
def check_complex_dependencies():
|
| 120 |
"""Check if complex dependencies are available"""
|
|
|
|
| 756 |
|
| 757 |
progress(0.9, desc="Processing complete, preparing output...")
|
| 758 |
|
| 759 |
+
# Look for output files and ensure both OBJ and GLB formats are available
|
| 760 |
obj_files = []
|
| 761 |
glb_files = []
|
| 762 |
image_files = []
|
| 763 |
|
| 764 |
+
print("Searching for output files and ensuring GLB conversion...")
|
| 765 |
|
| 766 |
# First check for mesh files in mesh_final directory (priority)
|
| 767 |
mesh_final_dir = Path(temp_dir) / "mesh_final"
|
| 768 |
if mesh_final_dir.exists():
|
| 769 |
print(f"Found mesh_final directory at {mesh_final_dir}")
|
| 770 |
+
# Ensure both OBJ and GLB formats are available
|
| 771 |
+
obj_files, glb_files = ensure_mesh_formats(mesh_final_dir)
|
| 772 |
+
print(f"Found {len(obj_files)} OBJ files and {len(glb_files)} GLB files in mesh_final")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
else:
|
| 774 |
print("mesh_final directory not found")
|
| 775 |
|
|
|
|
| 777 |
for mesh_dir in Path(temp_dir).glob("mesh_*"):
|
| 778 |
if mesh_dir.is_dir() and mesh_dir.name != 'mesh_final':
|
| 779 |
print(f"Checking directory: {mesh_dir}")
|
| 780 |
+
dir_obj_files, dir_glb_files = ensure_mesh_formats(mesh_dir)
|
| 781 |
+
obj_files.extend(dir_obj_files)
|
| 782 |
+
glb_files.extend(dir_glb_files)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
|
| 784 |
# Collect image files for visualization
|
| 785 |
for file_path in Path(temp_dir).rglob("*"):
|
|
|
|
| 810 |
# Return None instead of an error string to avoid file not found errors with Gradio
|
| 811 |
return None
|
| 812 |
|
| 813 |
+
def create_combined_mesh_output(output_dir):
|
| 814 |
+
"""
|
| 815 |
+
Create a combined output showing both OBJ and GLB files if available
|
| 816 |
+
|
| 817 |
+
Args:
|
| 818 |
+
output_dir: Directory containing mesh files
|
| 819 |
+
|
| 820 |
+
Returns:
|
| 821 |
+
tuple: (primary_file, secondary_file, status_message)
|
| 822 |
+
"""
|
| 823 |
+
obj_files, glb_files = ensure_mesh_formats(output_dir)
|
| 824 |
+
|
| 825 |
+
if glb_files and obj_files:
|
| 826 |
+
# Both formats available - return GLB as primary (better for web viewing)
|
| 827 |
+
return glb_files[0], obj_files[0], "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
|
| 828 |
+
elif glb_files:
|
| 829 |
+
return glb_files[0], None, "🎉 Success! GLB file generated and ready for download."
|
| 830 |
+
elif obj_files:
|
| 831 |
+
return obj_files[0], None, "🎉 Success! OBJ file generated and ready for download."
|
| 832 |
+
else:
|
| 833 |
+
return None, None, "❌ No mesh files were generated. Please check the processing logs."
|
| 834 |
+
|
| 835 |
def create_interface():
|
| 836 |
"""
|
| 837 |
Create the Gradio interface with simplified components
|
|
|
|
| 848 |
2. For **Text** mode: Enter descriptions of your target and base garment styles
|
| 849 |
3. For **Image to Mesh** mode: Upload an image to generate a 3D mesh directly and select a base mesh type
|
| 850 |
4. Click "Generate 3D Garment" to create your 3D mesh file
|
| 851 |
+
5. **GLB files** are automatically generated for better web viewing and virtual try-on compatibility
|
| 852 |
""")
|
| 853 |
|
| 854 |
with gr.Row():
|
|
|
|
| 939 |
generate_btn = gr.Button("Generate 3D Garment")
|
| 940 |
|
| 941 |
with gr.Column():
|
| 942 |
+
# Primary output (GLB preferred)
|
| 943 |
output = gr.File(
|
| 944 |
+
label="Generated 3D Garment (GLB/OBJ)",
|
| 945 |
file_types=[".obj", ".glb", ".png", ".jpg"],
|
| 946 |
file_count="single"
|
| 947 |
)
|
| 948 |
|
| 949 |
+
# Secondary output (OBJ if GLB is primary)
|
| 950 |
+
secondary_output = gr.File(
|
| 951 |
+
label="Alternative Format (OBJ/GLB)",
|
| 952 |
+
file_types=[".obj", ".glb"],
|
| 953 |
+
file_count="single",
|
| 954 |
+
visible=False
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
gr.Markdown("""
|
| 958 |
## Tips:
|
| 959 |
|
|
|
|
| 961 |
- For image to mesh mode: Use clear, front-facing garment images to generate a 3D mesh directly
|
| 962 |
- Choose the appropriate base mesh type that matches your target garment
|
| 963 |
- Higher epochs = better quality but longer processing time
|
| 964 |
+
- **GLB files** are automatically generated for better web viewing and virtual try-on compatibility
|
| 965 |
+
- **OBJ files** are also available for traditional 3D software compatibility
|
| 966 |
- Output files can be downloaded by clicking on them
|
| 967 |
|
| 968 |
Processing may take several minutes.
|
|
|
|
| 999 |
status_msg
|
| 1000 |
)
|
| 1001 |
|
| 1002 |
+
# Function to handle processing with better error feedback and dual output
|
| 1003 |
def process_with_feedback(*args):
|
| 1004 |
try:
|
| 1005 |
# Check if processing engine is available
|
| 1006 |
if loop is None:
|
| 1007 |
+
return None, None, "❌ ERROR: Processing engine not available. Please check that all dependencies are properly installed."
|
| 1008 |
|
| 1009 |
result = process_garment(*args)
|
| 1010 |
if result is None:
|
| 1011 |
+
return None, None, "Processing completed but no output files were generated. Please check the logs for more details."
|
| 1012 |
elif isinstance(result, str) and result.startswith("Error:"):
|
| 1013 |
+
# Return None for the file outputs and the error message for status
|
| 1014 |
+
return None, None, result
|
| 1015 |
elif isinstance(result, str) and os.path.exists(result):
|
| 1016 |
+
# Valid file path - check if we can create a combined output
|
| 1017 |
+
result_path = Path(result)
|
| 1018 |
+
if result_path.suffix.lower() == '.glb':
|
| 1019 |
+
# GLB file - try to find corresponding OBJ
|
| 1020 |
+
obj_file = result_path.with_suffix('.obj')
|
| 1021 |
+
if obj_file.exists():
|
| 1022 |
+
return result, str(obj_file), "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
|
| 1023 |
+
else:
|
| 1024 |
+
return result, None, "🎉 Success! GLB file generated and ready for download."
|
| 1025 |
+
elif result_path.suffix.lower() == '.obj':
|
| 1026 |
+
# OBJ file - try to find corresponding GLB or create one
|
| 1027 |
+
glb_file = result_path.with_suffix('.glb')
|
| 1028 |
+
if glb_file.exists():
|
| 1029 |
+
return str(glb_file), result, "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
|
| 1030 |
+
else:
|
| 1031 |
+
# Try to convert OBJ to GLB
|
| 1032 |
+
glb_path = convert_obj_to_glb(result)
|
| 1033 |
+
if glb_path:
|
| 1034 |
+
return glb_path, result, "🎉 Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available."
|
| 1035 |
+
else:
|
| 1036 |
+
return result, None, "🎉 Success! OBJ file generated and ready for download."
|
| 1037 |
+
else:
|
| 1038 |
+
# Some other file type
|
| 1039 |
+
return result, None, "🎉 Processing completed successfully! Download your file below."
|
| 1040 |
elif isinstance(result, str):
|
| 1041 |
# Some other string that's not an error and not a file path
|
| 1042 |
+
return None, None, f"Unexpected result: {result}"
|
| 1043 |
else:
|
| 1044 |
# Should be a file path or None
|
| 1045 |
+
return result, None, "🎉 Processing completed successfully! Download your 3D garment file below."
|
| 1046 |
except Exception as e:
|
| 1047 |
import traceback
|
| 1048 |
print(f"Error in interface: {str(e)}")
|
| 1049 |
print(traceback.format_exc())
|
| 1050 |
+
return None, None, f"❌ Error: {str(e)}"
|
| 1051 |
|
| 1052 |
# Toggle visibility based on input mode with better feedback
|
| 1053 |
input_type.change(
|
|
|
|
| 1057 |
show_progress=True
|
| 1058 |
)
|
| 1059 |
|
| 1060 |
+
# Connect the button to the processing function with error handling and dual output
|
| 1061 |
generate_btn.click(
|
| 1062 |
fn=process_with_feedback,
|
| 1063 |
inputs=[
|
|
|
|
| 1072 |
clip_weight,
|
| 1073 |
delta_clip_weight
|
| 1074 |
],
|
| 1075 |
+
outputs=[output, secondary_output, status_output]
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
# Update secondary output visibility when primary output changes
|
| 1079 |
+
def update_secondary_visibility(primary_file):
|
| 1080 |
+
"""Update secondary output visibility based on whether both formats are available"""
|
| 1081 |
+
if primary_file is not None and primary_file != "":
|
| 1082 |
+
# Check if there's a corresponding file in the other format
|
| 1083 |
+
primary_path = Path(primary_file)
|
| 1084 |
+
if primary_path.suffix.lower() == '.glb':
|
| 1085 |
+
# Check if corresponding OBJ exists
|
| 1086 |
+
obj_file = primary_path.with_suffix('.obj')
|
| 1087 |
+
if obj_file.exists():
|
| 1088 |
+
return gr.update(visible=True)
|
| 1089 |
+
elif primary_path.suffix.lower() == '.obj':
|
| 1090 |
+
# Check if corresponding GLB exists
|
| 1091 |
+
glb_file = primary_path.with_suffix('.glb')
|
| 1092 |
+
if glb_file.exists():
|
| 1093 |
+
return gr.update(visible=True)
|
| 1094 |
+
return gr.update(visible=False)
|
| 1095 |
+
|
| 1096 |
+
# Connect the secondary output visibility to the primary output
|
| 1097 |
+
output.change(
|
| 1098 |
+
fn=update_secondary_visibility,
|
| 1099 |
+
inputs=[output],
|
| 1100 |
+
outputs=[secondary_output]
|
| 1101 |
)
|
| 1102 |
|
| 1103 |
return interface
|
loop.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import clip
|
| 2 |
import kornia
|
| 3 |
import os
|
| 4 |
import sys
|
|
@@ -108,15 +107,26 @@ def loop(cfg):
|
|
| 108 |
fashion_image = cfg.get('fashion_image', False)
|
| 109 |
fashion_text = cfg.get('fashion_text', True) # Default to fashion text mode
|
| 110 |
use_target_mesh = cfg.get('use_target_mesh', True)
|
| 111 |
-
CLIP_embeddings = False
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
else:
|
| 117 |
fclip = FashionCLIP('fashion-clip')
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# Use FashionCLIP for all modes to avoid CLIP loading issues
|
| 122 |
if fashion_image:
|
|
@@ -434,7 +444,7 @@ def loop(cfg):
|
|
| 434 |
r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
|
| 435 |
logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
|
| 436 |
|
| 437 |
-
if cfg.consistency_loss_weight != 0:
|
| 438 |
consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
|
| 439 |
else:
|
| 440 |
consistency_loss = r_loss
|
|
|
|
|
|
|
| 1 |
import kornia
|
| 2 |
import os
|
| 3 |
import sys
|
|
|
|
| 107 |
fashion_image = cfg.get('fashion_image', False)
|
| 108 |
fashion_text = cfg.get('fashion_text', True) # Default to fashion text mode
|
| 109 |
use_target_mesh = cfg.get('use_target_mesh', True)
|
| 110 |
+
CLIP_embeddings = False # Always use FashionCLIP to avoid CLIP issues
|
| 111 |
|
| 112 |
+
# Always use FashionCLIP to avoid CLIP loading issues
|
| 113 |
+
print('Loading FashionCLIP model...')
|
| 114 |
+
try:
|
|
|
|
| 115 |
fclip = FashionCLIP('fashion-clip')
|
| 116 |
+
print('FashionCLIP loaded successfully')
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f'Error loading FashionCLIP: {e}')
|
| 119 |
+
raise RuntimeError(f"Failed to load FashionCLIP: {e}")
|
| 120 |
+
|
| 121 |
+
# Load CLIPVisualEncoder with error handling
|
| 122 |
+
print('Loading CLIPVisualEncoder...')
|
| 123 |
+
try:
|
| 124 |
+
fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
|
| 125 |
+
print('CLIPVisualEncoder loaded successfully')
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f'Error loading CLIPVisualEncoder: {e}')
|
| 128 |
+
print('Continuing without CLIPVisualEncoder...')
|
| 129 |
+
fe = None
|
| 130 |
|
| 131 |
# Use FashionCLIP for all modes to avoid CLIP loading issues
|
| 132 |
if fashion_image:
|
|
|
|
| 444 |
r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
|
| 445 |
logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
|
| 446 |
|
| 447 |
+
if cfg.consistency_loss_weight != 0 and fe is not None:
|
| 448 |
consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
|
| 449 |
else:
|
| 450 |
consistency_loss = r_loss
|
utilities/clip_spatial.py
CHANGED
|
@@ -3,11 +3,18 @@ import math
|
|
| 3 |
import types
|
| 4 |
import typing
|
| 5 |
|
| 6 |
-
import clip
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
from torchvision import models, transforms
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# code lifted from CLIPasso
|
| 12 |
|
| 13 |
# For ViT
|
|
@@ -25,10 +32,50 @@ class CLIPVisualEncoder(nn.Module):
|
|
| 25 |
|
| 26 |
|
| 27 |
def load_model(self, model_name, device):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
@staticmethod
|
| 34 |
def _fix_pos_enc(patch_size: int, stride_hw: typing.Tuple[int, int]):
|
|
|
|
| 3 |
import types
|
| 4 |
import typing
|
| 5 |
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from torchvision import models, transforms
|
| 9 |
|
| 10 |
+
# Try to import CLIP, but handle import errors gracefully
|
| 11 |
+
try:
|
| 12 |
+
import clip
|
| 13 |
+
CLIP_AVAILABLE = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
print("Warning: CLIP not available, using FashionCLIP fallback")
|
| 16 |
+
CLIP_AVAILABLE = False
|
| 17 |
+
|
| 18 |
# code lifted from CLIPasso
|
| 19 |
|
| 20 |
# For ViT
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def load_model(self, model_name, device):
|
| 35 |
+
if CLIP_AVAILABLE:
|
| 36 |
+
try:
|
| 37 |
+
model, preprocess = clip.load(model_name, device=device)
|
| 38 |
+
self.model = model.visual
|
| 39 |
+
self.mean = torch.tensor(preprocess.transforms[-1].mean, device=device)
|
| 40 |
+
self.std = torch.tensor(preprocess.transforms[-1].std, device=device)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error loading CLIP model: {e}")
|
| 43 |
+
print("Falling back to FashionCLIP...")
|
| 44 |
+
self._load_fashion_clip_fallback(device)
|
| 45 |
+
else:
|
| 46 |
+
print("CLIP not available, using FashionCLIP fallback...")
|
| 47 |
+
self._load_fashion_clip_fallback(device)
|
| 48 |
+
|
| 49 |
+
def _load_fashion_clip_fallback(self, device):
|
| 50 |
+
"""Fallback method using FashionCLIP when regular CLIP is not available"""
|
| 51 |
+
try:
|
| 52 |
+
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP
|
| 53 |
+
fclip = FashionCLIP('fashion-clip')
|
| 54 |
+
self.model = fclip.model.vision_model
|
| 55 |
+
# Use standard CLIP mean and std values
|
| 56 |
+
self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
|
| 57 |
+
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
|
| 58 |
+
print("Successfully loaded FashionCLIP fallback")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading FashionCLIP fallback: {e}")
|
| 61 |
+
# Create a dummy model if all else fails
|
| 62 |
+
self._create_dummy_model(device)
|
| 63 |
+
|
| 64 |
+
def _create_dummy_model(self, device):
|
| 65 |
+
"""Create a dummy model when all CLIP options fail"""
|
| 66 |
+
print("Creating dummy CLIP model - functionality will be limited")
|
| 67 |
+
# Create a simple dummy model structure
|
| 68 |
+
class DummyModel:
|
| 69 |
+
def __init__(self):
|
| 70 |
+
self.conv1 = nn.Conv2d(3, 768, kernel_size=16, stride=16)
|
| 71 |
+
self.class_embedding = nn.Parameter(torch.randn(768))
|
| 72 |
+
self.positional_embedding = nn.Parameter(torch.randn(197, 768))
|
| 73 |
+
self.ln_pre = nn.LayerNorm(768)
|
| 74 |
+
self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(768, 12), 12)
|
| 75 |
+
|
| 76 |
+
self.model = DummyModel()
|
| 77 |
+
self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
|
| 78 |
+
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
|
| 79 |
|
| 80 |
@staticmethod
|
| 81 |
def _fix_pos_enc(patch_size: int, stride_hw: typing.Tuple[int, int]):
|
utils.py
CHANGED
|
@@ -77,6 +77,11 @@ def get_og_mesh(mesh_path, output_path, triangulate_flag, bsdf_flag, mesh_name='
|
|
| 77 |
|
| 78 |
def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device):
|
| 79 |
# Consistency loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Get mapping from vertex to pixels
|
| 81 |
curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224)
|
| 82 |
for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)):
|
|
|
|
| 77 |
|
| 78 |
def compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device):
|
| 79 |
# Consistency loss
|
| 80 |
+
# Check if fe is available
|
| 81 |
+
if fe is None:
|
| 82 |
+
print("Warning: CLIPVisualEncoder not available, skipping consistency loss")
|
| 83 |
+
return torch.tensor(0.0, device=device)
|
| 84 |
+
|
| 85 |
# Get mapping from vertex to pixels
|
| 86 |
curr_vp_map = get_vp_map(final_mesh.v_pos, params_camera['mvp'], 224)
|
| 87 |
for idx, rast_faces in enumerate(train_rast_map[:, :, :, 3].view(cfg.batch_size, -1)):
|