Ali Mohsin commited on
Commit
24ea486
Β·
1 Parent(s): 6086b2f

Next level fix

Browse files
app.py CHANGED
@@ -12,10 +12,191 @@ from pydantic import BaseModel
12
  from PIL import Image
13
  from starlette.staticfiles import StaticFiles
14
  import threading
 
15
 
16
  from inference import InferenceService
17
  from utils.data_fetch import ensure_dataset_ready
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  AI_API_KEY = os.getenv("AI_API_KEY")
21
 
@@ -254,7 +435,9 @@ def start_training_advanced(
254
  if not DATASET_ROOT:
255
  return "❌ Dataset not ready. Please wait for bootstrap to complete."
256
 
 
257
  def _runner():
 
258
  try:
259
  import subprocess
260
  import json
@@ -327,10 +510,10 @@ def start_training_advanced(
327
  json.dump(vit_config, f, indent=2)
328
 
329
  # Train ResNet with custom parameters
330
- train_log.value = f"πŸš€ Starting ResNet training with custom parameters...\n"
331
- train_log.value += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
332
- train_log.value += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
333
- train_log.value += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
334
 
335
  resnet_cmd = [
336
  "python", "train_resnet.py",
@@ -350,16 +533,16 @@ def start_training_advanced(
350
  result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
351
 
352
  if result.returncode == 0:
353
- train_log.value += "βœ… ResNet training completed successfully!\n\n"
354
  else:
355
- train_log.value += f"❌ ResNet training failed: {result.stderr}\n\n"
356
  return
357
 
358
  # Train ViT with custom parameters
359
- train_log.value += f"πŸš€ Starting ViT training with custom parameters...\n"
360
- train_log.value += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
361
- train_log.value += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
362
- train_log.value += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
363
 
364
  vit_cmd = [
365
  "python", "train_vit_triplet.py",
@@ -376,47 +559,87 @@ def start_training_advanced(
376
  result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
377
 
378
  if result.returncode == 0:
379
- train_log.value += "βœ… ViT training completed successfully!\n\n"
380
- train_log.value += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
381
- train_log.value += "πŸ”„ Reloading models for inference...\n"
382
  service.reload_models()
383
- train_log.value += "βœ… Models reloaded and ready for inference!\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  else:
385
- train_log.value += f"❌ ViT training failed: {result.stderr}\n"
386
 
387
  except Exception as e:
388
- train_log.value += f"\n❌ Training error: {str(e)}"
389
 
390
  threading.Thread(target=_runner, daemon=True).start()
391
- return "πŸš€ Advanced training started with custom parameters! Check the log below for progress."
392
 
393
 
394
  def start_training_simple(res_epochs: int, vit_epochs: int):
395
  """Start simple training with basic parameters."""
 
396
  def _runner():
 
397
  try:
398
  import subprocess
399
  if not DATASET_ROOT:
400
- train_log.value = "Dataset not ready."
401
  return
402
  export_dir = os.getenv("EXPORT_DIR", "models/exports")
403
  os.makedirs(export_dir, exist_ok=True)
404
- train_log.value = "Training ResNet…\n"
405
  subprocess.run([
406
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
407
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
408
  ], check=False)
409
- train_log.value += "\nTraining ViT (triplet)…\n"
410
  subprocess.run([
411
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
412
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
413
  ], check=False)
414
  service.reload_models()
415
- train_log.value += "\nDone. Artifacts in models/exports."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  except Exception as e:
417
- train_log.value += f"\nError: {e}"
418
  threading.Thread(target=_runner, daemon=True).start()
419
- return "Started"
420
 
421
 
422
  with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
@@ -563,6 +786,56 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
563
  outputs=train_log
564
  )
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  with gr.Tab("πŸ”§ Simple Training"):
567
  gr.Markdown("### πŸš€ Quick Training with Default Parameters\nFast training with proven configurations for immediate results.")
568
  epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
@@ -577,23 +850,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
577
  btn = gr.Button("Compute Embeddings")
578
  btn.click(fn=gradio_embed, inputs=inp, outputs=out)
579
 
580
- with gr.Tab("πŸ“₯ Downloads"):
581
- gr.Markdown("### πŸ“¦ Download Trained Models and Artifacts\nAccess all exported models, checkpoints, and training metrics.")
582
- file_list = gr.JSON(label="Available Artifacts")
583
- def list_artifacts_for_ui():
584
- export_dir = os.getenv("EXPORT_DIR", "models/exports")
585
- files = []
586
- if os.path.isdir(export_dir):
587
- for fn in os.listdir(export_dir):
588
- if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
589
- files.append({
590
- "name": fn,
591
- "path": f"{export_dir}/{fn}",
592
- "url": f"/files/{fn}",
593
- })
594
- return {"artifacts": files}
595
- refresh = gr.Button("πŸ”„ Refresh Artifacts")
596
- refresh.click(fn=lambda: list_artifacts_for_ui(), inputs=[], outputs=file_list)
597
 
598
  with gr.Tab("πŸ“ˆ Status"):
599
  gr.Markdown("### 🚦 System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.")
 
12
  from PIL import Image
13
  from starlette.staticfiles import StaticFiles
14
  import threading
15
+ import json
16
 
17
  from inference import InferenceService
18
  from utils.data_fetch import ensure_dataset_ready
19
 
20
+ # Global state
21
+ BOOT_STATUS = "starting"
22
+ DATASET_ROOT: Optional[str] = None
23
+
24
+ def get_artifact_overview():
25
+ """Get comprehensive artifact overview."""
26
+ try:
27
+ from utils.artifact_manager import create_artifact_manager
28
+ manager = create_artifact_manager()
29
+ return manager.get_artifact_summary()
30
+ except Exception as e:
31
+ return {"error": str(e)}
32
+
33
+ def export_artifact_summary():
34
+ """Export artifact summary as JSON file."""
35
+ try:
36
+ from utils.artifact_manager import create_artifact_manager
37
+ manager = create_artifact_manager()
38
+ summary = manager.get_artifact_summary()
39
+
40
+ # Save to exports directory
41
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
42
+ os.makedirs(export_dir, exist_ok=True)
43
+
44
+ summary_path = os.path.join(export_dir, "system_summary.json")
45
+ with open(summary_path, 'w') as f:
46
+ json.dump(summary, f, indent=2)
47
+
48
+ return summary_path
49
+ except Exception as e:
50
+ return None
51
+
52
+ def create_download_package(package_type: str):
53
+ """Create a downloadable package."""
54
+ try:
55
+ from utils.artifact_manager import create_artifact_manager
56
+ manager = create_artifact_manager()
57
+
58
+ # Extract package type from the dropdown choice
59
+ if "complete" in package_type:
60
+ pkg_type = "complete"
61
+ elif "splits_only" in package_type:
62
+ pkg_type = "splits_only"
63
+ elif "models_only" in package_type:
64
+ pkg_type = "models_only"
65
+ else:
66
+ return f"❌ Invalid package type: {package_type}", get_available_packages()
67
+
68
+ package_path = manager.create_download_package(pkg_type)
69
+ package_name = os.path.basename(package_path)
70
+
71
+ return f"βœ… Package created: {package_name}", get_available_packages()
72
+
73
+ except Exception as e:
74
+ return f"❌ Failed to create package: {e}", get_available_packages()
75
+
76
+ def get_available_packages():
77
+ """Get list of available packages."""
78
+ try:
79
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
80
+ packages = []
81
+
82
+ if os.path.exists(export_dir):
83
+ for file in os.listdir(export_dir):
84
+ if file.endswith((".tar.gz", ".zip")):
85
+ file_path = os.path.join(export_dir, file)
86
+ packages.append({
87
+ "name": file,
88
+ "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
89
+ "path": file_path,
90
+ "url": f"/files/{file}"
91
+ })
92
+
93
+ return {"packages": packages}
94
+ except Exception as e:
95
+ return {"error": str(e)}
96
+
97
+ def get_individual_files():
98
+ """Get list of individual downloadable files."""
99
+ try:
100
+ from utils.artifact_manager import create_artifact_manager
101
+ manager = create_artifact_manager()
102
+ files = manager.get_downloadable_files()
103
+
104
+ # Group by category
105
+ categorized = {}
106
+ for file in files:
107
+ category = file["category"]
108
+ if category not in categorized:
109
+ categorized[category] = []
110
+ categorized[category].append(file)
111
+
112
+ return categorized
113
+ except Exception as e:
114
+ return {"error": str(e)}
115
+
116
+ def download_all_files():
117
+ """Download all files as a ZIP archive."""
118
+ try:
119
+ from utils.artifact_manager import create_artifact_manager
120
+ manager = create_artifact_manager()
121
+ files = manager.get_downloadable_files()
122
+
123
+ # Create ZIP with all files
124
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
125
+ os.makedirs(export_dir, exist_ok=True)
126
+
127
+ zip_path = os.path.join(export_dir, "all_artifacts.zip")
128
+ import zipfile
129
+
130
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
131
+ for file in files:
132
+ if os.path.exists(file["path"]):
133
+ zipf.write(file["path"], file["name"])
134
+
135
+ return zip_path
136
+ except Exception as e:
137
+ return None
138
+
139
+ def get_training_status():
140
+ """Get current training status from the monitor."""
141
+ try:
142
+ from training_monitor import create_monitor
143
+ monitor = create_monitor()
144
+ status = monitor.get_status()
145
+ return status if status else {"status": "no-training"}
146
+ except Exception as e:
147
+ return {"status": "error", "error": str(e)}
148
+
149
+ def push_splits_to_hf(token, username):
150
+ """Push splits to HF Hub."""
151
+ if not token or not username:
152
+ return "❌ Please provide HF token and username"
153
+
154
+ try:
155
+ from utils.hf_hub_integration import create_hf_integration
156
+ hf = create_hf_integration(token)
157
+ result = hf.upload_splits_to_hf()
158
+
159
+ if result.get("success"):
160
+ return f"βœ… Successfully uploaded splits to {username}/Dressify-Helper"
161
+ else:
162
+ return f"❌ Failed to upload splits: {result.get('error', 'Unknown error')}"
163
+ except Exception as e:
164
+ return f"❌ Upload failed: {e}"
165
+
166
+ def push_models_to_hf(token, username):
167
+ """Push models to HF Hub."""
168
+ if not token or not username:
169
+ return "❌ Please provide HF token and username"
170
+
171
+ try:
172
+ from utils.hf_hub_integration import create_hf_integration
173
+ hf = create_hf_integration(token)
174
+ result = hf.upload_models_to_hf()
175
+
176
+ if result.get("success"):
177
+ return f"βœ… Successfully uploaded models to {username}/dressify-models"
178
+ else:
179
+ return f"❌ Failed to upload models: {result.get('error', 'Unknown error')}"
180
+ except Exception as e:
181
+ return f"❌ Upload failed: {e}"
182
+
183
+ def push_everything_to_hf(token, username):
184
+ """Push everything to HF Hub."""
185
+ if not token or not username:
186
+ return "❌ Please provide HF token and username"
187
+
188
+ try:
189
+ from utils.hf_hub_integration import create_hf_integration
190
+ hf = create_hf_integration(token)
191
+ result = hf.upload_everything_to_hf()
192
+
193
+ if result.get("success"):
194
+ return f"βœ… Successfully uploaded everything to HF Hub"
195
+ else:
196
+ return f"❌ Failed to upload everything: {result.get('error', 'Unknown error')}"
197
+ except Exception as e:
198
+ return f"❌ Upload failed: {e}"
199
+
200
 
201
  AI_API_KEY = os.getenv("AI_API_KEY")
202
 
 
435
  if not DATASET_ROOT:
436
  return "❌ Dataset not ready. Please wait for bootstrap to complete."
437
 
438
+ log_message = "πŸš€ Advanced training started with custom parameters! Check the log below for progress."
439
  def _runner():
440
+ nonlocal log_message
441
  try:
442
  import subprocess
443
  import json
 
510
  json.dump(vit_config, f, indent=2)
511
 
512
  # Train ResNet with custom parameters
513
+ log_message = f"πŸš€ Starting ResNet training with custom parameters...\n"
514
+ log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
515
+ log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
516
+ log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
517
 
518
  resnet_cmd = [
519
  "python", "train_resnet.py",
 
533
  result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
534
 
535
  if result.returncode == 0:
536
+ log_message += "βœ… ResNet training completed successfully!\n\n"
537
  else:
538
+ log_message += f"❌ ResNet training failed: {result.stderr}\n\n"
539
  return
540
 
541
  # Train ViT with custom parameters
542
+ log_message += f"πŸš€ Starting ViT training with custom parameters...\n"
543
+ log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
544
+ log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
545
+ log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
546
 
547
  vit_cmd = [
548
  "python", "train_vit_triplet.py",
 
559
  result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
560
 
561
  if result.returncode == 0:
562
+ log_message += "βœ… ViT training completed successfully!\n\n"
563
+ log_message += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
564
+ log_message += "πŸ”„ Reloading models for inference...\n"
565
  service.reload_models()
566
+ log_message += "βœ… Models reloaded and ready for inference!\n"
567
+
568
+ # Auto-upload to HF Hub if token is available
569
+ hf_token = os.getenv("HF_TOKEN")
570
+ if hf_token:
571
+ log_message += "πŸ“€ Auto-uploading artifacts to Hugging Face Hub...\n"
572
+ try:
573
+ from utils.hf_hub_integration import create_hf_integration
574
+ hf = create_hf_integration(hf_token)
575
+ result = hf.upload_everything_to_hf()
576
+ if result.get("success"):
577
+ log_message += "βœ… Successfully uploaded to HF Hub!\n"
578
+ log_message += "πŸ”— Models: https://huggingface.co/Stylique/dressify-models\n"
579
+ log_message += "πŸ”— Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n"
580
+ else:
581
+ log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n"
582
+ except Exception as e:
583
+ log_message += f"⚠️ Auto-upload failed: {str(e)}\n"
584
+ else:
585
+ log_message += "πŸ’‘ Set HF_TOKEN env var for automatic uploads\n"
586
  else:
587
+ log_message += f"❌ ViT training failed: {result.stderr}\n"
588
 
589
  except Exception as e:
590
+ log_message += f"\n❌ Training error: {str(e)}"
591
 
592
  threading.Thread(target=_runner, daemon=True).start()
593
+ return log_message
594
 
595
 
596
  def start_training_simple(res_epochs: int, vit_epochs: int):
597
  """Start simple training with basic parameters."""
598
+ log_message = "Starting training..."
599
  def _runner():
600
+ nonlocal log_message
601
  try:
602
  import subprocess
603
  if not DATASET_ROOT:
604
+ log_message = "Dataset not ready."
605
  return
606
  export_dir = os.getenv("EXPORT_DIR", "models/exports")
607
  os.makedirs(export_dir, exist_ok=True)
608
+ log_message = "Training ResNet…\n"
609
  subprocess.run([
610
  "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
611
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
612
  ], check=False)
613
+ log_message += "\nTraining ViT (triplet)…\n"
614
  subprocess.run([
615
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
616
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
617
  ], check=False)
618
  service.reload_models()
619
+ log_message += "\nDone. Artifacts in models/exports."
620
+
621
+ # Auto-upload to HF Hub if token is available
622
+ hf_token = os.getenv("HF_TOKEN")
623
+ if hf_token:
624
+ log_message += "\nπŸ“€ Auto-uploading artifacts to Hugging Face Hub...\n"
625
+ try:
626
+ from utils.hf_hub_integration import create_hf_integration
627
+ hf = create_hf_integration(hf_token)
628
+ result = hf.upload_everything_to_hf()
629
+ if result.get("success"):
630
+ log_message += "βœ… Successfully uploaded to HF Hub!\n"
631
+ log_message += "πŸ”— Models: https://huggingface.co/Stylique/dressify-models\n"
632
+ log_message += "πŸ”— Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n"
633
+ else:
634
+ log_message += f"⚠️ Upload failed: {result.get('error', 'Unknown error')}\n"
635
+ except Exception as e:
636
+ log_message += f"⚠️ Auto-upload failed: {str(e)}\n"
637
+ else:
638
+ log_message += "\nπŸ’‘ Set HF_TOKEN env var for automatic uploads\n"
639
  except Exception as e:
640
+ log_message += f"\nError: {e}"
641
  threading.Thread(target=_runner, daemon=True).start()
642
+ return log_message
643
 
644
 
645
  with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
 
786
  outputs=train_log
787
  )
788
 
789
+ with gr.Tab("πŸ“¦ Artifact Management"):
790
+ gr.Markdown("### 🎯 Comprehensive Artifact Management\nManage, package, and upload all system artifacts to Hugging Face Hub.")
791
+
792
+ with gr.Row():
793
+ with gr.Column(scale=1):
794
+ gr.Markdown("#### πŸ“Š Artifact Overview")
795
+ artifact_overview = gr.JSON(label="System Artifacts", value=get_artifact_overview)
796
+ refresh_overview = gr.Button("πŸ”„ Refresh Overview")
797
+ refresh_overview.click(fn=get_artifact_overview, inputs=[], outputs=artifact_overview)
798
+
799
+ gr.Markdown("#### πŸ“¦ Create Packages")
800
+ package_type = gr.Dropdown(
801
+ choices=["complete", "splits_only", "models_only"],
802
+ value="complete",
803
+ label="Package Type"
804
+ )
805
+ create_package_btn = gr.Button("πŸ“¦ Create Package")
806
+ package_result = gr.Textbox(label="Package Result", interactive=False)
807
+ available_packages = gr.JSON(label="Available Packages", value=get_available_packages)
808
+
809
+ create_package_btn.click(
810
+ fn=create_download_package,
811
+ inputs=[package_type],
812
+ outputs=[package_result, available_packages]
813
+ )
814
+
815
+ with gr.Column(scale=1):
816
+ gr.Markdown("#### πŸš€ Hugging Face Hub Integration")
817
+ gr.Markdown("πŸ’‘ **Pro Tip**: Set `HF_TOKEN` environment variable for automatic uploads after training!")
818
+ hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_...")
819
+ hf_username = gr.Textbox(label="Username", placeholder="your-username")
820
+
821
+ with gr.Row():
822
+ push_splits_btn = gr.Button("πŸ“€ Push Splits", variant="secondary")
823
+ push_models_btn = gr.Button("πŸ“€ Push Models", variant="secondary")
824
+
825
+ push_everything_btn = gr.Button("πŸ“€ Push Everything", variant="primary")
826
+ hf_result = gr.Textbox(label="Upload Result", interactive=False, lines=3)
827
+
828
+ push_splits_btn.click(fn=push_splits_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
829
+ push_models_btn.click(fn=push_models_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
830
+ push_everything_btn.click(fn=push_everything_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
831
+
832
+ gr.Markdown("#### πŸ“₯ Download Management")
833
+ individual_files = gr.JSON(label="Individual Files", value=get_individual_files)
834
+ download_all_btn = gr.Button("πŸ“₯ Download All as ZIP")
835
+ download_result = gr.Textbox(label="Download Result", interactive=False)
836
+
837
+ download_all_btn.click(fn=download_all_files, inputs=[], outputs=download_result)
838
+
839
  with gr.Tab("πŸ”§ Simple Training"):
840
  gr.Markdown("### πŸš€ Quick Training with Default Parameters\nFast training with proven configurations for immediate results.")
841
  epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
 
850
  btn = gr.Button("Compute Embeddings")
851
  btn.click(fn=gradio_embed, inputs=inp, outputs=out)
852
 
853
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
 
855
  with gr.Tab("πŸ“ˆ Status"):
856
  gr.Markdown("### 🚦 System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.")
artifact_management_ui.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive Gradio interface for Dressify artifact management.
4
+ Provides download, upload, and organization features for all system artifacts.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import gradio as gr
10
+ from typing import Dict, List, Any
11
+ from utils.artifact_manager import create_artifact_manager
12
+
13
+ def create_artifact_management_interface():
14
+ """Create the main artifact management interface."""
15
+
16
+ with gr.Blocks(title="Dressify Artifact Management", theme=gr.themes.Soft()) as interface:
17
+ gr.Markdown("# 🎯 Dressify Artifact Management System")
18
+ gr.Markdown("## πŸ“¦ Download, Upload, and Organize All System Artifacts")
19
+
20
+ with gr.Tabs():
21
+
22
+ # Overview Tab
23
+ with gr.Tab("πŸ“Š System Overview"):
24
+ gr.Markdown("### πŸš€ Complete System Status and Artifact Summary")
25
+
26
+ with gr.Row():
27
+ refresh_overview = gr.Button("πŸ”„ Refresh Overview", variant="primary")
28
+ export_summary = gr.Button("πŸ“₯ Export Summary JSON", variant="secondary")
29
+
30
+ overview_display = gr.JSON(label="System Overview", value=get_system_overview())
31
+
32
+ refresh_overview.click(
33
+ fn=get_system_overview,
34
+ outputs=overview_display
35
+ )
36
+
37
+ export_summary.click(
38
+ fn=export_system_summary,
39
+ outputs=gr.File(label="Download Summary")
40
+ )
41
+
42
+ # Download Management Tab
43
+ with gr.Tab("πŸ“₯ Download Management"):
44
+ gr.Markdown("### 🎁 Create Downloadable Packages")
45
+
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ gr.Markdown("#### πŸ“¦ Package Types")
49
+ package_type = gr.Dropdown(
50
+ choices=[
51
+ "complete - Everything (splits + models + metadata + configs)",
52
+ "splits_only - Dataset splits only (lightweight)",
53
+ "models_only - Trained models only"
54
+ ],
55
+ value="splits_only - Dataset splits only (lightweight)",
56
+ label="Package Type"
57
+ )
58
+
59
+ create_package_btn = gr.Button("πŸš€ Create Package", variant="primary")
60
+ package_status = gr.Textbox(label="Package Status", interactive=False)
61
+
62
+ with gr.Column(scale=1):
63
+ gr.Markdown("#### πŸ“‹ Available Packages")
64
+ packages_list = gr.JSON(label="Created Packages", value=get_available_packages())
65
+ refresh_packages = gr.Button("πŸ”„ Refresh Packages")
66
+
67
+ create_package_btn.click(
68
+ fn=create_download_package,
69
+ inputs=[package_type],
70
+ outputs=[package_status, packages_list]
71
+ )
72
+
73
+ refresh_packages.click(
74
+ fn=get_available_packages,
75
+ outputs=packages_list
76
+ )
77
+
78
+ # Individual Files Tab
79
+ with gr.Tab("πŸ“ Individual Files"):
80
+ gr.Markdown("### πŸ” Browse and Download Individual Artifacts")
81
+
82
+ with gr.Row():
83
+ refresh_files = gr.Button("πŸ”„ Refresh Files", variant="primary")
84
+ download_all_btn = gr.Button("πŸ“₯ Download All as ZIP", variant="secondary")
85
+
86
+ files_display = gr.JSON(label="Available Files", value=get_individual_files())
87
+
88
+ refresh_files.click(
89
+ fn=get_individual_files,
90
+ outputs=files_display
91
+ )
92
+
93
+ download_all_btn.click(
94
+ fn=download_all_files,
95
+ outputs=gr.File(label="Download All Files")
96
+ )
97
+
98
+ # Upload & Restore Tab
99
+ with gr.Tab("πŸ“€ Upload & Restore"):
100
+ gr.Markdown("### πŸ”„ Upload Pre-processed Artifacts")
101
+ gr.Markdown("Upload previously downloaded packages to avoid reprocessing.")
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=1):
105
+ gr.Markdown("#### πŸ“€ Upload Package")
106
+ upload_package = gr.File(
107
+ label="Upload Artifact Package (.tar.gz)",
108
+ file_types=[".tar.gz", ".zip"]
109
+ )
110
+ upload_btn = gr.Button("πŸ“€ Upload & Extract", variant="primary")
111
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
112
+
113
+ with gr.Column(scale=1):
114
+ gr.Markdown("#### πŸ“‹ Restore Options")
115
+ restore_splits = gr.Button("πŸ”„ Restore Splits Only", variant="secondary")
116
+ restore_models = gr.Button("πŸ”„ Restore Models Only", variant="secondary")
117
+ restore_all = gr.Button("πŸ”„ Restore Everything", variant="secondary")
118
+
119
+ upload_btn.click(
120
+ fn=upload_and_extract_package,
121
+ inputs=[upload_package],
122
+ outputs=upload_status
123
+ )
124
+
125
+ restore_splits.click(
126
+ fn=restore_splits_only,
127
+ outputs=gr.Textbox(label="Restore Status")
128
+ )
129
+
130
+ restore_models.click(
131
+ fn=restore_models_only,
132
+ outputs=gr.Textbox(label="Restore Status")
133
+ )
134
+
135
+ restore_all.click(
136
+ fn=restore_everything,
137
+ outputs=gr.Textbox(label="Restore Status")
138
+ )
139
+
140
+ # Hugging Face Integration Tab
141
+ with gr.Tab("πŸ€— HF Hub Integration"):
142
+ gr.Markdown("### πŸš€ Push Artifacts to Hugging Face Hub")
143
+ gr.Markdown("Upload your artifacts to HF Hub for easy access and sharing.")
144
+
145
+ with gr.Row():
146
+ with gr.Column(scale=1):
147
+ gr.Markdown("#### πŸ”‘ Authentication")
148
+ hf_token = gr.Textbox(
149
+ label="Hugging Face Token",
150
+ placeholder="hf_...",
151
+ type="password"
152
+ )
153
+ hf_username = gr.Textbox(
154
+ label="HF Username",
155
+ placeholder="yourusername"
156
+ )
157
+
158
+ with gr.Column(scale=1):
159
+ gr.Markdown("#### πŸ“€ Push Options")
160
+ push_splits = gr.Button("πŸ“€ Push Splits to HF", variant="primary")
161
+ push_models = gr.Button("πŸ“€ Push Models to HF", variant="primary")
162
+ push_all = gr.Button("πŸ“€ Push Everything to HF", variant="primary")
163
+
164
+ push_status = gr.Textbox(label="Push Status", interactive=False)
165
+
166
+ push_splits.click(
167
+ fn=push_splits_to_hf,
168
+ inputs=[hf_token, hf_username],
169
+ outputs=push_status
170
+ )
171
+
172
+ push_models.click(
173
+ fn=push_models_to_hf,
174
+ inputs=[hf_token, hf_username],
175
+ outputs=push_status
176
+ )
177
+
178
+ push_all.click(
179
+ fn=push_everything_to_hf,
180
+ inputs=[hf_token, hf_username],
181
+ outputs=push_status
182
+ )
183
+
184
+ # Runtime Fetching Tab
185
+ with gr.Tab("⚑ Runtime Fetching"):
186
+ gr.Markdown("### πŸ”„ Fetch Artifacts at Runtime")
187
+ gr.Markdown("Configure the system to fetch artifacts from HF Hub instead of reprocessing.")
188
+
189
+ with gr.Row():
190
+ with gr.Column(scale=1):
191
+ gr.Markdown("#### πŸ”— HF Hub Sources")
192
+ splits_repo = gr.Textbox(
193
+ label="Splits Repository",
194
+ placeholder="yourusername/dressify-splits",
195
+ value="Stylique/dressify-splits"
196
+ )
197
+ models_repo = gr.Textbox(
198
+ label="Models Repository",
199
+ placeholder="yourusername/dressify-models",
200
+ value="Stylique/dressify-models"
201
+ )
202
+
203
+ enable_runtime_fetch = gr.Checkbox(
204
+ label="Enable Runtime Fetching",
205
+ value=False
206
+ )
207
+
208
+ with gr.Column(scale=1):
209
+ gr.Markdown("#### πŸš€ Fetch Actions")
210
+ fetch_splits = gr.Button("πŸ”„ Fetch Splits", variant="primary")
211
+ fetch_models = gr.Button("πŸ”„ Fetch Models", variant="primary")
212
+ fetch_all = gr.Button("πŸ”„ Fetch Everything", variant="primary")
213
+
214
+ fetch_status = gr.Textbox(label="Fetch Status", interactive=False)
215
+
216
+ fetch_splits.click(
217
+ fn=fetch_splits_from_hf,
218
+ inputs=[splits_repo],
219
+ outputs=fetch_status
220
+ )
221
+
222
+ fetch_models.click(
223
+ fn=fetch_models_from_hf,
224
+ inputs=[models_repo],
225
+ outputs=fetch_status
226
+ )
227
+
228
+ fetch_all.click(
229
+ fn=fetch_everything_from_hf,
230
+ inputs=[splits_repo, models_repo],
231
+ outputs=fetch_status
232
+ )
233
+
234
+ # Footer
235
+ gr.Markdown("---")
236
+ gr.Markdown("### πŸ’‘ Usage Instructions")
237
+ gr.Markdown("""
238
+ 1. **System Overview**: Check what artifacts are available and their sizes
239
+ 2. **Download Management**: Create packaged downloads for easy sharing
240
+ 3. **Individual Files**: Browse and download specific artifacts
241
+ 4. **Upload & Restore**: Upload previously downloaded packages
242
+ 5. **HF Hub Integration**: Push artifacts to Hugging Face for sharing
243
+ 6. **Runtime Fetching**: Configure automatic fetching from HF Hub
244
+ """)
245
+
246
+ gr.Markdown("### 🎯 Benefits")
247
+ gr.Markdown("""
248
+ - **Save Time**: No more reprocessing expensive splits
249
+ - **Save Resources**: Avoid re-downloading and re-extracting
250
+ - **Easy Sharing**: Download packages and share with others
251
+ - **HF Integration**: Push to Hub for community access
252
+ - **Runtime Fetching**: Automatic artifact retrieval
253
+ """)
254
+
255
+ return interface
256
+
257
+ # Helper functions for the interface
258
+ def get_system_overview():
259
+ """Get comprehensive system overview."""
260
+ try:
261
+ manager = create_artifact_manager()
262
+ return manager.get_artifact_summary()
263
+ except Exception as e:
264
+ return {"error": str(e)}
265
+
266
+ def export_system_summary():
267
+ """Export system summary as JSON file."""
268
+ try:
269
+ manager = create_artifact_manager()
270
+ summary = manager.get_artifact_summary()
271
+
272
+ # Save to exports directory
273
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
274
+ os.makedirs(export_dir, exist_ok=True)
275
+
276
+ summary_path = os.path.join(export_dir, "system_summary.json")
277
+ with open(summary_path, 'w') as f:
278
+ json.dump(summary, f, indent=2)
279
+
280
+ return summary_path
281
+ except Exception as e:
282
+ return None
283
+
284
+ def create_download_package(package_type: str):
285
+ """Create a downloadable package."""
286
+ try:
287
+ manager = create_artifact_manager()
288
+
289
+ # Extract package type from the dropdown choice
290
+ if "complete" in package_type:
291
+ pkg_type = "complete"
292
+ elif "splits_only" in package_type:
293
+ pkg_type = "splits_only"
294
+ elif "models_only" in package_type:
295
+ pkg_type = "models_only"
296
+ else:
297
+ return f"❌ Invalid package type: {package_type}", get_available_packages()
298
+
299
+ package_path = manager.create_download_package(pkg_type)
300
+ package_name = os.path.basename(package_path)
301
+
302
+ return f"βœ… Package created: {package_name}", get_available_packages()
303
+
304
+ except Exception as e:
305
+ return f"❌ Failed to create package: {e}", get_available_packages()
306
+
307
+ def get_available_packages():
308
+ """Get list of available packages."""
309
+ try:
310
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
311
+ packages = []
312
+
313
+ if os.path.exists(export_dir):
314
+ for file in os.listdir(export_dir):
315
+ if file.endswith((".tar.gz", ".zip")):
316
+ file_path = os.path.join(export_dir, file)
317
+ packages.append({
318
+ "name": file,
319
+ "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
320
+ "path": file_path,
321
+ "url": f"/files/{file}"
322
+ })
323
+
324
+ return {"packages": packages}
325
+ except Exception as e:
326
+ return {"error": str(e)}
327
+
328
+ def get_individual_files():
329
+ """Get list of individual downloadable files."""
330
+ try:
331
+ manager = create_artifact_manager()
332
+ files = manager.get_downloadable_files()
333
+
334
+ # Group by category
335
+ categorized = {}
336
+ for file in files:
337
+ category = file["category"]
338
+ if category not in categorized:
339
+ categorized[category] = []
340
+ categorized[category].append(file)
341
+
342
+ return categorized
343
+ except Exception as e:
344
+ return {"error": str(e)}
345
+
346
+ def download_all_files():
347
+ """Download all files as a ZIP archive."""
348
+ try:
349
+ manager = create_artifact_manager()
350
+ files = manager.get_downloadable_files()
351
+
352
+ # Create ZIP with all files
353
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
354
+ os.makedirs(export_dir, exist_ok=True)
355
+
356
+ zip_path = os.path.join(export_dir, "all_artifacts.zip")
357
+ import zipfile
358
+
359
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
360
+ for file in files:
361
+ if os.path.exists(file["path"]):
362
+ zipf.write(file["path"], file["name"])
363
+
364
+ return zip_path
365
+ except Exception as e:
366
+ return None
367
+
368
+ # Placeholder functions for upload/restore features
369
+ def upload_and_extract_package(upload_file):
370
+ """Upload and extract a package."""
371
+ if upload_file is None:
372
+ return "❌ No file uploaded"
373
+
374
+ try:
375
+ # This would implement actual upload and extraction logic
376
+ return f"βœ… Package uploaded: {upload_file.name}"
377
+ except Exception as e:
378
+ return f"❌ Upload failed: {e}"
379
+
380
+ def restore_splits_only():
381
+ """Restore splits only."""
382
+ return "πŸ”„ Splits restoration not yet implemented"
383
+
384
+ def restore_models_only():
385
+ """Restore models only."""
386
+ return "πŸ”„ Models restoration not yet implemented"
387
+
388
+ def restore_everything():
389
+ """Restore everything."""
390
+ return "πŸ”„ Full restoration not yet implemented"
391
+
392
+ # Placeholder functions for HF Hub integration
393
+ def push_splits_to_hf(token, username):
394
+ """Push splits to HF Hub."""
395
+ if not token or not username:
396
+ return "❌ Please provide HF token and username"
397
+ return f"πŸ“€ Pushing splits to {username}/dressify-splits..."
398
+
399
+ def push_models_to_hf(token, username):
400
+ """Push models to HF Hub."""
401
+ if not token or not username:
402
+ return "❌ Please provide HF token and username"
403
+ return f"πŸ“€ Pushing models to {username}/dressify-models..."
404
+
405
+ def push_everything_to_hf(token, username):
406
+ """Push everything to HF Hub."""
407
+ if not token or not username:
408
+ return "❌ Please provide HF token and username"
409
+ return f"πŸ“€ Pushing everything to {username}/dressify..."
410
+
411
+ # Placeholder functions for runtime fetching
412
+ def fetch_splits_from_hf(repo):
413
+ """Fetch splits from HF Hub."""
414
+ return f"πŸ”„ Fetching splits from {repo}..."
415
+
416
+ def fetch_models_from_hf(repo):
417
+ """Fetch models from HF Hub."""
418
+ return f"πŸ”„ Fetching models from {repo}..."
419
+
420
+ def fetch_everything_from_hf(splits_repo, models_repo):
421
+ """Fetch everything from HF Hub."""
422
+ return f"πŸ”„ Fetching everything from {splits_repo} and {models_repo}..."
423
+
424
+ if __name__ == "__main__":
425
+ # Test the interface
426
+ interface = create_artifact_management_interface()
427
+ interface.launch()
models/resnet_embedder.py CHANGED
@@ -2,6 +2,7 @@ from typing import Optional
2
 
3
  import torch
4
  import torch.nn as nn
 
5
  import torchvision.models as tvm
6
 
7
 
@@ -27,6 +28,8 @@ class ResNetItemEmbedder(nn.Module):
27
  feats = self.backbone(x) # (B, C, 1, 1)
28
  feats = feats.flatten(1) # (B, C)
29
  emb = self.proj(feats) # (B, D)
 
 
30
  return emb
31
 
32
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
  import torchvision.models as tvm
7
 
8
 
 
28
  feats = self.backbone(x) # (B, C, 1, 1)
29
  feats = feats.flatten(1) # (B, C)
30
  emb = self.proj(feats) # (B, D)
31
+ # Apply L2 normalization as specified in requirements
32
+ emb = F.normalize(emb, p=2, dim=1)
33
  return emb
34
 
35
 
test_training.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple test script to verify training components work.
4
+ Run this to test if the system is ready for training.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+
11
+ def test_imports():
12
+ """Test if all required modules can be imported."""
13
+ print("πŸ” Testing imports...")
14
+
15
+ try:
16
+ from models.resnet_embedder import ResNetItemEmbedder
17
+ print("βœ… ResNet embedder imported successfully")
18
+ except Exception as e:
19
+ print(f"❌ Failed to import ResNet embedder: {e}")
20
+ return False
21
+
22
+ try:
23
+ from models.vit_outfit import OutfitCompatibilityModel
24
+ print("βœ… ViT outfit model imported successfully")
25
+ except Exception as e:
26
+ print(f"❌ Failed to import ViT outfit model: {e}")
27
+ return False
28
+
29
+ try:
30
+ from data.polyvore import PolyvoreTripletDataset
31
+ print("βœ… Polyvore dataset imported successfully")
32
+ except Exception as e:
33
+ print(f"❌ Failed to import Polyvore dataset: {e}")
34
+ return False
35
+
36
+ try:
37
+ from utils.transforms import build_train_transforms
38
+ print("βœ… Transforms imported successfully")
39
+ except Exception as e:
40
+ print(f"❌ Failed to import transforms: {e}")
41
+ return False
42
+
43
+ return True
44
+
45
+ def test_models():
46
+ """Test if models can be created and run forward pass."""
47
+ print("\nπŸ—οΈ Testing model creation...")
48
+
49
+ try:
50
+ # Test ResNet embedder
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ print(f"Using device: {device}")
53
+
54
+ resnet = ResNetItemEmbedder(embedding_dim=512).to(device)
55
+ print(f"βœ… ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters")
56
+
57
+ # Test forward pass
58
+ dummy_input = torch.randn(2, 3, 224, 224).to(device)
59
+ with torch.no_grad():
60
+ output = resnet(dummy_input)
61
+ print(f"βœ… ResNet forward pass: input {dummy_input.shape} -> output {output.shape}")
62
+
63
+ # Test ViT outfit model
64
+ vit = OutfitCompatibilityModel(embedding_dim=512).to(device)
65
+ print(f"βœ… ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters")
66
+
67
+ # Test forward pass
68
+ dummy_tokens = torch.randn(2, 4, 512).to(device)
69
+ with torch.no_grad():
70
+ output = vit(dummy_tokens)
71
+ print(f"βœ… ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}")
72
+
73
+ return True
74
+
75
+ except Exception as e:
76
+ print(f"❌ Model test failed: {e}")
77
+ return False
78
+
79
+ def test_dataset():
80
+ """Test if dataset can be loaded (if available)."""
81
+ print("\nπŸ“Š Testing dataset loading...")
82
+
83
+ data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")
84
+ splits_dir = os.path.join(data_root, "splits")
85
+ train_file = os.path.join(splits_dir, "train.json")
86
+
87
+ if not os.path.exists(train_file):
88
+ print(f"⚠️ Training data not found at {train_file}")
89
+ print("πŸ’‘ Dataset preparation may be needed")
90
+ return True # Not a failure, just not ready
91
+
92
+ try:
93
+ dataset = PolyvoreTripletDataset(data_root, split="train")
94
+ print(f"βœ… Dataset loaded successfully: {len(dataset)} samples")
95
+
96
+ # Test getting one sample
97
+ if len(dataset) > 0:
98
+ sample = dataset[0]
99
+ print(f"βœ… Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}")
100
+
101
+ return True
102
+
103
+ except Exception as e:
104
+ print(f"❌ Dataset test failed: {e}")
105
+ return False
106
+
107
+ def test_training_components():
108
+ """Test if training components can be created."""
109
+ print("\nπŸš€ Testing training components...")
110
+
111
+ try:
112
+ from torch.utils.data import DataLoader
113
+ from torch.optim import AdamW
114
+ from torch.nn import TripletMarginLoss
115
+
116
+ # Test optimizer creation
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ model = ResNetItemEmbedder(embedding_dim=512).to(device)
119
+ optimizer = AdamW(model.parameters(), lr=1e-3)
120
+ print("βœ… Optimizer created successfully")
121
+
122
+ # Test loss function
123
+ criterion = TripletMarginLoss(margin=0.2)
124
+ print("βœ… Loss function created successfully")
125
+
126
+ return True
127
+
128
+ except Exception as e:
129
+ print(f"❌ Training components test failed: {e}")
130
+ return False
131
+
132
+ def main():
133
+ """Run all tests."""
134
+ print("πŸ§ͺ Starting Dressify Training System Tests\n")
135
+
136
+ tests = [
137
+ ("Imports", test_imports),
138
+ ("Models", test_models),
139
+ ("Dataset", test_dataset),
140
+ ("Training Components", test_training_components),
141
+ ]
142
+
143
+ results = []
144
+ for test_name, test_func in tests:
145
+ try:
146
+ result = test_func()
147
+ results.append((test_name, result))
148
+ except Exception as e:
149
+ print(f"οΏ½οΏ½ {test_name} test crashed: {e}")
150
+ results.append((test_name, False))
151
+
152
+ # Summary
153
+ print("\n" + "="*50)
154
+ print("πŸ“Š TEST RESULTS SUMMARY")
155
+ print("="*50)
156
+
157
+ passed = 0
158
+ total = len(results)
159
+
160
+ for test_name, result in results:
161
+ status = "βœ… PASS" if result else "❌ FAIL"
162
+ print(f"{test_name:20} {status}")
163
+ if result:
164
+ passed += 1
165
+
166
+ print("="*50)
167
+ print(f"Overall: {passed}/{total} tests passed")
168
+
169
+ if passed == total:
170
+ print("πŸŽ‰ All tests passed! System is ready for training.")
171
+ return True
172
+ else:
173
+ print("⚠️ Some tests failed. Please check the errors above.")
174
+ return False
175
+
176
+ if __name__ == "__main__":
177
+ success = main()
178
+ sys.exit(0 if success else 1)
train_resnet.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import argparse
 
3
  from typing import Tuple
4
 
5
  import torch
@@ -7,6 +8,9 @@ import torch.nn as nn
7
  import torch.optim as optim
8
  from torch.utils.data import DataLoader
9
 
 
 
 
10
  from data.polyvore import PolyvoreTripletDataset
11
  from models.resnet_embedder import ResNetItemEmbedder
12
  from utils.export import ensure_export_dir
@@ -15,7 +19,7 @@ import json
15
 
16
  def parse_args() -> argparse.Namespace:
17
  p = argparse.ArgumentParser()
18
- p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/kaggle/input/polyvore-outfits"))
19
  p.add_argument("--epochs", type=int, default=20)
20
  p.add_argument("--batch_size", type=int, default=64)
21
  p.add_argument("--lr", type=float, default=1e-3)
@@ -30,80 +34,119 @@ def main() -> None:
30
  if device == "cuda":
31
  torch.backends.cudnn.benchmark = True
32
 
 
 
 
 
33
  # Ensure splits exist; if missing, prepare from official splits
34
  splits_dir = os.path.join(args.data_root, "splits")
35
  triplet_path = os.path.join(splits_dir, "train.json")
 
36
  if not os.path.exists(triplet_path):
 
 
37
  os.makedirs(splits_dir, exist_ok=True)
38
  try:
39
- from scripts.prepare_polyvore import main as prepare_main
40
- import sys
41
- argv_bak = sys.argv
42
- try:
43
- # First try using official splits (no random)
44
- sys.argv = ["prepare_polyvore.py", "--root", args.data_root]
45
- prepare_main()
46
- finally:
47
- sys.argv = argv_bak
48
- except Exception:
49
- # As a fallback, try random split on any available aggregate file
50
- try:
51
- from scripts.prepare_polyvore import main as prepare_main
52
- import sys
53
- argv_bak = sys.argv
54
- try:
55
- sys.argv = ["prepare_polyvore.py", "--root", args.data_root, "--random_split"]
56
- prepare_main()
57
- finally:
58
- sys.argv = argv_bak
59
- except Exception:
60
- pass
61
-
62
- dataset = PolyvoreTripletDataset(args.data_root, split="train")
63
-
64
- loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"))
65
  model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
66
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
67
  criterion = nn.TripletMarginLoss(margin=0.2, p=2)
68
 
 
 
 
69
  export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
70
  best_loss = float("inf")
71
  history = []
 
 
 
72
  for epoch in range(args.epochs):
73
  model.train()
74
- running = 0.0
75
  steps = 0
76
- for batch in loader:
77
- # Expect batch as (anchor, positive, negative)
78
- anchor, positive, negative = batch
79
- anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
80
- positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
81
- negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
82
- with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
83
- emb_a = model(anchor)
84
- emb_p = model(positive)
85
- emb_n = model(negative)
86
- loss = criterion(emb_a, emb_p, emb_n)
87
- optimizer.zero_grad(set_to_none=True)
88
- loss.backward()
89
- optimizer.step()
90
- running += loss.item()
91
- steps += 1
92
- avg_loss = running / max(1, steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  out_path = args.out
94
  if not out_path.startswith("models/"):
95
  out_path = os.path.join(export_dir, os.path.basename(args.out))
96
- torch.save({"state_dict": model.state_dict()}, out_path)
97
- print(f"Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
 
 
 
 
 
 
98
  history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
 
99
  if avg_loss < best_loss:
100
  best_loss = avg_loss
101
- torch.save({"state_dict": model.state_dict()}, os.path.join(export_dir, "resnet_item_embedder_best.pth"))
 
 
102
 
103
- # write metrics
104
  metrics_path = os.path.join(export_dir, "resnet_metrics.json")
105
  with open(metrics_path, "w") as f:
106
  json.dump({"best_triplet_loss": best_loss, "history": history}, f)
 
 
 
107
 
108
 
109
  if __name__ == "__main__":
 
1
  import os
2
  import argparse
3
+ import sys
4
  from typing import Tuple
5
 
6
  import torch
 
8
  import torch.optim as optim
9
  from torch.utils.data import DataLoader
10
 
11
+ # Fix import paths
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+
14
  from data.polyvore import PolyvoreTripletDataset
15
  from models.resnet_embedder import ResNetItemEmbedder
16
  from utils.export import ensure_export_dir
 
19
 
20
  def parse_args() -> argparse.Namespace:
21
  p = argparse.ArgumentParser()
22
+ p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
23
  p.add_argument("--epochs", type=int, default=20)
24
  p.add_argument("--batch_size", type=int, default=64)
25
  p.add_argument("--lr", type=float, default=1e-3)
 
34
  if device == "cuda":
35
  torch.backends.cudnn.benchmark = True
36
 
37
+ print(f"πŸš€ Starting ResNet training on {device}")
38
+ print(f"πŸ“ Data root: {args.data_root}")
39
+ print(f"βš™οΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")
40
+
41
  # Ensure splits exist; if missing, prepare from official splits
42
  splits_dir = os.path.join(args.data_root, "splits")
43
  triplet_path = os.path.join(splits_dir, "train.json")
44
+
45
  if not os.path.exists(triplet_path):
46
+ print(f"⚠️ Triplet file not found: {triplet_path}")
47
+ print("πŸ”§ Attempting to prepare dataset...")
48
  os.makedirs(splits_dir, exist_ok=True)
49
  try:
50
+ # Try to import and run the prepare script
51
+ sys.path.append(os.path.join(os.path.dirname(__file__), "scripts"))
52
+ from prepare_polyvore import main as prepare_main
53
+ print("βœ… Successfully imported prepare_polyvore")
54
+
55
+ # Prepare dataset without random splits
56
+ prepare_main()
57
+ print("βœ… Dataset preparation completed")
58
+ except Exception as e:
59
+ print(f"❌ Failed to prepare dataset: {e}")
60
+ print("πŸ’‘ Please ensure the dataset is prepared manually")
61
+ return
62
+ else:
63
+ print(f"βœ… Found existing splits: {triplet_path}")
64
+
65
+ try:
66
+ dataset = PolyvoreTripletDataset(args.data_root, split="train")
67
+ print(f"πŸ“Š Dataset loaded: {len(dataset)} samples")
68
+ except Exception as e:
69
+ print(f"❌ Failed to load dataset: {e}")
70
+ return
71
+
72
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"))
 
 
 
73
  model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
74
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
75
  criterion = nn.TripletMarginLoss(margin=0.2, p=2)
76
 
77
+ print(f"πŸ—οΈ Model created: {model.__class__.__name__}")
78
+ print(f"πŸ“ˆ Total parameters: {sum(p.numel() for p in model.parameters()):,}")
79
+
80
  export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
81
  best_loss = float("inf")
82
  history = []
83
+
84
+ print(f"πŸ’Ύ Checkpoints will be saved to: {export_dir}")
85
+
86
  for epoch in range(args.epochs):
87
  model.train()
88
+ running_loss = 0.0
89
  steps = 0
90
+
91
+ print(f"πŸ”„ Epoch {epoch+1}/{args.epochs}")
92
+
93
+ for batch_idx, batch in enumerate(loader):
94
+ try:
95
+ # Expect batch as (anchor, positive, negative)
96
+ anchor, positive, negative = batch
97
+ anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
98
+ positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
99
+ negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
100
+
101
+ with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
102
+ emb_a = model(anchor)
103
+ emb_p = model(positive)
104
+ emb_n = model(negative)
105
+
106
+ loss = criterion(emb_a, emb_p, emb_n)
107
+ optimizer.zero_grad(set_to_none=True)
108
+ loss.backward()
109
+ optimizer.step()
110
+
111
+ running_loss += loss.item()
112
+ steps += 1
113
+
114
+ if batch_idx % 100 == 0:
115
+ print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
116
+
117
+ except Exception as e:
118
+ print(f"❌ Error in batch {batch_idx}: {e}")
119
+ continue
120
+
121
+ avg_loss = running_loss / max(1, steps)
122
+
123
+ # Save checkpoint with better path handling
124
  out_path = args.out
125
  if not out_path.startswith("models/"):
126
  out_path = os.path.join(export_dir, os.path.basename(args.out))
127
+
128
+ # Ensure the output directory exists
129
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
130
+
131
+ # Save checkpoint
132
+ torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
133
+ print(f"βœ… Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
134
+
135
  history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
136
+
137
  if avg_loss < best_loss:
138
  best_loss = avg_loss
139
+ best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth")
140
+ torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
141
+ print(f"πŸ† New best model saved: {best_path}")
142
 
143
+ # Write metrics
144
  metrics_path = os.path.join(export_dir, "resnet_metrics.json")
145
  with open(metrics_path, "w") as f:
146
  json.dump({"best_triplet_loss": best_loss, "history": history}, f)
147
+
148
+ print(f"πŸ“Š Training completed! Best loss: {best_loss:.4f}")
149
+ print(f"πŸ“ˆ Metrics saved to: {metrics_path}")
150
 
151
 
152
  if __name__ == "__main__":
train_vit_triplet.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import argparse
 
3
  from typing import List
4
 
5
  import torch
@@ -7,6 +8,9 @@ import torch.nn as nn
7
  import torch.optim as optim
8
  from torch.utils.data import DataLoader
9
 
 
 
 
10
  from data.polyvore import PolyvoreOutfitTripletDataset
11
  from models.vit_outfit import OutfitCompatibilityModel
12
  from models.resnet_embedder import ResNetItemEmbedder
@@ -16,7 +20,7 @@ import json
16
 
17
  def parse_args() -> argparse.Namespace:
18
  p = argparse.ArgumentParser()
19
- p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/kaggle/input/polyvore-outfits"))
20
  p.add_argument("--epochs", type=int, default=30)
21
  p.add_argument("--batch_size", type=int, default=32)
22
  p.add_argument("--lr", type=float, default=5e-4)
@@ -45,128 +49,182 @@ def main() -> None:
45
  if device == "cuda":
46
  torch.backends.cudnn.benchmark = True
47
 
 
 
 
 
48
  # Ensure outfit triplets exist
49
  splits_dir = os.path.join(args.data_root, "splits")
50
  trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
 
51
  if not os.path.exists(trip_path):
 
 
52
  os.makedirs(splits_dir, exist_ok=True)
53
  try:
54
- from scripts.prepare_polyvore import main as prepare_main
55
- import sys
56
- argv_bak = sys.argv
57
- try:
58
- sys.argv = ["prepare_polyvore.py", "--root", args.data_root]
59
- prepare_main()
60
- finally:
61
- sys.argv = argv_bak
62
- except Exception:
63
- try:
64
- from scripts.prepare_polyvore import main as prepare_main
65
- import sys
66
- argv_bak = sys.argv
67
- try:
68
- sys.argv = ["prepare_polyvore.py", "--root", args.data_root, "--random_split"]
69
- prepare_main()
70
- finally:
71
- sys.argv = argv_bak
72
- except Exception:
73
- pass
74
-
75
- dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
76
 
77
  def collate(batch):
78
  return batch # variable length handled inside training loop
79
 
80
- loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"), collate_fn=collate)
81
 
82
  model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
83
  embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
84
  for p in embedder.parameters():
85
  p.requires_grad_(False)
86
 
 
 
 
 
 
87
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
88
  triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
89
 
90
  export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
91
  best_loss = float("inf")
92
  hist = []
 
 
 
93
  for epoch in range(args.epochs):
94
  model.train()
95
- for batch in loader:
96
- # batch: List[(ga_imgs, gb_imgs, bd_imgs)]
97
- anchor_tokens = []
98
- positive_tokens = []
99
- negative_tokens = []
100
- for ga, gb, bd in batch:
101
- ta = embed_outfit(ga, embedder, device)
102
- tb = embed_outfit(gb, embedder, device)
103
- tn = embed_outfit(bd, embedder, device)
104
- anchor_tokens.append(ta.unsqueeze(0))
105
- positive_tokens.append(tb.unsqueeze(0))
106
- negative_tokens.append(tn.unsqueeze(0))
107
- A = torch.cat(anchor_tokens, dim=0) # (B, N, D)
108
- P = torch.cat(positive_tokens, dim=0)
109
- N = torch.cat(negative_tokens, dim=0)
110
-
111
- # get outfit-level embeddings via ViT encoder pooled output
112
- with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
113
- ea = model.encoder(A).mean(dim=1)
114
- ep = model.encoder(P).mean(dim=1)
115
- en = model.encoder(N).mean(dim=1)
116
- loss = triplet(ea, ep, en)
117
- optimizer.zero_grad(set_to_none=True)
118
- loss.backward()
119
- optimizer.step()
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Simple validation using a subset of training data as a proxy if no val split here
122
  # For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
123
  val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
124
  val_loss = None
 
125
  if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
126
- val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
127
- val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: x)
128
- model.eval()
129
- losses = []
130
- with torch.no_grad():
131
- for vbatch in val_loader:
132
- anchor_tokens = []
133
- positive_tokens = []
134
- negative_tokens = []
135
- for ga, gb, bd in vbatch:
136
- ta = embed_outfit(ga, embedder, device)
137
- tb = embed_outfit(gb, embedder, device)
138
- tn = embed_outfit(bd, embedder, device)
139
- anchor_tokens.append(ta.unsqueeze(0))
140
- positive_tokens.append(tb.unsqueeze(0))
141
- negative_tokens.append(tn.unsqueeze(0))
142
- A = torch.cat(anchor_tokens, dim=0)
143
- P = torch.cat(positive_tokens, dim=0)
144
- N = torch.cat(negative_tokens, dim=0)
145
- ea = model.encoder(A).mean(dim=1)
146
- ep = model.encoder(P).mean(dim=1)
147
- en = model.encoder(N).mean(dim=1)
148
- l = triplet(ea, ep, en).item()
149
- losses.append(l)
150
- val_loss = sum(losses) / max(1, len(losses))
 
 
 
 
 
 
 
151
 
152
  out_path = args.export
153
  if not out_path.startswith("models/"):
154
  out_path = os.path.join(export_dir, os.path.basename(args.export))
155
- torch.save({"state_dict": model.state_dict()}, out_path)
 
 
 
156
  if val_loss is not None:
157
- print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={loss.item():.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
158
- hist.append({"epoch": epoch + 1, "triplet_loss": float(loss.item()), "val_triplet_loss": float(val_loss)})
159
  if val_loss < best_loss:
160
  best_loss = val_loss
161
- torch.save({"state_dict": model.state_dict()}, os.path.join(export_dir, "vit_outfit_model_best.pth"))
 
 
162
  else:
163
- print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={loss.item():.4f} saved -> {out_path}")
164
- hist.append({"epoch": epoch + 1, "triplet_loss": float(loss.item())})
165
 
 
166
  metrics_path = os.path.join(export_dir, "vit_metrics.json")
167
  payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
168
  with open(metrics_path, "w") as f:
169
  json.dump(payload, f)
 
 
 
 
 
170
 
171
 
172
  if __name__ == "__main__":
 
1
  import os
2
  import argparse
3
+ import sys
4
  from typing import List
5
 
6
  import torch
 
8
  import torch.optim as optim
9
  from torch.utils.data import DataLoader
10
 
11
+ # Fix import paths
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+
14
  from data.polyvore import PolyvoreOutfitTripletDataset
15
  from models.vit_outfit import OutfitCompatibilityModel
16
  from models.resnet_embedder import ResNetItemEmbedder
 
20
 
21
  def parse_args() -> argparse.Namespace:
22
  p = argparse.ArgumentParser()
23
+ p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
24
  p.add_argument("--epochs", type=int, default=30)
25
  p.add_argument("--batch_size", type=int, default=32)
26
  p.add_argument("--lr", type=float, default=5e-4)
 
49
  if device == "cuda":
50
  torch.backends.cudnn.benchmark = True
51
 
52
+ print(f"πŸš€ Starting ViT Outfit training on {device}")
53
+ print(f"πŸ“ Data root: {args.data_root}")
54
+ print(f"βš™οΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")
55
+
56
  # Ensure outfit triplets exist
57
  splits_dir = os.path.join(args.data_root, "splits")
58
  trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
59
+
60
  if not os.path.exists(trip_path):
61
+ print(f"⚠️ Outfit triplet file not found: {trip_path}")
62
+ print("πŸ”§ Attempting to prepare dataset...")
63
  os.makedirs(splits_dir, exist_ok=True)
64
  try:
65
+ # Try to import and run the prepare script
66
+ sys.path.append(os.path.join(os.path.dirname(__file__), "scripts"))
67
+ from prepare_polyvore import main as prepare_main
68
+ print("βœ… Successfully imported prepare_polyvore")
69
+
70
+ # Prepare dataset without random splits
71
+ prepare_main()
72
+ print("βœ… Dataset preparation completed")
73
+ except Exception as e:
74
+ print(f"❌ Failed to prepare dataset: {e}")
75
+ print("πŸ’‘ Please ensure the dataset is prepared manually")
76
+ return
77
+ else:
78
+ print(f"βœ… Found existing outfit triplets: {trip_path}")
79
+
80
+ try:
81
+ dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
82
+ print(f"πŸ“Š Dataset loaded: {len(dataset)} samples")
83
+ except Exception as e:
84
+ print(f"❌ Failed to load dataset: {e}")
85
+ return
 
86
 
87
  def collate(batch):
88
  return batch # variable length handled inside training loop
89
 
90
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate)
91
 
92
  model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
93
  embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
94
  for p in embedder.parameters():
95
  p.requires_grad_(False)
96
 
97
+ print(f"πŸ—οΈ Models created:")
98
+ print(f" - ViT Outfit: {model.__class__.__name__}")
99
+ print(f" - ResNet Embedder: {embedder.__class__.__name__}")
100
+ print(f"πŸ“ˆ Total parameters: {sum(p.numel() for p in model.parameters()):,}")
101
+
102
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
103
  triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
104
 
105
  export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
106
  best_loss = float("inf")
107
  hist = []
108
+
109
+ print(f"πŸ’Ύ Checkpoints will be saved to: {export_dir}")
110
+
111
  for epoch in range(args.epochs):
112
  model.train()
113
+ running_loss = 0.0
114
+ steps = 0
115
+
116
+ print(f"πŸ”„ Epoch {epoch+1}/{args.epochs}")
117
+
118
+ for batch_idx, batch in enumerate(loader):
119
+ try:
120
+ # batch: List[(ga_imgs, gb_imgs, bd_imgs)]
121
+ anchor_tokens = []
122
+ positive_tokens = []
123
+ negative_tokens = []
124
+
125
+ for ga, gb, bd in batch:
126
+ ta = embed_outfit(ga, embedder, device)
127
+ tb = embed_outfit(gb, embedder, device)
128
+ tn = embed_outfit(bd, embedder, device)
129
+ anchor_tokens.append(ta.unsqueeze(0))
130
+ positive_tokens.append(tb.unsqueeze(0))
131
+ negative_tokens.append(tn.unsqueeze(0))
132
+
133
+ A = torch.cat(anchor_tokens, dim=0) # (B, N, D)
134
+ P = torch.cat(positive_tokens, dim=0)
135
+ N = torch.cat(negative_tokens, dim=0)
136
+
137
+ # get outfit-level embeddings via ViT encoder pooled output
138
+ with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
139
+ ea = model.encoder(A).mean(dim=1)
140
+ ep = model.encoder(P).mean(dim=1)
141
+ en = model.encoder(N).mean(dim=1)
142
+ loss = triplet(ea, ep, en)
143
+
144
+ optimizer.zero_grad(set_to_none=True)
145
+ loss.backward()
146
+ optimizer.step()
147
+
148
+ running_loss += loss.item()
149
+ steps += 1
150
+
151
+ if batch_idx % 50 == 0:
152
+ print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
153
+
154
+ except Exception as e:
155
+ print(f"❌ Error in batch {batch_idx}: {e}")
156
+ continue
157
+
158
+ avg_loss = running_loss / max(1, steps)
159
+
160
  # Simple validation using a subset of training data as a proxy if no val split here
161
  # For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
162
  val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
163
  val_loss = None
164
+
165
  if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
166
+ try:
167
+ val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
168
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: x)
169
+ model.eval()
170
+ losses = []
171
+
172
+ with torch.no_grad():
173
+ for vbatch in val_loader:
174
+ anchor_tokens = []
175
+ positive_tokens = []
176
+ negative_tokens = []
177
+ for ga, gb, bd in vbatch:
178
+ ta = embed_outfit(ga, embedder, device)
179
+ tb = embed_outfit(gb, embedder, device)
180
+ tn = embed_outfit(bd, embedder, device)
181
+ anchor_tokens.append(ta.unsqueeze(0))
182
+ positive_tokens.append(tb.unsqueeze(0))
183
+ negative_tokens.append(tn.unsqueeze(0))
184
+ A = torch.cat(anchor_tokens, dim=0)
185
+ P = torch.cat(positive_tokens, dim=0)
186
+ N = torch.cat(negative_tokens, dim=0)
187
+ ea = model.encoder(A).mean(dim=1)
188
+ ep = model.encoder(P).mean(dim=1)
189
+ en = model.encoder(N).mean(dim=1)
190
+ l = triplet(ea, ep, en).item()
191
+ losses.append(l)
192
+
193
+ val_loss = sum(losses) / max(1, len(losses))
194
+ print(f" πŸ“Š Validation loss: {val_loss:.4f}")
195
+
196
+ except Exception as e:
197
+ print(f" ⚠️ Validation failed: {e}")
198
 
199
  out_path = args.export
200
  if not out_path.startswith("models/"):
201
  out_path = os.path.join(export_dir, os.path.basename(args.export))
202
+
203
+ # Save checkpoint
204
+ torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
205
+
206
  if val_loss is not None:
207
+ print(f"βœ… Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
208
+ hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)})
209
  if val_loss < best_loss:
210
  best_loss = val_loss
211
+ best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
212
+ torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path)
213
+ print(f"πŸ† New best model saved: {best_path}")
214
  else:
215
+ print(f"βœ… Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}")
216
+ hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)})
217
 
218
+ # Write metrics
219
  metrics_path = os.path.join(export_dir, "vit_metrics.json")
220
  payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
221
  with open(metrics_path, "w") as f:
222
  json.dump(payload, f)
223
+
224
+ print(f"πŸ“Š Training completed!")
225
+ if best_loss != float("inf"):
226
+ print(f"πŸ† Best validation loss: {best_loss:.4f}")
227
+ print(f"πŸ“ˆ Metrics saved to: {metrics_path}")
228
 
229
 
230
  if __name__ == "__main__":
training_monitor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple training monitor for Dressify.
4
+ Shows real-time training progress and status.
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import json
10
+ from datetime import datetime
11
+ from typing import Dict, Any, Optional
12
+
13
+ class TrainingMonitor:
14
+ """Monitor training progress and status."""
15
+
16
+ def __init__(self, export_dir: str = "models/exports"):
17
+ self.export_dir = export_dir
18
+ self.status_file = os.path.join(export_dir, "training_status.json")
19
+ self.start_time = None
20
+
21
+ def start_training(self, model_name: str, config: Dict[str, Any]):
22
+ """Start monitoring a training session."""
23
+ self.start_time = datetime.now()
24
+ status = {
25
+ "model": model_name,
26
+ "status": "training",
27
+ "start_time": self.start_time.isoformat(),
28
+ "config": config,
29
+ "epochs_completed": 0,
30
+ "current_loss": None,
31
+ "best_loss": float("inf"),
32
+ "last_update": datetime.now().isoformat()
33
+ }
34
+ self._save_status(status)
35
+ print(f"πŸš€ Started monitoring {model_name} training")
36
+
37
+ def update_progress(self, epoch: int, loss: float, is_best: bool = False):
38
+ """Update training progress."""
39
+ if not self.start_time:
40
+ return
41
+
42
+ try:
43
+ with open(self.status_file, 'r') as f:
44
+ status = json.load(f)
45
+ except:
46
+ return
47
+
48
+ status["epochs_completed"] = epoch
49
+ status["current_loss"] = loss
50
+ status["last_update"] = datetime.now().isoformat()
51
+
52
+ if is_best:
53
+ status["best_loss"] = min(status["best_loss"], loss)
54
+
55
+ self._save_status(status)
56
+
57
+ def complete_training(self, final_loss: float, total_epochs: int):
58
+ """Mark training as completed."""
59
+ try:
60
+ with open(self.status_file, 'r') as f:
61
+ status = json.load(f)
62
+ except:
63
+ return
64
+
65
+ status["status"] = "completed"
66
+ status["epochs_completed"] = total_epochs
67
+ status["current_loss"] = final_loss
68
+ status["best_loss"] = min(status["best_loss"], final_loss)
69
+ status["completion_time"] = datetime.now().isoformat()
70
+ status["duration"] = str(datetime.now() - self.start_time) if self.start_time else None
71
+
72
+ self._save_status(status)
73
+ print(f"βœ… Training completed in {status['duration']}")
74
+
75
+ def fail_training(self, error: str):
76
+ """Mark training as failed."""
77
+ try:
78
+ with open(self.status_file, 'r') as f:
79
+ status = json.load(f)
80
+ except:
81
+ return
82
+
83
+ status["status"] = "failed"
84
+ status["error"] = error
85
+ status["failure_time"] = datetime.now().isoformat()
86
+
87
+ self._save_status(status)
88
+ print(f"❌ Training failed: {error}")
89
+
90
+ def get_status(self) -> Optional[Dict[str, Any]]:
91
+ """Get current training status."""
92
+ try:
93
+ with open(self.status_file, 'r') as f:
94
+ return json.load(f)
95
+ except:
96
+ return None
97
+
98
+ def _save_status(self, status: Dict[str, Any]):
99
+ """Save status to file."""
100
+ os.makedirs(self.export_dir, exist_ok=True)
101
+ with open(self.status_file, 'w') as f:
102
+ json.dump(status, f, indent=2)
103
+
104
+ def print_status(self):
105
+ """Print current training status."""
106
+ status = self.get_status()
107
+ if not status:
108
+ print("πŸ“Š No training status available")
109
+ return
110
+
111
+ print(f"\nπŸ“Š Training Status: {status['model']}")
112
+ print(f"Status: {status['status']}")
113
+ print(f"Started: {status['start_time']}")
114
+ print(f"Epochs: {status['epochs_completed']}")
115
+ print(f"Current Loss: {status['current_loss']:.4f}" if status['current_loss'] else "Current Loss: N/A")
116
+ print(f"Best Loss: {status['best_loss']:.4f}" if status['best_loss'] != float("inf") else "Best Loss: N/A")
117
+ print(f"Last Update: {status['last_update']}")
118
+
119
+ if status['status'] == 'completed':
120
+ print(f"Duration: {status['duration']}")
121
+ elif status['status'] == 'failed':
122
+ print(f"Error: {status['error']}")
123
+
124
+ def create_monitor() -> TrainingMonitor:
125
+ """Create a training monitor instance."""
126
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
127
+ return TrainingMonitor(export_dir)
128
+
129
+ if __name__ == "__main__":
130
+ # Test the monitor
131
+ monitor = create_monitor()
132
+ monitor.print_status()
utils/artifact_manager.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive artifact manager for Dressify.
4
+ Handles packaging, downloading, and organizing all system artifacts.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import shutil
10
+ import zipfile
11
+ import tarfile
12
+ from datetime import datetime
13
+ from typing import Dict, List, Any, Optional
14
+ from pathlib import Path
15
+
16
+ class ArtifactManager:
17
+ """Manages all system artifacts for easy download and upload."""
18
+
19
+ def __init__(self, base_dir: str = "/home/user/app"):
20
+ self.base_dir = base_dir
21
+ self.data_dir = os.path.join(base_dir, "data/Polyvore")
22
+ self.splits_dir = os.path.join(self.data_dir, "splits")
23
+ self.export_dir = os.getenv("EXPORT_DIR", "models/exports")
24
+
25
+ # Default HF repositories - updated to use your specific repos
26
+ self.default_repos = {
27
+ "splits": "Stylique/Dressify-Helper",
28
+ "models": "Stylique/dressify-models",
29
+ "metadata": "Stylique/Dressify-Helper"
30
+ }
31
+
32
+ # Repository organization structure
33
+ self.repo_structure = {
34
+ "Stylique/dressify-models": {
35
+ "description": "Dressify trained models and checkpoints",
36
+ "files": {
37
+ "resnet_item_embedder_best.pth": "ResNet50 item embedder (best checkpoint)",
38
+ "vit_outfit_model_best.pth": "ViT outfit compatibility model (best checkpoint)",
39
+ "resnet_metrics.json": "ResNet training metrics and history",
40
+ "vit_metrics.json": "ViT training metrics and history",
41
+ "model_cards/": "Model documentation and cards"
42
+ }
43
+ },
44
+ "Stylique/Dressify-Helper": {
45
+ "description": "Dressify dataset splits, metadata, and helper files",
46
+ "files": {
47
+ "splits/": "Dataset splits (train/valid/test)",
48
+ "metadata/": "Item metadata and outfit information",
49
+ "configs/": "Training configurations",
50
+ "packages/": "Pre-packaged downloads"
51
+ }
52
+ }
53
+ }
54
+
55
+ def get_artifact_summary(self) -> Dict[str, Any]:
56
+ """Get comprehensive summary of all available artifacts."""
57
+ summary = {
58
+ "timestamp": datetime.now().isoformat(),
59
+ "datasets": self._get_dataset_info(),
60
+ "splits": self._get_splits_info(),
61
+ "models": self._get_models_info(),
62
+ "configs": self._get_configs_info(),
63
+ "metadata": self._get_metadata_info(),
64
+ "hf_repos": self.repo_structure,
65
+ "total_size_mb": 0
66
+ }
67
+
68
+ # Calculate total size
69
+ total_size = 0
70
+ for category in summary.values():
71
+ if isinstance(category, dict) and "size_mb" in category:
72
+ total_size += category["size_mb"]
73
+ summary["total_size_mb"] = round(total_size, 2)
74
+
75
+ return summary
76
+
77
+ def _get_dataset_info(self) -> Dict[str, Any]:
78
+ """Get information about the Polyvore dataset."""
79
+ info = {
80
+ "status": "not_found",
81
+ "size_mb": 0,
82
+ "files": [],
83
+ "images_count": 0
84
+ }
85
+
86
+ if os.path.exists(self.data_dir):
87
+ info["status"] = "available"
88
+
89
+ # Count images
90
+ images_dir = os.path.join(self.data_dir, "images")
91
+ if os.path.exists(images_dir):
92
+ try:
93
+ image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
94
+ info["images_count"] = len(image_files)
95
+ except:
96
+ pass
97
+
98
+ # Calculate size
99
+ try:
100
+ total_size = sum(os.path.getsize(os.path.join(dirpath, filename))
101
+ for dirpath, dirnames, filenames in os.walk(self.data_dir)
102
+ for filename in filenames)
103
+ info["size_mb"] = round(total_size / (1024 * 1024), 2)
104
+ except:
105
+ pass
106
+
107
+ # List key files
108
+ key_files = ["images.zip", "polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
109
+ for file in key_files:
110
+ file_path = os.path.join(self.data_dir, file)
111
+ if os.path.exists(file_path):
112
+ info["files"].append({
113
+ "name": file,
114
+ "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
115
+ "path": file_path
116
+ })
117
+
118
+ return info
119
+
120
+ def _get_splits_info(self) -> Dict[str, Any]:
121
+ """Get information about dataset splits."""
122
+ info = {
123
+ "status": "not_found",
124
+ "size_mb": 0,
125
+ "files": [],
126
+ "splits_available": []
127
+ }
128
+
129
+ if os.path.exists(self.splits_dir):
130
+ info["status"] = "available"
131
+
132
+ split_files = [
133
+ "train.json", "valid.json", "test.json",
134
+ "outfits_train.json", "outfits_valid.json", "outfits_test.json",
135
+ "outfit_triplets_train.json", "outfit_triplets_valid.json", "outfit_triplets_test.json"
136
+ ]
137
+
138
+ total_size = 0
139
+ for file in split_files:
140
+ file_path = os.path.join(self.splits_dir, file)
141
+ if os.path.exists(file_path):
142
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
143
+ total_size += size_mb
144
+ info["files"].append({
145
+ "name": file,
146
+ "size_mb": size_mb,
147
+ "path": file_path
148
+ })
149
+ info["splits_available"].append(file.replace(".json", ""))
150
+
151
+ info["size_mb"] = round(total_size, 2)
152
+
153
+ return info
154
+
155
+ def _get_models_info(self) -> Dict[str, Any]:
156
+ """Get information about trained models."""
157
+ info = {
158
+ "status": "not_found",
159
+ "size_mb": 0,
160
+ "files": [],
161
+ "models_available": []
162
+ }
163
+
164
+ if os.path.exists(self.export_dir):
165
+ info["status"] = "available"
166
+
167
+ model_files = [
168
+ "resnet_item_embedder.pth", "resnet_item_embedder_best.pth",
169
+ "vit_outfit_model.pth", "vit_outfit_model_best.pth",
170
+ "resnet_metrics.json", "vit_metrics.json"
171
+ ]
172
+
173
+ total_size = 0
174
+ for file in model_files:
175
+ file_path = os.path.join(self.export_dir, file)
176
+ if os.path.exists(file_path):
177
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
178
+ total_size += size_mb
179
+ info["files"].append({
180
+ "name": file,
181
+ "size_mb": size_mb,
182
+ "path": file_path,
183
+ "type": "checkpoint" if file.endswith(".pth") else "metrics"
184
+ })
185
+ if file.endswith(".pth"):
186
+ info["models_available"].append(file.replace(".pth", ""))
187
+
188
+ info["size_mb"] = round(total_size, 2)
189
+
190
+ return info
191
+
192
+ def _get_configs_info(self) -> Dict[str, Any]:
193
+ """Get information about configuration files."""
194
+ info = {
195
+ "status": "not_found",
196
+ "size_mb": 0,
197
+ "files": []
198
+ }
199
+
200
+ config_files = [
201
+ "resnet_config_custom.json", "vit_config_custom.json",
202
+ "item.yaml", "outfit.yaml", "default.yaml"
203
+ ]
204
+
205
+ total_size = 0
206
+ for file in config_files:
207
+ # Check export dir first, then configs dir
208
+ file_path = os.path.join(self.export_dir, file)
209
+ if not os.path.exists(file_path):
210
+ file_path = os.path.join("configs", file)
211
+
212
+ if os.path.exists(file_path):
213
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
214
+ total_size += size_mb
215
+ info["files"].append({
216
+ "name": file,
217
+ "size_mb": size_mb,
218
+ "path": file_path
219
+ })
220
+
221
+ if info["files"]:
222
+ info["status"] = "available"
223
+ info["size_mb"] = round(total_size, 2)
224
+
225
+ return info
226
+
227
+ def _get_metadata_info(self) -> Dict[str, Any]:
228
+ """Get information about metadata files."""
229
+ info = {
230
+ "status": "not_found",
231
+ "size_mb": 0,
232
+ "files": []
233
+ }
234
+
235
+ metadata_files = [
236
+ "polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"
237
+ ]
238
+
239
+ total_size = 0
240
+ for file in metadata_files:
241
+ file_path = os.path.join(self.data_dir, file)
242
+ if os.path.exists(file_path):
243
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
244
+ total_size += size_mb
245
+ info["files"].append({
246
+ "name": file,
247
+ "size_mb": size_mb,
248
+ "path": file_path
249
+ })
250
+
251
+ if info["files"]:
252
+ info["status"] = "available"
253
+ info["size_mb"] = round(total_size, 2)
254
+
255
+ return info
256
+
257
+ def create_download_package(self, package_type: str = "complete") -> str:
258
+ """Create a downloadable package of artifacts."""
259
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
260
+
261
+ if package_type == "complete":
262
+ # Complete package with everything
263
+ package_name = f"dressify_complete_{timestamp}"
264
+ package_path = os.path.join(self.export_dir, f"{package_name}.tar.gz")
265
+
266
+ with tarfile.open(package_path, "w:gz") as tar:
267
+ # Add splits
268
+ if os.path.exists(self.splits_dir):
269
+ tar.add(self.splits_dir, arcname="splits")
270
+
271
+ # Add models
272
+ if os.path.exists(self.export_dir):
273
+ for file in os.listdir(self.export_dir):
274
+ if file.endswith((".pth", ".json", ".yaml")):
275
+ tar.add(os.path.join(self.export_dir, file), arcname=f"models/{file}")
276
+
277
+ # Add metadata
278
+ metadata_files = ["polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
279
+ for file in metadata_files:
280
+ file_path = os.path.join(self.data_dir, file)
281
+ if os.path.exists(file_path):
282
+ tar.add(file_path, arcname=f"metadata/{file}")
283
+
284
+ # Add configs
285
+ configs_dir = "configs"
286
+ if os.path.exists(configs_dir):
287
+ tar.add(configs_dir, arcname="configs")
288
+
289
+ elif package_type == "splits_only":
290
+ # Only splits (lightweight)
291
+ package_name = f"dressify_splits_{timestamp}"
292
+ package_path = os.path.join(self.export_dir, f"{package_name}.tar.gz")
293
+
294
+ with tarfile.open(package_path, "w:gz") as tar:
295
+ if os.path.exists(self.splits_dir):
296
+ tar.add(self.splits_dir, arcname="splits")
297
+
298
+ elif package_type == "models_only":
299
+ # Only trained models
300
+ package_name = f"dressify_models_{timestamp}"
301
+ package_path = os.path.join(self.export_dir, f"{package_name}.tar.gz")
302
+
303
+ with tarfile.open(package_path, "w:gz") as tar:
304
+ if os.path.exists(self.export_dir):
305
+ for file in os.listdir(self.export_dir):
306
+ if file.endswith((".pth", ".json")):
307
+ tar.add(os.path.join(self.export_dir, file), arcname=f"models/{file}")
308
+
309
+ else:
310
+ raise ValueError(f"Unknown package type: {package_type}")
311
+
312
+ return package_path
313
+
314
+ def get_downloadable_files(self) -> List[Dict[str, Any]]:
315
+ """Get list of all downloadable files."""
316
+ files = []
317
+
318
+ # Add splits
319
+ if os.path.exists(self.splits_dir):
320
+ for file in os.listdir(self.splits_dir):
321
+ if file.endswith(".json"):
322
+ file_path = os.path.join(self.splits_dir, file)
323
+ files.append({
324
+ "name": f"splits/{file}",
325
+ "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
326
+ "path": file_path,
327
+ "category": "splits",
328
+ "description": f"Dataset split: {file.replace('.json', '')}"
329
+ })
330
+
331
+ # Add models
332
+ if os.path.exists(self.export_dir):
333
+ for file in os.listdir(self.export_dir):
334
+ if file.endswith((".pth", ".json")):
335
+ file_path = os.path.join(self.export_dir, file)
336
+ files.append({
337
+ "name": f"models/{file}",
338
+ "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
339
+ "path": file_path,
340
+ "category": "models",
341
+ "description": "Trained model or metrics"
342
+ })
343
+
344
+ # Add metadata
345
+ metadata_files = ["polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
346
+ for file in metadata_files:
347
+ file_path = os.path.join(self.data_dir, file)
348
+ if os.path.exists(file_path):
349
+ files.append({
350
+ "name": f"metadata/{file}",
351
+ "size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
352
+ "path": file_path,
353
+ "category": "metadata",
354
+ "description": "Dataset metadata"
355
+ })
356
+
357
+ return files
358
+
359
+ def create_hf_upload_plan(self) -> Dict[str, Any]:
360
+ """Create a plan for uploading to HF Hub."""
361
+ plan = {
362
+ "Stylique/dressify-models": {
363
+ "description": "Upload trained models and checkpoints",
364
+ "files_to_upload": [],
365
+ "estimated_size_mb": 0
366
+ },
367
+ "Stylique/Dressify-Helper": {
368
+ "description": "Upload dataset splits and metadata",
369
+ "files_to_upload": [],
370
+ "estimated_size_mb": 0
371
+ }
372
+ }
373
+
374
+ # Plan for models repo
375
+ if os.path.exists(self.export_dir):
376
+ for file in os.listdir(self.export_dir):
377
+ if file.endswith((".pth", ".json")):
378
+ file_path = os.path.join(self.export_dir, file)
379
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
380
+ plan["Stylique/dressify-models"]["files_to_upload"].append({
381
+ "name": file,
382
+ "path": file_path,
383
+ "size_mb": size_mb
384
+ })
385
+ plan["Stylique/dressify-models"]["estimated_size_mb"] += size_mb
386
+
387
+ # Plan for helper repo
388
+ if os.path.exists(self.splits_dir):
389
+ for file in os.listdir(self.splits_dir):
390
+ if file.endswith(".json"):
391
+ file_path = os.path.join(self.splits_dir, file)
392
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
393
+ plan["Stylique/Dressify-Helper"]["files_to_upload"].append({
394
+ "name": f"splits/{file}",
395
+ "path": file_path,
396
+ "size_mb": size_mb
397
+ })
398
+ plan["Stylique/Dressify-Helper"]["estimated_size_mb"] += size_mb
399
+
400
+ # Add metadata files
401
+ metadata_files = ["polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
402
+ for file in metadata_files:
403
+ file_path = os.path.join(self.data_dir, file)
404
+ if os.path.exists(file_path):
405
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
406
+ plan["Stylique/Dressify-Helper"]["files_to_upload"].append({
407
+ "name": f"metadata/{file}",
408
+ "path": file_path,
409
+ "size_mb": size_mb
410
+ })
411
+ plan["Stylique/Dressify-Helper"]["estimated_size_mb"] += size_mb
412
+
413
+ return plan
414
+
415
+ def create_artifact_manager() -> ArtifactManager:
416
+ """Create an artifact manager instance."""
417
+ return ArtifactManager()
utils/export.py CHANGED
@@ -5,11 +5,19 @@ import torch
5
 
6
 
7
  def ensure_export_dir(path: str) -> str:
 
8
  os.makedirs(path, exist_ok=True)
9
  return path
10
 
11
 
 
 
 
 
 
 
12
  def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
 
13
  model.eval()
14
  traced = torch.jit.trace(model, example_inputs)
15
  torch.jit.save(traced, out_path)
@@ -17,6 +25,7 @@ def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out
17
 
18
 
19
  def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
 
20
  model.eval()
21
  torch.onnx.export(
22
  model,
@@ -32,6 +41,20 @@ def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path:
32
  return out_path
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
 
 
5
 
6
 
7
  def ensure_export_dir(path: str) -> str:
8
+ """Create export directory and all parent directories if they don't exist."""
9
  os.makedirs(path, exist_ok=True)
10
  return path
11
 
12
 
13
+ def get_export_dir() -> str:
14
+ """Get the default export directory, creating it if necessary."""
15
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
16
+ return ensure_export_dir(export_dir)
17
+
18
+
19
  def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
20
+ """Export model to TorchScript format."""
21
  model.eval()
22
  traced = torch.jit.trace(model, example_inputs)
23
  torch.jit.save(traced, out_path)
 
25
 
26
 
27
  def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
28
+ """Export model to ONNX format."""
29
  model.eval()
30
  torch.onnx.export(
31
  model,
 
41
  return out_path
42
 
43
 
44
+ def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str:
45
+ """Save model checkpoint with metadata."""
46
+ # Ensure directory exists
47
+ os.makedirs(os.path.dirname(path), exist_ok=True)
48
+
49
+ # Save checkpoint
50
+ checkpoint = {
51
+ "state_dict": model.state_dict(),
52
+ **kwargs
53
+ }
54
+ torch.save(checkpoint, path)
55
+ return path
56
+
57
+
58
 
59
 
60
 
utils/hf_hub_integration.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Hub integration for Dressify.
4
+ Handles uploading artifacts to specific HF repositories.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import shutil
10
+ from datetime import datetime
11
+ from typing import Dict, List, Any, Optional
12
+ from huggingface_hub import HfApi, create_repo, upload_file, upload_folder
13
+ from pathlib import Path
14
+
15
+ class HFHubIntegration:
16
+ """Integrates with Hugging Face Hub for artifact management."""
17
+
18
+ def __init__(self, token: str = None):
19
+ self.api = HfApi(token=token)
20
+ self.token = token
21
+
22
+ # Your specific repositories
23
+ self.repos = {
24
+ "models": "Stylique/dressify-models",
25
+ "helper": "Stylique/Dressify-Helper"
26
+ }
27
+
28
+ # Repository descriptions and metadata
29
+ self.repo_metadata = {
30
+ "Stylique/dressify-models": {
31
+ "description": "Dressify trained models and checkpoints for outfit recommendation",
32
+ "tags": ["computer-vision", "fashion", "outfit-recommendation", "deep-learning"],
33
+ "license": "mit",
34
+ "language": "en"
35
+ },
36
+ "Stylique/Dressify-Helper": {
37
+ "description": "Dressify dataset splits, metadata, and helper files",
38
+ "tags": ["dataset", "fashion", "outfit-recommendation", "polyvore"],
39
+ "license": "mit",
40
+ "language": "en"
41
+ }
42
+ }
43
+
44
+ def ensure_repos_exist(self) -> Dict[str, bool]:
45
+ """Ensure all required repositories exist, create if they don't."""
46
+ results = {}
47
+
48
+ for repo_id in self.repos.values():
49
+ try:
50
+ # Try to get repo info
51
+ repo_info = self.api.repo_info(repo_id)
52
+ results[repo_id] = True
53
+ print(f"βœ… Repository exists: {repo_id}")
54
+ except Exception:
55
+ try:
56
+ # Create repository
57
+ if "models" in repo_id:
58
+ create_repo(
59
+ repo_id=repo_id,
60
+ repo_type="model",
61
+ token=self.token,
62
+ description=self.repo_metadata[repo_id]["description"],
63
+ license=self.repo_metadata[repo_id]["license"],
64
+ tags=self.repo_metadata[repo_id]["tags"]
65
+ )
66
+ else:
67
+ create_repo(
68
+ repo_id=repo_id,
69
+ repo_type="dataset",
70
+ token=self.token,
71
+ description=self.repo_metadata[repo_id]["description"],
72
+ license=self.repo_metadata[repo_id]["license"],
73
+ tags=self.repo_metadata[repo_id]["tags"]
74
+ )
75
+
76
+ results[repo_id] = True
77
+ print(f"βœ… Created repository: {repo_id}")
78
+ except Exception as e:
79
+ results[repo_id] = False
80
+ print(f"❌ Failed to create repository {repo_id}: {e}")
81
+
82
+ return results
83
+
84
+ def upload_models_to_hf(self, models_dir: str = None) -> Dict[str, Any]:
85
+ """Upload trained models to the models repository."""
86
+ if models_dir is None:
87
+ models_dir = os.getenv("EXPORT_DIR", "models/exports")
88
+
89
+ if not os.path.exists(models_dir):
90
+ return {"success": False, "error": f"Models directory not found: {models_dir}"}
91
+
92
+ try:
93
+ print(f"πŸš€ Uploading models to {self.repos['models']}...")
94
+
95
+ # Files to upload
96
+ model_files = [
97
+ "resnet_item_embedder_best.pth",
98
+ "vit_outfit_model_best.pth",
99
+ "resnet_metrics.json",
100
+ "vit_metrics.json"
101
+ ]
102
+
103
+ uploaded_files = []
104
+ total_size = 0
105
+
106
+ for file in model_files:
107
+ file_path = os.path.join(models_dir, file)
108
+ if os.path.exists(file_path):
109
+ try:
110
+ # Upload file
111
+ self.api.upload_file(
112
+ path_or_fileobj=file_path,
113
+ path_in_repo=file,
114
+ repo_id=self.repos['models'],
115
+ token=self.token
116
+ )
117
+
118
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
119
+ total_size += size_mb
120
+ uploaded_files.append({
121
+ "name": file,
122
+ "size_mb": size_mb,
123
+ "status": "uploaded"
124
+ })
125
+
126
+ print(f"βœ… Uploaded: {file} ({size_mb} MB)")
127
+
128
+ except Exception as e:
129
+ uploaded_files.append({
130
+ "name": file,
131
+ "status": "failed",
132
+ "error": str(e)
133
+ })
134
+ print(f"❌ Failed to upload {file}: {e}")
135
+
136
+ # Create model card
137
+ self._create_model_card()
138
+
139
+ result = {
140
+ "success": True,
141
+ "repository": self.repos['models'],
142
+ "uploaded_files": uploaded_files,
143
+ "total_size_mb": total_size,
144
+ "timestamp": datetime.now().isoformat()
145
+ }
146
+
147
+ print(f"πŸŽ‰ Models upload completed! Total size: {total_size} MB")
148
+ return result
149
+
150
+ except Exception as e:
151
+ return {"success": False, "error": str(e)}
152
+
153
+ def upload_splits_to_hf(self, splits_dir: str = None) -> Dict[str, Any]:
154
+ """Upload dataset splits to the helper repository."""
155
+ if splits_dir is None:
156
+ splits_dir = os.path.join(os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"), "splits")
157
+
158
+ if not os.path.exists(splits_dir):
159
+ return {"success": False, "error": f"Splits directory not found: {splits_dir}"}
160
+
161
+ try:
162
+ print(f"πŸš€ Uploading splits to {self.repos['helper']}...")
163
+
164
+ # Upload entire splits directory
165
+ self.api.upload_folder(
166
+ folder_path=splits_dir,
167
+ path_in_repo="splits",
168
+ repo_id=self.repos['helper'],
169
+ token=self.token
170
+ )
171
+
172
+ # Calculate total size
173
+ total_size = 0
174
+ for root, dirs, files in os.walk(splits_dir):
175
+ for file in files:
176
+ file_path = os.path.join(root, file)
177
+ total_size += os.path.getsize(file_path)
178
+
179
+ total_size_mb = round(total_size / (1024 * 1024), 2)
180
+
181
+ result = {
182
+ "success": True,
183
+ "repository": self.repos['helper'],
184
+ "uploaded_folder": "splits",
185
+ "total_size_mb": total_size_mb,
186
+ "timestamp": datetime.now().isoformat()
187
+ }
188
+
189
+ print(f"πŸŽ‰ Splits upload completed! Total size: {total_size_mb} MB")
190
+ return result
191
+
192
+ except Exception as e:
193
+ return {"success": False, "error": str(e)}
194
+
195
+ def upload_metadata_to_hf(self, data_dir: str = None) -> Dict[str, Any]:
196
+ """Upload metadata files to the helper repository."""
197
+ if data_dir is None:
198
+ data_dir = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")
199
+
200
+ if not os.path.exists(data_dir):
201
+ return {"success": False, "error": f"Data directory not found: {data_dir}"}
202
+
203
+ try:
204
+ print(f"πŸš€ Uploading metadata to {self.repos['helper']}...")
205
+
206
+ # Metadata files to upload
207
+ metadata_files = [
208
+ "polyvore_item_metadata.json",
209
+ "polyvore_outfit_titles.json",
210
+ "categories.csv"
211
+ ]
212
+
213
+ uploaded_files = []
214
+ total_size = 0
215
+
216
+ for file in metadata_files:
217
+ file_path = os.path.join(data_dir, file)
218
+ if os.path.exists(file_path):
219
+ try:
220
+ # Upload to metadata subfolder
221
+ self.api.upload_file(
222
+ path_or_fileobj=file_path,
223
+ path_in_repo=f"metadata/{file}",
224
+ repo_id=self.repos['helper'],
225
+ token=self.token
226
+ )
227
+
228
+ size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
229
+ total_size += size_mb
230
+ uploaded_files.append({
231
+ "name": file,
232
+ "size_mb": size_mb,
233
+ "status": "uploaded"
234
+ })
235
+
236
+ print(f"βœ… Uploaded: {file} ({size_mb} MB)")
237
+
238
+ except Exception as e:
239
+ uploaded_files.append({
240
+ "name": file,
241
+ "status": "failed",
242
+ "error": str(e)
243
+ })
244
+ print(f"❌ Failed to upload {file}: {e}")
245
+
246
+ result = {
247
+ "success": True,
248
+ "repository": self.repos['helper'],
249
+ "uploaded_files": uploaded_files,
250
+ "total_size_mb": total_size,
251
+ "timestamp": datetime.now().isoformat()
252
+ }
253
+
254
+ print(f"πŸŽ‰ Metadata upload completed! Total size: {total_size} MB")
255
+ return result
256
+
257
+ except Exception as e:
258
+ return {"success": False, "error": str(e)}
259
+
260
+ def upload_everything_to_hf(self) -> Dict[str, Any]:
261
+ """Upload all artifacts to HF Hub."""
262
+ print("πŸš€ Starting comprehensive upload to HF Hub...")
263
+
264
+ # Ensure repositories exist
265
+ repo_status = self.ensure_repos_exist()
266
+ if not all(repo_status.values()):
267
+ return {"success": False, "error": "Failed to ensure repositories exist"}
268
+
269
+ # Upload everything
270
+ results = {
271
+ "models": self.upload_models_to_hf(),
272
+ "splits": self.upload_splits_to_hf(),
273
+ "metadata": self.upload_metadata_to_hf(),
274
+ "timestamp": datetime.now().isoformat()
275
+ }
276
+
277
+ # Summary
278
+ success_count = sum(1 for r in results.values() if isinstance(r, dict) and r.get("success", False))
279
+ total_count = len([r for r in results.values() if isinstance(r, dict)])
280
+
281
+ print(f"\nπŸ“Š Upload Summary: {success_count}/{total_count} successful")
282
+ for category, result in results.items():
283
+ if isinstance(result, dict):
284
+ status = "βœ…" if result.get("success", False) else "❌"
285
+ print(f" {status} {category}")
286
+
287
+ return results
288
+
289
+ def _create_model_card(self):
290
+ """Create a model card for the models repository."""
291
+ model_card_content = """---
292
+ language: en
293
+ license: mit
294
+ tags:
295
+ - computer-vision
296
+ - fashion
297
+ - outfit-recommendation
298
+ - deep-learning
299
+ - resnet
300
+ - vision-transformer
301
+ ---
302
+
303
+ # Dressify Outfit Recommendation Models
304
+
305
+ This repository contains the trained models for the Dressify outfit recommendation system.
306
+
307
+ ## Models
308
+
309
+ ### ResNet Item Embedder
310
+ - **Architecture**: ResNet50 with custom projection head
311
+ - **Purpose**: Generate 512-dimensional embeddings for fashion items
312
+ - **Training**: Triplet loss with semi-hard negative mining
313
+ - **Input**: Fashion item images (224x224)
314
+ - **Output**: L2-normalized 512D embeddings
315
+
316
+ ### ViT Outfit Compatibility Model
317
+ - **Architecture**: Vision Transformer encoder
318
+ - **Purpose**: Score outfit compatibility from item embeddings
319
+ - **Training**: Triplet loss with cosine distance
320
+ - **Input**: Variable-length sequence of item embeddings
321
+ - **Output**: Compatibility score (0-1)
322
+
323
+ ## Usage
324
+
325
+ ```python
326
+ from huggingface_hub import hf_hub_download
327
+ import torch
328
+
329
+ # Download models
330
+ resnet_path = hf_hub_download(repo_id="Stylique/dressify-models", filename="resnet_item_embedder_best.pth")
331
+ vit_path = hf_hub_download(repo_id="Stylique/dressify-models", filename="vit_outfit_model_best.pth")
332
+
333
+ # Load models
334
+ resnet_model = torch.load(resnet_path)
335
+ vit_model = torch.load(vit_path)
336
+ ```
337
+
338
+ ## Training Details
339
+
340
+ - **Dataset**: Polyvore Outfits (Stylique/Polyvore)
341
+ - **Loss**: Triplet margin loss
342
+ - **Optimizer**: AdamW
343
+ - **Mixed Precision**: Enabled
344
+ - **Hardware**: NVIDIA GPU with CUDA
345
+
346
+ ## Performance
347
+
348
+ - **ResNet**: ~25M parameters, fast inference
349
+ - **ViT**: ~12M parameters, efficient outfit scoring
350
+ - **Memory**: Optimized for deployment on Hugging Face Spaces
351
+
352
+ ## Citation
353
+
354
+ If you use these models in your research, please cite:
355
+
356
+ ```bibtex
357
+ @misc{dressify2024,
358
+ title={Dressify: Deep Learning for Fashion Outfit Recommendation},
359
+ author={Stylique},
360
+ year={2024},
361
+ url={https://huggingface.co/Stylique/dressify-models}
362
+ }
363
+ ```
364
+ """
365
+
366
+ # Save model card
367
+ model_card_path = "model_card.md"
368
+ with open(model_card_path, 'w') as f:
369
+ f.write(model_card_content)
370
+
371
+ # Upload model card
372
+ try:
373
+ self.api.upload_file(
374
+ path_or_fileobj=model_card_path,
375
+ path_in_repo="README.md",
376
+ repo_id=self.repos['models'],
377
+ token=self.token
378
+ )
379
+ print("βœ… Model card uploaded")
380
+
381
+ # Clean up
382
+ os.remove(model_card_path)
383
+ except Exception as e:
384
+ print(f"⚠️ Failed to upload model card: {e}")
385
+
386
+ def get_upload_status(self) -> Dict[str, Any]:
387
+ """Get current upload status and repository information."""
388
+ status = {
389
+ "repositories": {},
390
+ "last_upload": None,
391
+ "total_uploads": 0
392
+ }
393
+
394
+ for repo_id in self.repos.values():
395
+ try:
396
+ repo_info = self.api.repo_info(repo_id)
397
+ status["repositories"][repo_id] = {
398
+ "exists": True,
399
+ "last_modified": repo_info.last_modified.isoformat() if repo_info.last_modified else None,
400
+ "size": repo_info.size_on_disk if hasattr(repo_info, 'size_on_disk') else None
401
+ }
402
+ except Exception:
403
+ status["repositories"][repo_id] = {
404
+ "exists": False,
405
+ "last_modified": None,
406
+ "size": None
407
+ }
408
+
409
+ return status
410
+
411
+ def create_hf_integration(token: str = None) -> HFHubIntegration:
412
+ """Create an HF Hub integration instance."""
413
+ return HFHubIntegration(token=token)
utils/runtime_fetcher.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Runtime artifact fetcher for Dressify.
4
+ Downloads pre-processed artifacts from Hugging Face Hub to avoid reprocessing.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import shutil
10
+ import tarfile
11
+ import zipfile
12
+ from pathlib import Path
13
+ from typing import Dict, List, Any, Optional
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+
16
+ class RuntimeArtifactFetcher:
17
+ """Fetches artifacts from HF Hub at runtime to avoid reprocessing."""
18
+
19
+ def __init__(self, base_dir: str = "/home/user/app"):
20
+ self.base_dir = base_dir
21
+ self.data_dir = os.path.join(base_dir, "data/Polyvore")
22
+ self.splits_dir = os.path.join(self.data_dir, "splits")
23
+ self.export_dir = os.getenv("EXPORT_DIR", "models/exports")
24
+
25
+ # Default HF repositories - updated to use your specific repos
26
+ self.default_repos = {
27
+ "splits": "Stylique/Dressify-Helper",
28
+ "models": "Stylique/dressify-models",
29
+ "metadata": "Stylique/Dressify-Helper"
30
+ }
31
+
32
+ def check_artifacts_needed(self) -> Dict[str, Any]:
33
+ """Check what artifacts need to be fetched."""
34
+ needs = {
35
+ "splits": False,
36
+ "models": False,
37
+ "metadata": False,
38
+ "total_size_mb": 0
39
+ }
40
+
41
+ # Check splits
42
+ if not os.path.exists(self.splits_dir) or not self._has_complete_splits():
43
+ needs["splits"] = True
44
+ needs["total_size_mb"] += 50 # Estimate splits size
45
+
46
+ # Check models
47
+ if not os.path.exists(self.export_dir) or not self._has_trained_models():
48
+ needs["models"] = True
49
+ needs["total_size_mb"] += 200 # Estimate models size
50
+
51
+ # Check metadata
52
+ if not self._has_complete_metadata():
53
+ needs["metadata"] = True
54
+ needs["total_size_mb"] += 100 # Estimate metadata size
55
+
56
+ return needs
57
+
58
+ def _has_complete_splits(self) -> bool:
59
+ """Check if complete splits are available."""
60
+ required_files = [
61
+ "train.json", "valid.json", "test.json",
62
+ "outfit_triplets_train.json", "outfit_triplets_valid.json", "outfit_triplets_test.json"
63
+ ]
64
+
65
+ for file in required_files:
66
+ if not os.path.exists(os.path.join(self.splits_dir, file)):
67
+ return False
68
+ return True
69
+
70
+ def _has_trained_models(self) -> bool:
71
+ """Check if trained models are available."""
72
+ required_files = [
73
+ "resnet_item_embedder_best.pth",
74
+ "vit_outfit_model_best.pth"
75
+ ]
76
+
77
+ for file in required_files:
78
+ if not os.path.exists(os.path.join(self.export_dir, file)):
79
+ return False
80
+ return True
81
+
82
+ def _has_complete_metadata(self) -> bool:
83
+ """Check if complete metadata is available."""
84
+ required_files = [
85
+ "polyvore_item_metadata.json",
86
+ "polyvore_outfit_titles.json",
87
+ "categories.csv"
88
+ ]
89
+
90
+ for file in required_files:
91
+ if not os.path.exists(os.path.join(self.data_dir, file)):
92
+ return False
93
+ return True
94
+
95
+ def fetch_splits_from_hf(self, repo: str = None, token: str = None) -> bool:
96
+ """Fetch dataset splits from HF Hub."""
97
+ if repo is None:
98
+ repo = self.default_repos["splits"]
99
+
100
+ try:
101
+ print(f"πŸ”„ Fetching splits from {repo}...")
102
+
103
+ # Create splits directory
104
+ os.makedirs(self.splits_dir, exist_ok=True)
105
+
106
+ # Download splits files
107
+ split_files = [
108
+ "train.json", "valid.json", "test.json",
109
+ "outfits_train.json", "outfits_valid.json", "outfits_test.json",
110
+ "outfit_triplets_train.json", "outfit_triplets_valid.json", "outfit_triplets_test.json"
111
+ ]
112
+
113
+ for file in split_files:
114
+ try:
115
+ local_path = hf_hub_download(
116
+ repo_id=repo,
117
+ filename=f"splits/{file}",
118
+ local_dir=self.splits_dir,
119
+ token=token
120
+ )
121
+ print(f"βœ… Downloaded: {file}")
122
+ except Exception as e:
123
+ print(f"⚠️ Failed to download {file}: {e}")
124
+
125
+ print(f"βœ… Splits fetched successfully to {self.splits_dir}")
126
+ return True
127
+
128
+ except Exception as e:
129
+ print(f"❌ Failed to fetch splits: {e}")
130
+ return False
131
+
132
+ def fetch_models_from_hf(self, repo: str = None, token: str = None) -> bool:
133
+ """Fetch trained models from HF Hub."""
134
+ if repo is None:
135
+ repo = self.default_repos["models"]
136
+
137
+ try:
138
+ print(f"πŸ”„ Fetching models from {repo}...")
139
+
140
+ # Create export directory
141
+ os.makedirs(self.export_dir, exist_ok=True)
142
+
143
+ # Download model files
144
+ model_files = [
145
+ "resnet_item_embedder_best.pth",
146
+ "vit_outfit_model_best.pth",
147
+ "resnet_metrics.json",
148
+ "vit_metrics.json"
149
+ ]
150
+
151
+ for file in model_files:
152
+ try:
153
+ local_path = hf_hub_download(
154
+ repo_id=repo,
155
+ filename=file,
156
+ local_dir=self.export_dir,
157
+ token=token
158
+ )
159
+ print(f"βœ… Downloaded: {file}")
160
+ except Exception as e:
161
+ print(f"⚠️ Failed to download {file}: {e}")
162
+
163
+ print(f"βœ… Models fetched successfully to {self.export_dir}")
164
+ return True
165
+
166
+ except Exception as e:
167
+ print(f"❌ Failed to fetch models: {e}")
168
+ return False
169
+
170
+ def fetch_metadata_from_hf(self, repo: str = None, token: str = None) -> bool:
171
+ """Fetch metadata from HF Hub."""
172
+ if repo is None:
173
+ repo = self.default_repos["metadata"]
174
+
175
+ try:
176
+ print(f"πŸ”„ Fetching metadata from {repo}...")
177
+
178
+ # Create data directory
179
+ os.makedirs(self.data_dir, exist_ok=True)
180
+
181
+ # Download metadata files
182
+ metadata_files = [
183
+ "polyvore_item_metadata.json",
184
+ "polyvore_outfit_titles.json",
185
+ "categories.csv"
186
+ ]
187
+
188
+ for file in metadata_files:
189
+ try:
190
+ local_path = hf_hub_download(
191
+ repo_id=repo,
192
+ filename=f"metadata/{file}",
193
+ local_dir=self.data_dir,
194
+ token=token
195
+ )
196
+ print(f"βœ… Downloaded: {file}")
197
+ except Exception as e:
198
+ print(f"⚠️ Failed to download {file}: {e}")
199
+
200
+ print(f"βœ… Metadata fetched successfully to {self.data_dir}")
201
+ return True
202
+
203
+ except Exception as e:
204
+ print(f"❌ Failed to fetch metadata: {e}")
205
+ return False
206
+
207
+ def fetch_everything_from_hf(self, splits_repo: str = None, models_repo: str = None,
208
+ metadata_repo: str = None, token: str = None) -> Dict[str, bool]:
209
+ """Fetch all artifacts from HF Hub."""
210
+ results = {}
211
+
212
+ print("πŸš€ Starting comprehensive artifact fetch from HF Hub...")
213
+
214
+ # Fetch splits
215
+ results["splits"] = self.fetch_splits_from_hf(splits_repo, token)
216
+
217
+ # Fetch models
218
+ results["models"] = self.fetch_models_from_hf(models_repo, token)
219
+
220
+ # Fetch metadata
221
+ results["metadata"] = self.fetch_metadata_from_hf(metadata_repo, token)
222
+
223
+ # Summary
224
+ success_count = sum(results.values())
225
+ total_count = len(results)
226
+
227
+ print(f"\nπŸ“Š Fetch Summary: {success_count}/{total_count} successful")
228
+ for artifact, success in results.items():
229
+ status = "βœ…" if success else "❌"
230
+ print(f" {status} {artifact}")
231
+
232
+ return results
233
+
234
+ def download_and_extract_package(self, package_path: str, extract_to: str = None) -> bool:
235
+ """Download and extract a package from HF Hub."""
236
+ try:
237
+ if extract_to is None:
238
+ extract_to = self.base_dir
239
+
240
+ print(f"πŸ”„ Downloading and extracting package: {package_path}")
241
+
242
+ # Download the package
243
+ local_path = hf_hub_download(
244
+ repo_id="Stylique/Dressify-Helper",
245
+ filename=f"packages/{os.path.basename(package_path)}",
246
+ local_dir=extract_to,
247
+ token=None
248
+ )
249
+
250
+ # Extract based on file type
251
+ if package_path.endswith(".tar.gz"):
252
+ with tarfile.open(local_path, 'r:gz') as tar:
253
+ tar.extractall(extract_to)
254
+ elif package_path.endswith(".zip"):
255
+ with zipfile.ZipFile(local_path, 'r') as zipf:
256
+ zipf.extractall(extract_to)
257
+
258
+ print(f"βœ… Package extracted to {extract_to}")
259
+ return True
260
+
261
+ except Exception as e:
262
+ print(f"❌ Failed to download/extract package: {e}")
263
+ return False
264
+
265
+ def get_fetch_status(self) -> Dict[str, Any]:
266
+ """Get current fetch status."""
267
+ return {
268
+ "splits_available": self._has_complete_splits(),
269
+ "models_available": self._has_trained_models(),
270
+ "metadata_available": self._has_complete_metadata(),
271
+ "artifacts_needed": self.check_artifacts_needed(),
272
+ "base_dir": self.base_dir,
273
+ "splits_dir": self.splits_dir,
274
+ "export_dir": self.export_dir,
275
+ "hf_repos": self.default_repos
276
+ }
277
+
278
+ def create_runtime_fetcher() -> RuntimeArtifactFetcher:
279
+ """Create a runtime fetcher instance."""
280
+ return RuntimeArtifactFetcher()
281
+
282
+ def auto_fetch_if_needed(token: str = None) -> Dict[str, bool]:
283
+ """Automatically fetch artifacts if they're needed."""
284
+ fetcher = create_runtime_fetcher()
285
+
286
+ # Check what's needed
287
+ needs = fetcher.check_artifacts_needed()
288
+
289
+ if not any([needs["splits"], needs["models"], needs["metadata"]]):
290
+ print("βœ… All artifacts are already available - no fetching needed")
291
+ return {"splits": True, "models": True, "metadata": True}
292
+
293
+ print(f"πŸ”„ Auto-fetching needed artifacts (estimated size: {needs['total_size_mb']} MB)")
294
+
295
+ # Fetch what's needed
296
+ results = {}
297
+ if needs["splits"]:
298
+ results["splits"] = fetcher.fetch_splits_from_hf(token=token)
299
+
300
+ if needs["models"]:
301
+ results["models"] = fetcher.fetch_models_from_hf(token=token)
302
+
303
+ if needs["metadata"]:
304
+ results["metadata"] = fetcher.fetch_metadata_from_hf(token=token)
305
+
306
+ return results
307
+
308
+ if __name__ == "__main__":
309
+ # Test the fetcher
310
+ fetcher = create_runtime_fetcher()
311
+ status = fetcher.get_fetch_status()
312
+ print("Current fetch status:", json.dumps(status, indent=2))