Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
24ea486
1
Parent(s):
6086b2f
Next level fix
Browse files- app.py +297 -40
- artifact_management_ui.py +427 -0
- models/resnet_embedder.py +3 -0
- test_training.py +178 -0
- train_resnet.py +92 -49
- train_vit_triplet.py +139 -81
- training_monitor.py +132 -0
- utils/artifact_manager.py +417 -0
- utils/export.py +23 -0
- utils/hf_hub_integration.py +413 -0
- utils/runtime_fetcher.py +312 -0
app.py
CHANGED
|
@@ -12,10 +12,191 @@ from pydantic import BaseModel
|
|
| 12 |
from PIL import Image
|
| 13 |
from starlette.staticfiles import StaticFiles
|
| 14 |
import threading
|
|
|
|
| 15 |
|
| 16 |
from inference import InferenceService
|
| 17 |
from utils.data_fetch import ensure_dataset_ready
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
AI_API_KEY = os.getenv("AI_API_KEY")
|
| 21 |
|
|
@@ -254,7 +435,9 @@ def start_training_advanced(
|
|
| 254 |
if not DATASET_ROOT:
|
| 255 |
return "β Dataset not ready. Please wait for bootstrap to complete."
|
| 256 |
|
|
|
|
| 257 |
def _runner():
|
|
|
|
| 258 |
try:
|
| 259 |
import subprocess
|
| 260 |
import json
|
|
@@ -327,10 +510,10 @@ def start_training_advanced(
|
|
| 327 |
json.dump(vit_config, f, indent=2)
|
| 328 |
|
| 329 |
# Train ResNet with custom parameters
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
|
| 335 |
resnet_cmd = [
|
| 336 |
"python", "train_resnet.py",
|
|
@@ -350,16 +533,16 @@ def start_training_advanced(
|
|
| 350 |
result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
|
| 351 |
|
| 352 |
if result.returncode == 0:
|
| 353 |
-
|
| 354 |
else:
|
| 355 |
-
|
| 356 |
return
|
| 357 |
|
| 358 |
# Train ViT with custom parameters
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
|
| 364 |
vit_cmd = [
|
| 365 |
"python", "train_vit_triplet.py",
|
|
@@ -376,47 +559,87 @@ def start_training_advanced(
|
|
| 376 |
result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
|
| 377 |
|
| 378 |
if result.returncode == 0:
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
service.reload_models()
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
else:
|
| 385 |
-
|
| 386 |
|
| 387 |
except Exception as e:
|
| 388 |
-
|
| 389 |
|
| 390 |
threading.Thread(target=_runner, daemon=True).start()
|
| 391 |
-
return
|
| 392 |
|
| 393 |
|
| 394 |
def start_training_simple(res_epochs: int, vit_epochs: int):
|
| 395 |
"""Start simple training with basic parameters."""
|
|
|
|
| 396 |
def _runner():
|
|
|
|
| 397 |
try:
|
| 398 |
import subprocess
|
| 399 |
if not DATASET_ROOT:
|
| 400 |
-
|
| 401 |
return
|
| 402 |
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 403 |
os.makedirs(export_dir, exist_ok=True)
|
| 404 |
-
|
| 405 |
subprocess.run([
|
| 406 |
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 407 |
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 408 |
], check=False)
|
| 409 |
-
|
| 410 |
subprocess.run([
|
| 411 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 412 |
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 413 |
], check=False)
|
| 414 |
service.reload_models()
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
except Exception as e:
|
| 417 |
-
|
| 418 |
threading.Thread(target=_runner, daemon=True).start()
|
| 419 |
-
return
|
| 420 |
|
| 421 |
|
| 422 |
with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
|
|
@@ -563,6 +786,56 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 563 |
outputs=train_log
|
| 564 |
)
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
with gr.Tab("π§ Simple Training"):
|
| 567 |
gr.Markdown("### π Quick Training with Default Parameters\nFast training with proven configurations for immediate results.")
|
| 568 |
epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
|
|
@@ -577,23 +850,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 577 |
btn = gr.Button("Compute Embeddings")
|
| 578 |
btn.click(fn=gradio_embed, inputs=inp, outputs=out)
|
| 579 |
|
| 580 |
-
|
| 581 |
-
gr.Markdown("### π¦ Download Trained Models and Artifacts\nAccess all exported models, checkpoints, and training metrics.")
|
| 582 |
-
file_list = gr.JSON(label="Available Artifacts")
|
| 583 |
-
def list_artifacts_for_ui():
|
| 584 |
-
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 585 |
-
files = []
|
| 586 |
-
if os.path.isdir(export_dir):
|
| 587 |
-
for fn in os.listdir(export_dir):
|
| 588 |
-
if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
|
| 589 |
-
files.append({
|
| 590 |
-
"name": fn,
|
| 591 |
-
"path": f"{export_dir}/{fn}",
|
| 592 |
-
"url": f"/files/{fn}",
|
| 593 |
-
})
|
| 594 |
-
return {"artifacts": files}
|
| 595 |
-
refresh = gr.Button("π Refresh Artifacts")
|
| 596 |
-
refresh.click(fn=lambda: list_artifacts_for_ui(), inputs=[], outputs=file_list)
|
| 597 |
|
| 598 |
with gr.Tab("π Status"):
|
| 599 |
gr.Markdown("### π¦ System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.")
|
|
|
|
| 12 |
from PIL import Image
|
| 13 |
from starlette.staticfiles import StaticFiles
|
| 14 |
import threading
|
| 15 |
+
import json
|
| 16 |
|
| 17 |
from inference import InferenceService
|
| 18 |
from utils.data_fetch import ensure_dataset_ready
|
| 19 |
|
| 20 |
+
# Global state
|
| 21 |
+
BOOT_STATUS = "starting"
|
| 22 |
+
DATASET_ROOT: Optional[str] = None
|
| 23 |
+
|
| 24 |
+
def get_artifact_overview():
|
| 25 |
+
"""Get comprehensive artifact overview."""
|
| 26 |
+
try:
|
| 27 |
+
from utils.artifact_manager import create_artifact_manager
|
| 28 |
+
manager = create_artifact_manager()
|
| 29 |
+
return manager.get_artifact_summary()
|
| 30 |
+
except Exception as e:
|
| 31 |
+
return {"error": str(e)}
|
| 32 |
+
|
| 33 |
+
def export_artifact_summary():
|
| 34 |
+
"""Export artifact summary as JSON file."""
|
| 35 |
+
try:
|
| 36 |
+
from utils.artifact_manager import create_artifact_manager
|
| 37 |
+
manager = create_artifact_manager()
|
| 38 |
+
summary = manager.get_artifact_summary()
|
| 39 |
+
|
| 40 |
+
# Save to exports directory
|
| 41 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 42 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
summary_path = os.path.join(export_dir, "system_summary.json")
|
| 45 |
+
with open(summary_path, 'w') as f:
|
| 46 |
+
json.dump(summary, f, indent=2)
|
| 47 |
+
|
| 48 |
+
return summary_path
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def create_download_package(package_type: str):
|
| 53 |
+
"""Create a downloadable package."""
|
| 54 |
+
try:
|
| 55 |
+
from utils.artifact_manager import create_artifact_manager
|
| 56 |
+
manager = create_artifact_manager()
|
| 57 |
+
|
| 58 |
+
# Extract package type from the dropdown choice
|
| 59 |
+
if "complete" in package_type:
|
| 60 |
+
pkg_type = "complete"
|
| 61 |
+
elif "splits_only" in package_type:
|
| 62 |
+
pkg_type = "splits_only"
|
| 63 |
+
elif "models_only" in package_type:
|
| 64 |
+
pkg_type = "models_only"
|
| 65 |
+
else:
|
| 66 |
+
return f"β Invalid package type: {package_type}", get_available_packages()
|
| 67 |
+
|
| 68 |
+
package_path = manager.create_download_package(pkg_type)
|
| 69 |
+
package_name = os.path.basename(package_path)
|
| 70 |
+
|
| 71 |
+
return f"β
Package created: {package_name}", get_available_packages()
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
return f"β Failed to create package: {e}", get_available_packages()
|
| 75 |
+
|
| 76 |
+
def get_available_packages():
|
| 77 |
+
"""Get list of available packages."""
|
| 78 |
+
try:
|
| 79 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 80 |
+
packages = []
|
| 81 |
+
|
| 82 |
+
if os.path.exists(export_dir):
|
| 83 |
+
for file in os.listdir(export_dir):
|
| 84 |
+
if file.endswith((".tar.gz", ".zip")):
|
| 85 |
+
file_path = os.path.join(export_dir, file)
|
| 86 |
+
packages.append({
|
| 87 |
+
"name": file,
|
| 88 |
+
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
|
| 89 |
+
"path": file_path,
|
| 90 |
+
"url": f"/files/{file}"
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
return {"packages": packages}
|
| 94 |
+
except Exception as e:
|
| 95 |
+
return {"error": str(e)}
|
| 96 |
+
|
| 97 |
+
def get_individual_files():
|
| 98 |
+
"""Get list of individual downloadable files."""
|
| 99 |
+
try:
|
| 100 |
+
from utils.artifact_manager import create_artifact_manager
|
| 101 |
+
manager = create_artifact_manager()
|
| 102 |
+
files = manager.get_downloadable_files()
|
| 103 |
+
|
| 104 |
+
# Group by category
|
| 105 |
+
categorized = {}
|
| 106 |
+
for file in files:
|
| 107 |
+
category = file["category"]
|
| 108 |
+
if category not in categorized:
|
| 109 |
+
categorized[category] = []
|
| 110 |
+
categorized[category].append(file)
|
| 111 |
+
|
| 112 |
+
return categorized
|
| 113 |
+
except Exception as e:
|
| 114 |
+
return {"error": str(e)}
|
| 115 |
+
|
| 116 |
+
def download_all_files():
|
| 117 |
+
"""Download all files as a ZIP archive."""
|
| 118 |
+
try:
|
| 119 |
+
from utils.artifact_manager import create_artifact_manager
|
| 120 |
+
manager = create_artifact_manager()
|
| 121 |
+
files = manager.get_downloadable_files()
|
| 122 |
+
|
| 123 |
+
# Create ZIP with all files
|
| 124 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 125 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 126 |
+
|
| 127 |
+
zip_path = os.path.join(export_dir, "all_artifacts.zip")
|
| 128 |
+
import zipfile
|
| 129 |
+
|
| 130 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 131 |
+
for file in files:
|
| 132 |
+
if os.path.exists(file["path"]):
|
| 133 |
+
zipf.write(file["path"], file["name"])
|
| 134 |
+
|
| 135 |
+
return zip_path
|
| 136 |
+
except Exception as e:
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
def get_training_status():
|
| 140 |
+
"""Get current training status from the monitor."""
|
| 141 |
+
try:
|
| 142 |
+
from training_monitor import create_monitor
|
| 143 |
+
monitor = create_monitor()
|
| 144 |
+
status = monitor.get_status()
|
| 145 |
+
return status if status else {"status": "no-training"}
|
| 146 |
+
except Exception as e:
|
| 147 |
+
return {"status": "error", "error": str(e)}
|
| 148 |
+
|
| 149 |
+
def push_splits_to_hf(token, username):
|
| 150 |
+
"""Push splits to HF Hub."""
|
| 151 |
+
if not token or not username:
|
| 152 |
+
return "β Please provide HF token and username"
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
from utils.hf_hub_integration import create_hf_integration
|
| 156 |
+
hf = create_hf_integration(token)
|
| 157 |
+
result = hf.upload_splits_to_hf()
|
| 158 |
+
|
| 159 |
+
if result.get("success"):
|
| 160 |
+
return f"β
Successfully uploaded splits to {username}/Dressify-Helper"
|
| 161 |
+
else:
|
| 162 |
+
return f"β Failed to upload splits: {result.get('error', 'Unknown error')}"
|
| 163 |
+
except Exception as e:
|
| 164 |
+
return f"β Upload failed: {e}"
|
| 165 |
+
|
| 166 |
+
def push_models_to_hf(token, username):
|
| 167 |
+
"""Push models to HF Hub."""
|
| 168 |
+
if not token or not username:
|
| 169 |
+
return "β Please provide HF token and username"
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
from utils.hf_hub_integration import create_hf_integration
|
| 173 |
+
hf = create_hf_integration(token)
|
| 174 |
+
result = hf.upload_models_to_hf()
|
| 175 |
+
|
| 176 |
+
if result.get("success"):
|
| 177 |
+
return f"β
Successfully uploaded models to {username}/dressify-models"
|
| 178 |
+
else:
|
| 179 |
+
return f"β Failed to upload models: {result.get('error', 'Unknown error')}"
|
| 180 |
+
except Exception as e:
|
| 181 |
+
return f"β Upload failed: {e}"
|
| 182 |
+
|
| 183 |
+
def push_everything_to_hf(token, username):
|
| 184 |
+
"""Push everything to HF Hub."""
|
| 185 |
+
if not token or not username:
|
| 186 |
+
return "β Please provide HF token and username"
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
from utils.hf_hub_integration import create_hf_integration
|
| 190 |
+
hf = create_hf_integration(token)
|
| 191 |
+
result = hf.upload_everything_to_hf()
|
| 192 |
+
|
| 193 |
+
if result.get("success"):
|
| 194 |
+
return f"β
Successfully uploaded everything to HF Hub"
|
| 195 |
+
else:
|
| 196 |
+
return f"β Failed to upload everything: {result.get('error', 'Unknown error')}"
|
| 197 |
+
except Exception as e:
|
| 198 |
+
return f"β Upload failed: {e}"
|
| 199 |
+
|
| 200 |
|
| 201 |
AI_API_KEY = os.getenv("AI_API_KEY")
|
| 202 |
|
|
|
|
| 435 |
if not DATASET_ROOT:
|
| 436 |
return "β Dataset not ready. Please wait for bootstrap to complete."
|
| 437 |
|
| 438 |
+
log_message = "π Advanced training started with custom parameters! Check the log below for progress."
|
| 439 |
def _runner():
|
| 440 |
+
nonlocal log_message
|
| 441 |
try:
|
| 442 |
import subprocess
|
| 443 |
import json
|
|
|
|
| 510 |
json.dump(vit_config, f, indent=2)
|
| 511 |
|
| 512 |
# Train ResNet with custom parameters
|
| 513 |
+
log_message = f"π Starting ResNet training with custom parameters...\n"
|
| 514 |
+
log_message += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
|
| 515 |
+
log_message += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
|
| 516 |
+
log_message += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
|
| 517 |
|
| 518 |
resnet_cmd = [
|
| 519 |
"python", "train_resnet.py",
|
|
|
|
| 533 |
result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
|
| 534 |
|
| 535 |
if result.returncode == 0:
|
| 536 |
+
log_message += "β
ResNet training completed successfully!\n\n"
|
| 537 |
else:
|
| 538 |
+
log_message += f"β ResNet training failed: {result.stderr}\n\n"
|
| 539 |
return
|
| 540 |
|
| 541 |
# Train ViT with custom parameters
|
| 542 |
+
log_message += f"π Starting ViT training with custom parameters...\n"
|
| 543 |
+
log_message += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
|
| 544 |
+
log_message += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
|
| 545 |
+
log_message += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
|
| 546 |
|
| 547 |
vit_cmd = [
|
| 548 |
"python", "train_vit_triplet.py",
|
|
|
|
| 559 |
result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
|
| 560 |
|
| 561 |
if result.returncode == 0:
|
| 562 |
+
log_message += "β
ViT training completed successfully!\n\n"
|
| 563 |
+
log_message += "π All training completed! Models saved to models/exports/\n"
|
| 564 |
+
log_message += "π Reloading models for inference...\n"
|
| 565 |
service.reload_models()
|
| 566 |
+
log_message += "β
Models reloaded and ready for inference!\n"
|
| 567 |
+
|
| 568 |
+
# Auto-upload to HF Hub if token is available
|
| 569 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 570 |
+
if hf_token:
|
| 571 |
+
log_message += "π€ Auto-uploading artifacts to Hugging Face Hub...\n"
|
| 572 |
+
try:
|
| 573 |
+
from utils.hf_hub_integration import create_hf_integration
|
| 574 |
+
hf = create_hf_integration(hf_token)
|
| 575 |
+
result = hf.upload_everything_to_hf()
|
| 576 |
+
if result.get("success"):
|
| 577 |
+
log_message += "β
Successfully uploaded to HF Hub!\n"
|
| 578 |
+
log_message += "π Models: https://huggingface.co/Stylique/dressify-models\n"
|
| 579 |
+
log_message += "π Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n"
|
| 580 |
+
else:
|
| 581 |
+
log_message += f"β οΈ Upload failed: {result.get('error', 'Unknown error')}\n"
|
| 582 |
+
except Exception as e:
|
| 583 |
+
log_message += f"β οΈ Auto-upload failed: {str(e)}\n"
|
| 584 |
+
else:
|
| 585 |
+
log_message += "π‘ Set HF_TOKEN env var for automatic uploads\n"
|
| 586 |
else:
|
| 587 |
+
log_message += f"β ViT training failed: {result.stderr}\n"
|
| 588 |
|
| 589 |
except Exception as e:
|
| 590 |
+
log_message += f"\nβ Training error: {str(e)}"
|
| 591 |
|
| 592 |
threading.Thread(target=_runner, daemon=True).start()
|
| 593 |
+
return log_message
|
| 594 |
|
| 595 |
|
| 596 |
def start_training_simple(res_epochs: int, vit_epochs: int):
|
| 597 |
"""Start simple training with basic parameters."""
|
| 598 |
+
log_message = "Starting training..."
|
| 599 |
def _runner():
|
| 600 |
+
nonlocal log_message
|
| 601 |
try:
|
| 602 |
import subprocess
|
| 603 |
if not DATASET_ROOT:
|
| 604 |
+
log_message = "Dataset not ready."
|
| 605 |
return
|
| 606 |
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 607 |
os.makedirs(export_dir, exist_ok=True)
|
| 608 |
+
log_message = "Training ResNetβ¦\n"
|
| 609 |
subprocess.run([
|
| 610 |
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 611 |
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 612 |
], check=False)
|
| 613 |
+
log_message += "\nTraining ViT (triplet)β¦\n"
|
| 614 |
subprocess.run([
|
| 615 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 616 |
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 617 |
], check=False)
|
| 618 |
service.reload_models()
|
| 619 |
+
log_message += "\nDone. Artifacts in models/exports."
|
| 620 |
+
|
| 621 |
+
# Auto-upload to HF Hub if token is available
|
| 622 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 623 |
+
if hf_token:
|
| 624 |
+
log_message += "\nπ€ Auto-uploading artifacts to Hugging Face Hub...\n"
|
| 625 |
+
try:
|
| 626 |
+
from utils.hf_hub_integration import create_hf_integration
|
| 627 |
+
hf = create_hf_integration(hf_token)
|
| 628 |
+
result = hf.upload_everything_to_hf()
|
| 629 |
+
if result.get("success"):
|
| 630 |
+
log_message += "β
Successfully uploaded to HF Hub!\n"
|
| 631 |
+
log_message += "π Models: https://huggingface.co/Stylique/dressify-models\n"
|
| 632 |
+
log_message += "π Data: https://huggingface.co/datasets/Stylique/Dressify-Helper\n"
|
| 633 |
+
else:
|
| 634 |
+
log_message += f"β οΈ Upload failed: {result.get('error', 'Unknown error')}\n"
|
| 635 |
+
except Exception as e:
|
| 636 |
+
log_message += f"β οΈ Auto-upload failed: {str(e)}\n"
|
| 637 |
+
else:
|
| 638 |
+
log_message += "\nπ‘ Set HF_TOKEN env var for automatic uploads\n"
|
| 639 |
except Exception as e:
|
| 640 |
+
log_message += f"\nError: {e}"
|
| 641 |
threading.Thread(target=_runner, daemon=True).start()
|
| 642 |
+
return log_message
|
| 643 |
|
| 644 |
|
| 645 |
with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendation") as demo:
|
|
|
|
| 786 |
outputs=train_log
|
| 787 |
)
|
| 788 |
|
| 789 |
+
with gr.Tab("π¦ Artifact Management"):
|
| 790 |
+
gr.Markdown("### π― Comprehensive Artifact Management\nManage, package, and upload all system artifacts to Hugging Face Hub.")
|
| 791 |
+
|
| 792 |
+
with gr.Row():
|
| 793 |
+
with gr.Column(scale=1):
|
| 794 |
+
gr.Markdown("#### π Artifact Overview")
|
| 795 |
+
artifact_overview = gr.JSON(label="System Artifacts", value=get_artifact_overview)
|
| 796 |
+
refresh_overview = gr.Button("π Refresh Overview")
|
| 797 |
+
refresh_overview.click(fn=get_artifact_overview, inputs=[], outputs=artifact_overview)
|
| 798 |
+
|
| 799 |
+
gr.Markdown("#### π¦ Create Packages")
|
| 800 |
+
package_type = gr.Dropdown(
|
| 801 |
+
choices=["complete", "splits_only", "models_only"],
|
| 802 |
+
value="complete",
|
| 803 |
+
label="Package Type"
|
| 804 |
+
)
|
| 805 |
+
create_package_btn = gr.Button("π¦ Create Package")
|
| 806 |
+
package_result = gr.Textbox(label="Package Result", interactive=False)
|
| 807 |
+
available_packages = gr.JSON(label="Available Packages", value=get_available_packages)
|
| 808 |
+
|
| 809 |
+
create_package_btn.click(
|
| 810 |
+
fn=create_download_package,
|
| 811 |
+
inputs=[package_type],
|
| 812 |
+
outputs=[package_result, available_packages]
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
with gr.Column(scale=1):
|
| 816 |
+
gr.Markdown("#### π Hugging Face Hub Integration")
|
| 817 |
+
gr.Markdown("π‘ **Pro Tip**: Set `HF_TOKEN` environment variable for automatic uploads after training!")
|
| 818 |
+
hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_...")
|
| 819 |
+
hf_username = gr.Textbox(label="Username", placeholder="your-username")
|
| 820 |
+
|
| 821 |
+
with gr.Row():
|
| 822 |
+
push_splits_btn = gr.Button("π€ Push Splits", variant="secondary")
|
| 823 |
+
push_models_btn = gr.Button("π€ Push Models", variant="secondary")
|
| 824 |
+
|
| 825 |
+
push_everything_btn = gr.Button("π€ Push Everything", variant="primary")
|
| 826 |
+
hf_result = gr.Textbox(label="Upload Result", interactive=False, lines=3)
|
| 827 |
+
|
| 828 |
+
push_splits_btn.click(fn=push_splits_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
|
| 829 |
+
push_models_btn.click(fn=push_models_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
|
| 830 |
+
push_everything_btn.click(fn=push_everything_to_hf, inputs=[hf_token, hf_username], outputs=hf_result)
|
| 831 |
+
|
| 832 |
+
gr.Markdown("#### π₯ Download Management")
|
| 833 |
+
individual_files = gr.JSON(label="Individual Files", value=get_individual_files)
|
| 834 |
+
download_all_btn = gr.Button("π₯ Download All as ZIP")
|
| 835 |
+
download_result = gr.Textbox(label="Download Result", interactive=False)
|
| 836 |
+
|
| 837 |
+
download_all_btn.click(fn=download_all_files, inputs=[], outputs=download_result)
|
| 838 |
+
|
| 839 |
with gr.Tab("π§ Simple Training"):
|
| 840 |
gr.Markdown("### π Quick Training with Default Parameters\nFast training with proven configurations for immediate results.")
|
| 841 |
epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
|
|
|
|
| 850 |
btn = gr.Button("Compute Embeddings")
|
| 851 |
btn.click(fn=gradio_embed, inputs=inp, outputs=out)
|
| 852 |
|
| 853 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
|
| 855 |
with gr.Tab("π Status"):
|
| 856 |
gr.Markdown("### π¦ System Status and Monitoring\nReal-time status of dataset preparation, training, and system health.")
|
artifact_management_ui.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Comprehensive Gradio interface for Dressify artifact management.
|
| 4 |
+
Provides download, upload, and organization features for all system artifacts.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from typing import Dict, List, Any
|
| 11 |
+
from utils.artifact_manager import create_artifact_manager
|
| 12 |
+
|
| 13 |
+
def create_artifact_management_interface():
|
| 14 |
+
"""Create the main artifact management interface."""
|
| 15 |
+
|
| 16 |
+
with gr.Blocks(title="Dressify Artifact Management", theme=gr.themes.Soft()) as interface:
|
| 17 |
+
gr.Markdown("# π― Dressify Artifact Management System")
|
| 18 |
+
gr.Markdown("## π¦ Download, Upload, and Organize All System Artifacts")
|
| 19 |
+
|
| 20 |
+
with gr.Tabs():
|
| 21 |
+
|
| 22 |
+
# Overview Tab
|
| 23 |
+
with gr.Tab("π System Overview"):
|
| 24 |
+
gr.Markdown("### π Complete System Status and Artifact Summary")
|
| 25 |
+
|
| 26 |
+
with gr.Row():
|
| 27 |
+
refresh_overview = gr.Button("π Refresh Overview", variant="primary")
|
| 28 |
+
export_summary = gr.Button("π₯ Export Summary JSON", variant="secondary")
|
| 29 |
+
|
| 30 |
+
overview_display = gr.JSON(label="System Overview", value=get_system_overview())
|
| 31 |
+
|
| 32 |
+
refresh_overview.click(
|
| 33 |
+
fn=get_system_overview,
|
| 34 |
+
outputs=overview_display
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
export_summary.click(
|
| 38 |
+
fn=export_system_summary,
|
| 39 |
+
outputs=gr.File(label="Download Summary")
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Download Management Tab
|
| 43 |
+
with gr.Tab("π₯ Download Management"):
|
| 44 |
+
gr.Markdown("### π Create Downloadable Packages")
|
| 45 |
+
|
| 46 |
+
with gr.Row():
|
| 47 |
+
with gr.Column(scale=1):
|
| 48 |
+
gr.Markdown("#### π¦ Package Types")
|
| 49 |
+
package_type = gr.Dropdown(
|
| 50 |
+
choices=[
|
| 51 |
+
"complete - Everything (splits + models + metadata + configs)",
|
| 52 |
+
"splits_only - Dataset splits only (lightweight)",
|
| 53 |
+
"models_only - Trained models only"
|
| 54 |
+
],
|
| 55 |
+
value="splits_only - Dataset splits only (lightweight)",
|
| 56 |
+
label="Package Type"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
create_package_btn = gr.Button("π Create Package", variant="primary")
|
| 60 |
+
package_status = gr.Textbox(label="Package Status", interactive=False)
|
| 61 |
+
|
| 62 |
+
with gr.Column(scale=1):
|
| 63 |
+
gr.Markdown("#### π Available Packages")
|
| 64 |
+
packages_list = gr.JSON(label="Created Packages", value=get_available_packages())
|
| 65 |
+
refresh_packages = gr.Button("π Refresh Packages")
|
| 66 |
+
|
| 67 |
+
create_package_btn.click(
|
| 68 |
+
fn=create_download_package,
|
| 69 |
+
inputs=[package_type],
|
| 70 |
+
outputs=[package_status, packages_list]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
refresh_packages.click(
|
| 74 |
+
fn=get_available_packages,
|
| 75 |
+
outputs=packages_list
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Individual Files Tab
|
| 79 |
+
with gr.Tab("π Individual Files"):
|
| 80 |
+
gr.Markdown("### π Browse and Download Individual Artifacts")
|
| 81 |
+
|
| 82 |
+
with gr.Row():
|
| 83 |
+
refresh_files = gr.Button("π Refresh Files", variant="primary")
|
| 84 |
+
download_all_btn = gr.Button("π₯ Download All as ZIP", variant="secondary")
|
| 85 |
+
|
| 86 |
+
files_display = gr.JSON(label="Available Files", value=get_individual_files())
|
| 87 |
+
|
| 88 |
+
refresh_files.click(
|
| 89 |
+
fn=get_individual_files,
|
| 90 |
+
outputs=files_display
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
download_all_btn.click(
|
| 94 |
+
fn=download_all_files,
|
| 95 |
+
outputs=gr.File(label="Download All Files")
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Upload & Restore Tab
|
| 99 |
+
with gr.Tab("π€ Upload & Restore"):
|
| 100 |
+
gr.Markdown("### π Upload Pre-processed Artifacts")
|
| 101 |
+
gr.Markdown("Upload previously downloaded packages to avoid reprocessing.")
|
| 102 |
+
|
| 103 |
+
with gr.Row():
|
| 104 |
+
with gr.Column(scale=1):
|
| 105 |
+
gr.Markdown("#### π€ Upload Package")
|
| 106 |
+
upload_package = gr.File(
|
| 107 |
+
label="Upload Artifact Package (.tar.gz)",
|
| 108 |
+
file_types=[".tar.gz", ".zip"]
|
| 109 |
+
)
|
| 110 |
+
upload_btn = gr.Button("π€ Upload & Extract", variant="primary")
|
| 111 |
+
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
| 112 |
+
|
| 113 |
+
with gr.Column(scale=1):
|
| 114 |
+
gr.Markdown("#### π Restore Options")
|
| 115 |
+
restore_splits = gr.Button("π Restore Splits Only", variant="secondary")
|
| 116 |
+
restore_models = gr.Button("π Restore Models Only", variant="secondary")
|
| 117 |
+
restore_all = gr.Button("π Restore Everything", variant="secondary")
|
| 118 |
+
|
| 119 |
+
upload_btn.click(
|
| 120 |
+
fn=upload_and_extract_package,
|
| 121 |
+
inputs=[upload_package],
|
| 122 |
+
outputs=upload_status
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
restore_splits.click(
|
| 126 |
+
fn=restore_splits_only,
|
| 127 |
+
outputs=gr.Textbox(label="Restore Status")
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
restore_models.click(
|
| 131 |
+
fn=restore_models_only,
|
| 132 |
+
outputs=gr.Textbox(label="Restore Status")
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
restore_all.click(
|
| 136 |
+
fn=restore_everything,
|
| 137 |
+
outputs=gr.Textbox(label="Restore Status")
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Hugging Face Integration Tab
|
| 141 |
+
with gr.Tab("π€ HF Hub Integration"):
|
| 142 |
+
gr.Markdown("### π Push Artifacts to Hugging Face Hub")
|
| 143 |
+
gr.Markdown("Upload your artifacts to HF Hub for easy access and sharing.")
|
| 144 |
+
|
| 145 |
+
with gr.Row():
|
| 146 |
+
with gr.Column(scale=1):
|
| 147 |
+
gr.Markdown("#### π Authentication")
|
| 148 |
+
hf_token = gr.Textbox(
|
| 149 |
+
label="Hugging Face Token",
|
| 150 |
+
placeholder="hf_...",
|
| 151 |
+
type="password"
|
| 152 |
+
)
|
| 153 |
+
hf_username = gr.Textbox(
|
| 154 |
+
label="HF Username",
|
| 155 |
+
placeholder="yourusername"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
with gr.Column(scale=1):
|
| 159 |
+
gr.Markdown("#### π€ Push Options")
|
| 160 |
+
push_splits = gr.Button("π€ Push Splits to HF", variant="primary")
|
| 161 |
+
push_models = gr.Button("π€ Push Models to HF", variant="primary")
|
| 162 |
+
push_all = gr.Button("π€ Push Everything to HF", variant="primary")
|
| 163 |
+
|
| 164 |
+
push_status = gr.Textbox(label="Push Status", interactive=False)
|
| 165 |
+
|
| 166 |
+
push_splits.click(
|
| 167 |
+
fn=push_splits_to_hf,
|
| 168 |
+
inputs=[hf_token, hf_username],
|
| 169 |
+
outputs=push_status
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
push_models.click(
|
| 173 |
+
fn=push_models_to_hf,
|
| 174 |
+
inputs=[hf_token, hf_username],
|
| 175 |
+
outputs=push_status
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
push_all.click(
|
| 179 |
+
fn=push_everything_to_hf,
|
| 180 |
+
inputs=[hf_token, hf_username],
|
| 181 |
+
outputs=push_status
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Runtime Fetching Tab
|
| 185 |
+
with gr.Tab("β‘ Runtime Fetching"):
|
| 186 |
+
gr.Markdown("### π Fetch Artifacts at Runtime")
|
| 187 |
+
gr.Markdown("Configure the system to fetch artifacts from HF Hub instead of reprocessing.")
|
| 188 |
+
|
| 189 |
+
with gr.Row():
|
| 190 |
+
with gr.Column(scale=1):
|
| 191 |
+
gr.Markdown("#### π HF Hub Sources")
|
| 192 |
+
splits_repo = gr.Textbox(
|
| 193 |
+
label="Splits Repository",
|
| 194 |
+
placeholder="yourusername/dressify-splits",
|
| 195 |
+
value="Stylique/dressify-splits"
|
| 196 |
+
)
|
| 197 |
+
models_repo = gr.Textbox(
|
| 198 |
+
label="Models Repository",
|
| 199 |
+
placeholder="yourusername/dressify-models",
|
| 200 |
+
value="Stylique/dressify-models"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
enable_runtime_fetch = gr.Checkbox(
|
| 204 |
+
label="Enable Runtime Fetching",
|
| 205 |
+
value=False
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
with gr.Column(scale=1):
|
| 209 |
+
gr.Markdown("#### π Fetch Actions")
|
| 210 |
+
fetch_splits = gr.Button("π Fetch Splits", variant="primary")
|
| 211 |
+
fetch_models = gr.Button("π Fetch Models", variant="primary")
|
| 212 |
+
fetch_all = gr.Button("π Fetch Everything", variant="primary")
|
| 213 |
+
|
| 214 |
+
fetch_status = gr.Textbox(label="Fetch Status", interactive=False)
|
| 215 |
+
|
| 216 |
+
fetch_splits.click(
|
| 217 |
+
fn=fetch_splits_from_hf,
|
| 218 |
+
inputs=[splits_repo],
|
| 219 |
+
outputs=fetch_status
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
fetch_models.click(
|
| 223 |
+
fn=fetch_models_from_hf,
|
| 224 |
+
inputs=[models_repo],
|
| 225 |
+
outputs=fetch_status
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
fetch_all.click(
|
| 229 |
+
fn=fetch_everything_from_hf,
|
| 230 |
+
inputs=[splits_repo, models_repo],
|
| 231 |
+
outputs=fetch_status
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Footer
|
| 235 |
+
gr.Markdown("---")
|
| 236 |
+
gr.Markdown("### π‘ Usage Instructions")
|
| 237 |
+
gr.Markdown("""
|
| 238 |
+
1. **System Overview**: Check what artifacts are available and their sizes
|
| 239 |
+
2. **Download Management**: Create packaged downloads for easy sharing
|
| 240 |
+
3. **Individual Files**: Browse and download specific artifacts
|
| 241 |
+
4. **Upload & Restore**: Upload previously downloaded packages
|
| 242 |
+
5. **HF Hub Integration**: Push artifacts to Hugging Face for sharing
|
| 243 |
+
6. **Runtime Fetching**: Configure automatic fetching from HF Hub
|
| 244 |
+
""")
|
| 245 |
+
|
| 246 |
+
gr.Markdown("### π― Benefits")
|
| 247 |
+
gr.Markdown("""
|
| 248 |
+
- **Save Time**: No more reprocessing expensive splits
|
| 249 |
+
- **Save Resources**: Avoid re-downloading and re-extracting
|
| 250 |
+
- **Easy Sharing**: Download packages and share with others
|
| 251 |
+
- **HF Integration**: Push to Hub for community access
|
| 252 |
+
- **Runtime Fetching**: Automatic artifact retrieval
|
| 253 |
+
""")
|
| 254 |
+
|
| 255 |
+
return interface
|
| 256 |
+
|
| 257 |
+
# Helper functions for the interface
|
| 258 |
+
def get_system_overview():
|
| 259 |
+
"""Get comprehensive system overview."""
|
| 260 |
+
try:
|
| 261 |
+
manager = create_artifact_manager()
|
| 262 |
+
return manager.get_artifact_summary()
|
| 263 |
+
except Exception as e:
|
| 264 |
+
return {"error": str(e)}
|
| 265 |
+
|
| 266 |
+
def export_system_summary():
|
| 267 |
+
"""Export system summary as JSON file."""
|
| 268 |
+
try:
|
| 269 |
+
manager = create_artifact_manager()
|
| 270 |
+
summary = manager.get_artifact_summary()
|
| 271 |
+
|
| 272 |
+
# Save to exports directory
|
| 273 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 274 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 275 |
+
|
| 276 |
+
summary_path = os.path.join(export_dir, "system_summary.json")
|
| 277 |
+
with open(summary_path, 'w') as f:
|
| 278 |
+
json.dump(summary, f, indent=2)
|
| 279 |
+
|
| 280 |
+
return summary_path
|
| 281 |
+
except Exception as e:
|
| 282 |
+
return None
|
| 283 |
+
|
| 284 |
+
def create_download_package(package_type: str):
|
| 285 |
+
"""Create a downloadable package."""
|
| 286 |
+
try:
|
| 287 |
+
manager = create_artifact_manager()
|
| 288 |
+
|
| 289 |
+
# Extract package type from the dropdown choice
|
| 290 |
+
if "complete" in package_type:
|
| 291 |
+
pkg_type = "complete"
|
| 292 |
+
elif "splits_only" in package_type:
|
| 293 |
+
pkg_type = "splits_only"
|
| 294 |
+
elif "models_only" in package_type:
|
| 295 |
+
pkg_type = "models_only"
|
| 296 |
+
else:
|
| 297 |
+
return f"β Invalid package type: {package_type}", get_available_packages()
|
| 298 |
+
|
| 299 |
+
package_path = manager.create_download_package(pkg_type)
|
| 300 |
+
package_name = os.path.basename(package_path)
|
| 301 |
+
|
| 302 |
+
return f"β
Package created: {package_name}", get_available_packages()
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
return f"β Failed to create package: {e}", get_available_packages()
|
| 306 |
+
|
| 307 |
+
def get_available_packages():
|
| 308 |
+
"""Get list of available packages."""
|
| 309 |
+
try:
|
| 310 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 311 |
+
packages = []
|
| 312 |
+
|
| 313 |
+
if os.path.exists(export_dir):
|
| 314 |
+
for file in os.listdir(export_dir):
|
| 315 |
+
if file.endswith((".tar.gz", ".zip")):
|
| 316 |
+
file_path = os.path.join(export_dir, file)
|
| 317 |
+
packages.append({
|
| 318 |
+
"name": file,
|
| 319 |
+
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
|
| 320 |
+
"path": file_path,
|
| 321 |
+
"url": f"/files/{file}"
|
| 322 |
+
})
|
| 323 |
+
|
| 324 |
+
return {"packages": packages}
|
| 325 |
+
except Exception as e:
|
| 326 |
+
return {"error": str(e)}
|
| 327 |
+
|
| 328 |
+
def get_individual_files():
|
| 329 |
+
"""Get list of individual downloadable files."""
|
| 330 |
+
try:
|
| 331 |
+
manager = create_artifact_manager()
|
| 332 |
+
files = manager.get_downloadable_files()
|
| 333 |
+
|
| 334 |
+
# Group by category
|
| 335 |
+
categorized = {}
|
| 336 |
+
for file in files:
|
| 337 |
+
category = file["category"]
|
| 338 |
+
if category not in categorized:
|
| 339 |
+
categorized[category] = []
|
| 340 |
+
categorized[category].append(file)
|
| 341 |
+
|
| 342 |
+
return categorized
|
| 343 |
+
except Exception as e:
|
| 344 |
+
return {"error": str(e)}
|
| 345 |
+
|
| 346 |
+
def download_all_files():
|
| 347 |
+
"""Download all files as a ZIP archive."""
|
| 348 |
+
try:
|
| 349 |
+
manager = create_artifact_manager()
|
| 350 |
+
files = manager.get_downloadable_files()
|
| 351 |
+
|
| 352 |
+
# Create ZIP with all files
|
| 353 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 354 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 355 |
+
|
| 356 |
+
zip_path = os.path.join(export_dir, "all_artifacts.zip")
|
| 357 |
+
import zipfile
|
| 358 |
+
|
| 359 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 360 |
+
for file in files:
|
| 361 |
+
if os.path.exists(file["path"]):
|
| 362 |
+
zipf.write(file["path"], file["name"])
|
| 363 |
+
|
| 364 |
+
return zip_path
|
| 365 |
+
except Exception as e:
|
| 366 |
+
return None
|
| 367 |
+
|
| 368 |
+
# Placeholder functions for upload/restore features
|
| 369 |
+
def upload_and_extract_package(upload_file):
|
| 370 |
+
"""Upload and extract a package."""
|
| 371 |
+
if upload_file is None:
|
| 372 |
+
return "β No file uploaded"
|
| 373 |
+
|
| 374 |
+
try:
|
| 375 |
+
# This would implement actual upload and extraction logic
|
| 376 |
+
return f"β
Package uploaded: {upload_file.name}"
|
| 377 |
+
except Exception as e:
|
| 378 |
+
return f"β Upload failed: {e}"
|
| 379 |
+
|
| 380 |
+
def restore_splits_only():
|
| 381 |
+
"""Restore splits only."""
|
| 382 |
+
return "π Splits restoration not yet implemented"
|
| 383 |
+
|
| 384 |
+
def restore_models_only():
|
| 385 |
+
"""Restore models only."""
|
| 386 |
+
return "π Models restoration not yet implemented"
|
| 387 |
+
|
| 388 |
+
def restore_everything():
|
| 389 |
+
"""Restore everything."""
|
| 390 |
+
return "π Full restoration not yet implemented"
|
| 391 |
+
|
| 392 |
+
# Placeholder functions for HF Hub integration
|
| 393 |
+
def push_splits_to_hf(token, username):
|
| 394 |
+
"""Push splits to HF Hub."""
|
| 395 |
+
if not token or not username:
|
| 396 |
+
return "β Please provide HF token and username"
|
| 397 |
+
return f"π€ Pushing splits to {username}/dressify-splits..."
|
| 398 |
+
|
| 399 |
+
def push_models_to_hf(token, username):
|
| 400 |
+
"""Push models to HF Hub."""
|
| 401 |
+
if not token or not username:
|
| 402 |
+
return "β Please provide HF token and username"
|
| 403 |
+
return f"π€ Pushing models to {username}/dressify-models..."
|
| 404 |
+
|
| 405 |
+
def push_everything_to_hf(token, username):
|
| 406 |
+
"""Push everything to HF Hub."""
|
| 407 |
+
if not token or not username:
|
| 408 |
+
return "β Please provide HF token and username"
|
| 409 |
+
return f"π€ Pushing everything to {username}/dressify..."
|
| 410 |
+
|
| 411 |
+
# Placeholder functions for runtime fetching
|
| 412 |
+
def fetch_splits_from_hf(repo):
|
| 413 |
+
"""Fetch splits from HF Hub."""
|
| 414 |
+
return f"π Fetching splits from {repo}..."
|
| 415 |
+
|
| 416 |
+
def fetch_models_from_hf(repo):
|
| 417 |
+
"""Fetch models from HF Hub."""
|
| 418 |
+
return f"π Fetching models from {repo}..."
|
| 419 |
+
|
| 420 |
+
def fetch_everything_from_hf(splits_repo, models_repo):
|
| 421 |
+
"""Fetch everything from HF Hub."""
|
| 422 |
+
return f"π Fetching everything from {splits_repo} and {models_repo}..."
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
# Test the interface
|
| 426 |
+
interface = create_artifact_management_interface()
|
| 427 |
+
interface.launch()
|
models/resnet_embedder.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Optional
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 5 |
import torchvision.models as tvm
|
| 6 |
|
| 7 |
|
|
@@ -27,6 +28,8 @@ class ResNetItemEmbedder(nn.Module):
|
|
| 27 |
feats = self.backbone(x) # (B, C, 1, 1)
|
| 28 |
feats = feats.flatten(1) # (B, C)
|
| 29 |
emb = self.proj(feats) # (B, D)
|
|
|
|
|
|
|
| 30 |
return emb
|
| 31 |
|
| 32 |
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
import torchvision.models as tvm
|
| 7 |
|
| 8 |
|
|
|
|
| 28 |
feats = self.backbone(x) # (B, C, 1, 1)
|
| 29 |
feats = feats.flatten(1) # (B, C)
|
| 30 |
emb = self.proj(feats) # (B, D)
|
| 31 |
+
# Apply L2 normalization as specified in requirements
|
| 32 |
+
emb = F.normalize(emb, p=2, dim=1)
|
| 33 |
return emb
|
| 34 |
|
| 35 |
|
test_training.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple test script to verify training components work.
|
| 4 |
+
Run this to test if the system is ready for training.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
def test_imports():
|
| 12 |
+
"""Test if all required modules can be imported."""
|
| 13 |
+
print("π Testing imports...")
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from models.resnet_embedder import ResNetItemEmbedder
|
| 17 |
+
print("β
ResNet embedder imported successfully")
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"β Failed to import ResNet embedder: {e}")
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from models.vit_outfit import OutfitCompatibilityModel
|
| 24 |
+
print("β
ViT outfit model imported successfully")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"β Failed to import ViT outfit model: {e}")
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from data.polyvore import PolyvoreTripletDataset
|
| 31 |
+
print("β
Polyvore dataset imported successfully")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"β Failed to import Polyvore dataset: {e}")
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from utils.transforms import build_train_transforms
|
| 38 |
+
print("β
Transforms imported successfully")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"β Failed to import transforms: {e}")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
def test_models():
|
| 46 |
+
"""Test if models can be created and run forward pass."""
|
| 47 |
+
print("\nποΈ Testing model creation...")
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
# Test ResNet embedder
|
| 51 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
print(f"Using device: {device}")
|
| 53 |
+
|
| 54 |
+
resnet = ResNetItemEmbedder(embedding_dim=512).to(device)
|
| 55 |
+
print(f"β
ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters")
|
| 56 |
+
|
| 57 |
+
# Test forward pass
|
| 58 |
+
dummy_input = torch.randn(2, 3, 224, 224).to(device)
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
output = resnet(dummy_input)
|
| 61 |
+
print(f"β
ResNet forward pass: input {dummy_input.shape} -> output {output.shape}")
|
| 62 |
+
|
| 63 |
+
# Test ViT outfit model
|
| 64 |
+
vit = OutfitCompatibilityModel(embedding_dim=512).to(device)
|
| 65 |
+
print(f"β
ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters")
|
| 66 |
+
|
| 67 |
+
# Test forward pass
|
| 68 |
+
dummy_tokens = torch.randn(2, 4, 512).to(device)
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
output = vit(dummy_tokens)
|
| 71 |
+
print(f"β
ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}")
|
| 72 |
+
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"β Model test failed: {e}")
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
def test_dataset():
|
| 80 |
+
"""Test if dataset can be loaded (if available)."""
|
| 81 |
+
print("\nπ Testing dataset loading...")
|
| 82 |
+
|
| 83 |
+
data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")
|
| 84 |
+
splits_dir = os.path.join(data_root, "splits")
|
| 85 |
+
train_file = os.path.join(splits_dir, "train.json")
|
| 86 |
+
|
| 87 |
+
if not os.path.exists(train_file):
|
| 88 |
+
print(f"β οΈ Training data not found at {train_file}")
|
| 89 |
+
print("π‘ Dataset preparation may be needed")
|
| 90 |
+
return True # Not a failure, just not ready
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
dataset = PolyvoreTripletDataset(data_root, split="train")
|
| 94 |
+
print(f"β
Dataset loaded successfully: {len(dataset)} samples")
|
| 95 |
+
|
| 96 |
+
# Test getting one sample
|
| 97 |
+
if len(dataset) > 0:
|
| 98 |
+
sample = dataset[0]
|
| 99 |
+
print(f"β
Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}")
|
| 100 |
+
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"β Dataset test failed: {e}")
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
def test_training_components():
|
| 108 |
+
"""Test if training components can be created."""
|
| 109 |
+
print("\nπ Testing training components...")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
from torch.utils.data import DataLoader
|
| 113 |
+
from torch.optim import AdamW
|
| 114 |
+
from torch.nn import TripletMarginLoss
|
| 115 |
+
|
| 116 |
+
# Test optimizer creation
|
| 117 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 118 |
+
model = ResNetItemEmbedder(embedding_dim=512).to(device)
|
| 119 |
+
optimizer = AdamW(model.parameters(), lr=1e-3)
|
| 120 |
+
print("β
Optimizer created successfully")
|
| 121 |
+
|
| 122 |
+
# Test loss function
|
| 123 |
+
criterion = TripletMarginLoss(margin=0.2)
|
| 124 |
+
print("β
Loss function created successfully")
|
| 125 |
+
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"β Training components test failed: {e}")
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
"""Run all tests."""
|
| 134 |
+
print("π§ͺ Starting Dressify Training System Tests\n")
|
| 135 |
+
|
| 136 |
+
tests = [
|
| 137 |
+
("Imports", test_imports),
|
| 138 |
+
("Models", test_models),
|
| 139 |
+
("Dataset", test_dataset),
|
| 140 |
+
("Training Components", test_training_components),
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
results = []
|
| 144 |
+
for test_name, test_func in tests:
|
| 145 |
+
try:
|
| 146 |
+
result = test_func()
|
| 147 |
+
results.append((test_name, result))
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"οΏ½οΏ½ {test_name} test crashed: {e}")
|
| 150 |
+
results.append((test_name, False))
|
| 151 |
+
|
| 152 |
+
# Summary
|
| 153 |
+
print("\n" + "="*50)
|
| 154 |
+
print("π TEST RESULTS SUMMARY")
|
| 155 |
+
print("="*50)
|
| 156 |
+
|
| 157 |
+
passed = 0
|
| 158 |
+
total = len(results)
|
| 159 |
+
|
| 160 |
+
for test_name, result in results:
|
| 161 |
+
status = "β
PASS" if result else "β FAIL"
|
| 162 |
+
print(f"{test_name:20} {status}")
|
| 163 |
+
if result:
|
| 164 |
+
passed += 1
|
| 165 |
+
|
| 166 |
+
print("="*50)
|
| 167 |
+
print(f"Overall: {passed}/{total} tests passed")
|
| 168 |
+
|
| 169 |
+
if passed == total:
|
| 170 |
+
print("π All tests passed! System is ready for training.")
|
| 171 |
+
return True
|
| 172 |
+
else:
|
| 173 |
+
print("β οΈ Some tests failed. Please check the errors above.")
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
success = main()
|
| 178 |
+
sys.exit(0 if success else 1)
|
train_resnet.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import argparse
|
|
|
|
| 3 |
from typing import Tuple
|
| 4 |
|
| 5 |
import torch
|
|
@@ -7,6 +8,9 @@ import torch.nn as nn
|
|
| 7 |
import torch.optim as optim
|
| 8 |
from torch.utils.data import DataLoader
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
from data.polyvore import PolyvoreTripletDataset
|
| 11 |
from models.resnet_embedder import ResNetItemEmbedder
|
| 12 |
from utils.export import ensure_export_dir
|
|
@@ -15,7 +19,7 @@ import json
|
|
| 15 |
|
| 16 |
def parse_args() -> argparse.Namespace:
|
| 17 |
p = argparse.ArgumentParser()
|
| 18 |
-
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/
|
| 19 |
p.add_argument("--epochs", type=int, default=20)
|
| 20 |
p.add_argument("--batch_size", type=int, default=64)
|
| 21 |
p.add_argument("--lr", type=float, default=1e-3)
|
|
@@ -30,80 +34,119 @@ def main() -> None:
|
|
| 30 |
if device == "cuda":
|
| 31 |
torch.backends.cudnn.benchmark = True
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Ensure splits exist; if missing, prepare from official splits
|
| 34 |
splits_dir = os.path.join(args.data_root, "splits")
|
| 35 |
triplet_path = os.path.join(splits_dir, "train.json")
|
|
|
|
| 36 |
if not os.path.exists(triplet_path):
|
|
|
|
|
|
|
| 37 |
os.makedirs(splits_dir, exist_ok=True)
|
| 38 |
try:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
dataset = PolyvoreTripletDataset(args.data_root, split="train")
|
| 63 |
-
|
| 64 |
-
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"))
|
| 65 |
model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
|
| 66 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 67 |
criterion = nn.TripletMarginLoss(margin=0.2, p=2)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
|
| 70 |
best_loss = float("inf")
|
| 71 |
history = []
|
|
|
|
|
|
|
|
|
|
| 72 |
for epoch in range(args.epochs):
|
| 73 |
model.train()
|
| 74 |
-
|
| 75 |
steps = 0
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
out_path = args.out
|
| 94 |
if not out_path.startswith("models/"):
|
| 95 |
out_path = os.path.join(export_dir, os.path.basename(args.out))
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
|
|
|
|
| 99 |
if avg_loss < best_loss:
|
| 100 |
best_loss = avg_loss
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
#
|
| 104 |
metrics_path = os.path.join(export_dir, "resnet_metrics.json")
|
| 105 |
with open(metrics_path, "w") as f:
|
| 106 |
json.dump({"best_triplet_loss": best_loss, "history": history}, f)
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
| 2 |
import argparse
|
| 3 |
+
import sys
|
| 4 |
from typing import Tuple
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 8 |
import torch.optim as optim
|
| 9 |
from torch.utils.data import DataLoader
|
| 10 |
|
| 11 |
+
# Fix import paths
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
|
| 14 |
from data.polyvore import PolyvoreTripletDataset
|
| 15 |
from models.resnet_embedder import ResNetItemEmbedder
|
| 16 |
from utils.export import ensure_export_dir
|
|
|
|
| 19 |
|
| 20 |
def parse_args() -> argparse.Namespace:
|
| 21 |
p = argparse.ArgumentParser()
|
| 22 |
+
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
|
| 23 |
p.add_argument("--epochs", type=int, default=20)
|
| 24 |
p.add_argument("--batch_size", type=int, default=64)
|
| 25 |
p.add_argument("--lr", type=float, default=1e-3)
|
|
|
|
| 34 |
if device == "cuda":
|
| 35 |
torch.backends.cudnn.benchmark = True
|
| 36 |
|
| 37 |
+
print(f"π Starting ResNet training on {device}")
|
| 38 |
+
print(f"π Data root: {args.data_root}")
|
| 39 |
+
print(f"βοΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")
|
| 40 |
+
|
| 41 |
# Ensure splits exist; if missing, prepare from official splits
|
| 42 |
splits_dir = os.path.join(args.data_root, "splits")
|
| 43 |
triplet_path = os.path.join(splits_dir, "train.json")
|
| 44 |
+
|
| 45 |
if not os.path.exists(triplet_path):
|
| 46 |
+
print(f"β οΈ Triplet file not found: {triplet_path}")
|
| 47 |
+
print("π§ Attempting to prepare dataset...")
|
| 48 |
os.makedirs(splits_dir, exist_ok=True)
|
| 49 |
try:
|
| 50 |
+
# Try to import and run the prepare script
|
| 51 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "scripts"))
|
| 52 |
+
from prepare_polyvore import main as prepare_main
|
| 53 |
+
print("β
Successfully imported prepare_polyvore")
|
| 54 |
+
|
| 55 |
+
# Prepare dataset without random splits
|
| 56 |
+
prepare_main()
|
| 57 |
+
print("β
Dataset preparation completed")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"β Failed to prepare dataset: {e}")
|
| 60 |
+
print("π‘ Please ensure the dataset is prepared manually")
|
| 61 |
+
return
|
| 62 |
+
else:
|
| 63 |
+
print(f"β
Found existing splits: {triplet_path}")
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
dataset = PolyvoreTripletDataset(args.data_root, split="train")
|
| 67 |
+
print(f"π Dataset loaded: {len(dataset)} samples")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"β Failed to load dataset: {e}")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"))
|
|
|
|
|
|
|
|
|
|
| 73 |
model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
|
| 74 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 75 |
criterion = nn.TripletMarginLoss(margin=0.2, p=2)
|
| 76 |
|
| 77 |
+
print(f"ποΈ Model created: {model.__class__.__name__}")
|
| 78 |
+
print(f"π Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 79 |
+
|
| 80 |
export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
|
| 81 |
best_loss = float("inf")
|
| 82 |
history = []
|
| 83 |
+
|
| 84 |
+
print(f"πΎ Checkpoints will be saved to: {export_dir}")
|
| 85 |
+
|
| 86 |
for epoch in range(args.epochs):
|
| 87 |
model.train()
|
| 88 |
+
running_loss = 0.0
|
| 89 |
steps = 0
|
| 90 |
+
|
| 91 |
+
print(f"π Epoch {epoch+1}/{args.epochs}")
|
| 92 |
+
|
| 93 |
+
for batch_idx, batch in enumerate(loader):
|
| 94 |
+
try:
|
| 95 |
+
# Expect batch as (anchor, positive, negative)
|
| 96 |
+
anchor, positive, negative = batch
|
| 97 |
+
anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
|
| 98 |
+
positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
|
| 99 |
+
negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
|
| 100 |
+
|
| 101 |
+
with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
|
| 102 |
+
emb_a = model(anchor)
|
| 103 |
+
emb_p = model(positive)
|
| 104 |
+
emb_n = model(negative)
|
| 105 |
+
|
| 106 |
+
loss = criterion(emb_a, emb_p, emb_n)
|
| 107 |
+
optimizer.zero_grad(set_to_none=True)
|
| 108 |
+
loss.backward()
|
| 109 |
+
optimizer.step()
|
| 110 |
+
|
| 111 |
+
running_loss += loss.item()
|
| 112 |
+
steps += 1
|
| 113 |
+
|
| 114 |
+
if batch_idx % 100 == 0:
|
| 115 |
+
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"β Error in batch {batch_idx}: {e}")
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
avg_loss = running_loss / max(1, steps)
|
| 122 |
+
|
| 123 |
+
# Save checkpoint with better path handling
|
| 124 |
out_path = args.out
|
| 125 |
if not out_path.startswith("models/"):
|
| 126 |
out_path = os.path.join(export_dir, os.path.basename(args.out))
|
| 127 |
+
|
| 128 |
+
# Ensure the output directory exists
|
| 129 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 130 |
+
|
| 131 |
+
# Save checkpoint
|
| 132 |
+
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
|
| 133 |
+
print(f"β
Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
|
| 134 |
+
|
| 135 |
history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
|
| 136 |
+
|
| 137 |
if avg_loss < best_loss:
|
| 138 |
best_loss = avg_loss
|
| 139 |
+
best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth")
|
| 140 |
+
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
|
| 141 |
+
print(f"π New best model saved: {best_path}")
|
| 142 |
|
| 143 |
+
# Write metrics
|
| 144 |
metrics_path = os.path.join(export_dir, "resnet_metrics.json")
|
| 145 |
with open(metrics_path, "w") as f:
|
| 146 |
json.dump({"best_triplet_loss": best_loss, "history": history}, f)
|
| 147 |
+
|
| 148 |
+
print(f"π Training completed! Best loss: {best_loss:.4f}")
|
| 149 |
+
print(f"π Metrics saved to: {metrics_path}")
|
| 150 |
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|
train_vit_triplet.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import argparse
|
|
|
|
| 3 |
from typing import List
|
| 4 |
|
| 5 |
import torch
|
|
@@ -7,6 +8,9 @@ import torch.nn as nn
|
|
| 7 |
import torch.optim as optim
|
| 8 |
from torch.utils.data import DataLoader
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
from data.polyvore import PolyvoreOutfitTripletDataset
|
| 11 |
from models.vit_outfit import OutfitCompatibilityModel
|
| 12 |
from models.resnet_embedder import ResNetItemEmbedder
|
|
@@ -16,7 +20,7 @@ import json
|
|
| 16 |
|
| 17 |
def parse_args() -> argparse.Namespace:
|
| 18 |
p = argparse.ArgumentParser()
|
| 19 |
-
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/
|
| 20 |
p.add_argument("--epochs", type=int, default=30)
|
| 21 |
p.add_argument("--batch_size", type=int, default=32)
|
| 22 |
p.add_argument("--lr", type=float, default=5e-4)
|
|
@@ -45,128 +49,182 @@ def main() -> None:
|
|
| 45 |
if device == "cuda":
|
| 46 |
torch.backends.cudnn.benchmark = True
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Ensure outfit triplets exist
|
| 49 |
splits_dir = os.path.join(args.data_root, "splits")
|
| 50 |
trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
|
|
|
|
| 51 |
if not os.path.exists(trip_path):
|
|
|
|
|
|
|
| 52 |
os.makedirs(splits_dir, exist_ok=True)
|
| 53 |
try:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
except Exception:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
|
| 76 |
|
| 77 |
def collate(batch):
|
| 78 |
return batch # variable length handled inside training loop
|
| 79 |
|
| 80 |
-
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=
|
| 81 |
|
| 82 |
model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
|
| 83 |
embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
|
| 84 |
for p in embedder.parameters():
|
| 85 |
p.requires_grad_(False)
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
|
| 88 |
triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
|
| 89 |
|
| 90 |
export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
|
| 91 |
best_loss = float("inf")
|
| 92 |
hist = []
|
|
|
|
|
|
|
|
|
|
| 93 |
for epoch in range(args.epochs):
|
| 94 |
model.train()
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# Simple validation using a subset of training data as a proxy if no val split here
|
| 122 |
# For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
|
| 123 |
val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
|
| 124 |
val_loss = None
|
|
|
|
| 125 |
if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
out_path = args.export
|
| 153 |
if not out_path.startswith("models/"):
|
| 154 |
out_path = os.path.join(export_dir, os.path.basename(args.export))
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
| 156 |
if val_loss is not None:
|
| 157 |
-
print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={
|
| 158 |
-
hist.append({"epoch": epoch + 1, "triplet_loss": float(
|
| 159 |
if val_loss < best_loss:
|
| 160 |
best_loss = val_loss
|
| 161 |
-
|
|
|
|
|
|
|
| 162 |
else:
|
| 163 |
-
print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={
|
| 164 |
-
hist.append({"epoch": epoch + 1, "triplet_loss": float(
|
| 165 |
|
|
|
|
| 166 |
metrics_path = os.path.join(export_dir, "vit_metrics.json")
|
| 167 |
payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
|
| 168 |
with open(metrics_path, "w") as f:
|
| 169 |
json.dump(payload, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
| 2 |
import argparse
|
| 3 |
+
import sys
|
| 4 |
from typing import List
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 8 |
import torch.optim as optim
|
| 9 |
from torch.utils.data import DataLoader
|
| 10 |
|
| 11 |
+
# Fix import paths
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
|
| 14 |
from data.polyvore import PolyvoreOutfitTripletDataset
|
| 15 |
from models.vit_outfit import OutfitCompatibilityModel
|
| 16 |
from models.resnet_embedder import ResNetItemEmbedder
|
|
|
|
| 20 |
|
| 21 |
def parse_args() -> argparse.Namespace:
|
| 22 |
p = argparse.ArgumentParser()
|
| 23 |
+
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
|
| 24 |
p.add_argument("--epochs", type=int, default=30)
|
| 25 |
p.add_argument("--batch_size", type=int, default=32)
|
| 26 |
p.add_argument("--lr", type=float, default=5e-4)
|
|
|
|
| 49 |
if device == "cuda":
|
| 50 |
torch.backends.cudnn.benchmark = True
|
| 51 |
|
| 52 |
+
print(f"π Starting ViT Outfit training on {device}")
|
| 53 |
+
print(f"π Data root: {args.data_root}")
|
| 54 |
+
print(f"βοΈ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}")
|
| 55 |
+
|
| 56 |
# Ensure outfit triplets exist
|
| 57 |
splits_dir = os.path.join(args.data_root, "splits")
|
| 58 |
trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
|
| 59 |
+
|
| 60 |
if not os.path.exists(trip_path):
|
| 61 |
+
print(f"β οΈ Outfit triplet file not found: {trip_path}")
|
| 62 |
+
print("π§ Attempting to prepare dataset...")
|
| 63 |
os.makedirs(splits_dir, exist_ok=True)
|
| 64 |
try:
|
| 65 |
+
# Try to import and run the prepare script
|
| 66 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "scripts"))
|
| 67 |
+
from prepare_polyvore import main as prepare_main
|
| 68 |
+
print("β
Successfully imported prepare_polyvore")
|
| 69 |
+
|
| 70 |
+
# Prepare dataset without random splits
|
| 71 |
+
prepare_main()
|
| 72 |
+
print("β
Dataset preparation completed")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"β Failed to prepare dataset: {e}")
|
| 75 |
+
print("π‘ Please ensure the dataset is prepared manually")
|
| 76 |
+
return
|
| 77 |
+
else:
|
| 78 |
+
print(f"β
Found existing outfit triplets: {trip_path}")
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
|
| 82 |
+
print(f"π Dataset loaded: {len(dataset)} samples")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"β Failed to load dataset: {e}")
|
| 85 |
+
return
|
|
|
|
| 86 |
|
| 87 |
def collate(batch):
|
| 88 |
return batch # variable length handled inside training loop
|
| 89 |
|
| 90 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate)
|
| 91 |
|
| 92 |
model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
|
| 93 |
embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
|
| 94 |
for p in embedder.parameters():
|
| 95 |
p.requires_grad_(False)
|
| 96 |
|
| 97 |
+
print(f"ποΈ Models created:")
|
| 98 |
+
print(f" - ViT Outfit: {model.__class__.__name__}")
|
| 99 |
+
print(f" - ResNet Embedder: {embedder.__class__.__name__}")
|
| 100 |
+
print(f"π Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 101 |
+
|
| 102 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
|
| 103 |
triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
|
| 104 |
|
| 105 |
export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
|
| 106 |
best_loss = float("inf")
|
| 107 |
hist = []
|
| 108 |
+
|
| 109 |
+
print(f"πΎ Checkpoints will be saved to: {export_dir}")
|
| 110 |
+
|
| 111 |
for epoch in range(args.epochs):
|
| 112 |
model.train()
|
| 113 |
+
running_loss = 0.0
|
| 114 |
+
steps = 0
|
| 115 |
+
|
| 116 |
+
print(f"π Epoch {epoch+1}/{args.epochs}")
|
| 117 |
+
|
| 118 |
+
for batch_idx, batch in enumerate(loader):
|
| 119 |
+
try:
|
| 120 |
+
# batch: List[(ga_imgs, gb_imgs, bd_imgs)]
|
| 121 |
+
anchor_tokens = []
|
| 122 |
+
positive_tokens = []
|
| 123 |
+
negative_tokens = []
|
| 124 |
+
|
| 125 |
+
for ga, gb, bd in batch:
|
| 126 |
+
ta = embed_outfit(ga, embedder, device)
|
| 127 |
+
tb = embed_outfit(gb, embedder, device)
|
| 128 |
+
tn = embed_outfit(bd, embedder, device)
|
| 129 |
+
anchor_tokens.append(ta.unsqueeze(0))
|
| 130 |
+
positive_tokens.append(tb.unsqueeze(0))
|
| 131 |
+
negative_tokens.append(tn.unsqueeze(0))
|
| 132 |
+
|
| 133 |
+
A = torch.cat(anchor_tokens, dim=0) # (B, N, D)
|
| 134 |
+
P = torch.cat(positive_tokens, dim=0)
|
| 135 |
+
N = torch.cat(negative_tokens, dim=0)
|
| 136 |
+
|
| 137 |
+
# get outfit-level embeddings via ViT encoder pooled output
|
| 138 |
+
with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
|
| 139 |
+
ea = model.encoder(A).mean(dim=1)
|
| 140 |
+
ep = model.encoder(P).mean(dim=1)
|
| 141 |
+
en = model.encoder(N).mean(dim=1)
|
| 142 |
+
loss = triplet(ea, ep, en)
|
| 143 |
+
|
| 144 |
+
optimizer.zero_grad(set_to_none=True)
|
| 145 |
+
loss.backward()
|
| 146 |
+
optimizer.step()
|
| 147 |
+
|
| 148 |
+
running_loss += loss.item()
|
| 149 |
+
steps += 1
|
| 150 |
+
|
| 151 |
+
if batch_idx % 50 == 0:
|
| 152 |
+
print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}")
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"β Error in batch {batch_idx}: {e}")
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
avg_loss = running_loss / max(1, steps)
|
| 159 |
+
|
| 160 |
# Simple validation using a subset of training data as a proxy if no val split here
|
| 161 |
# For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
|
| 162 |
val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
|
| 163 |
val_loss = None
|
| 164 |
+
|
| 165 |
if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
|
| 166 |
+
try:
|
| 167 |
+
val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
|
| 168 |
+
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: x)
|
| 169 |
+
model.eval()
|
| 170 |
+
losses = []
|
| 171 |
+
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
for vbatch in val_loader:
|
| 174 |
+
anchor_tokens = []
|
| 175 |
+
positive_tokens = []
|
| 176 |
+
negative_tokens = []
|
| 177 |
+
for ga, gb, bd in vbatch:
|
| 178 |
+
ta = embed_outfit(ga, embedder, device)
|
| 179 |
+
tb = embed_outfit(gb, embedder, device)
|
| 180 |
+
tn = embed_outfit(bd, embedder, device)
|
| 181 |
+
anchor_tokens.append(ta.unsqueeze(0))
|
| 182 |
+
positive_tokens.append(tb.unsqueeze(0))
|
| 183 |
+
negative_tokens.append(tn.unsqueeze(0))
|
| 184 |
+
A = torch.cat(anchor_tokens, dim=0)
|
| 185 |
+
P = torch.cat(positive_tokens, dim=0)
|
| 186 |
+
N = torch.cat(negative_tokens, dim=0)
|
| 187 |
+
ea = model.encoder(A).mean(dim=1)
|
| 188 |
+
ep = model.encoder(P).mean(dim=1)
|
| 189 |
+
en = model.encoder(N).mean(dim=1)
|
| 190 |
+
l = triplet(ea, ep, en).item()
|
| 191 |
+
losses.append(l)
|
| 192 |
+
|
| 193 |
+
val_loss = sum(losses) / max(1, len(losses))
|
| 194 |
+
print(f" π Validation loss: {val_loss:.4f}")
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f" β οΈ Validation failed: {e}")
|
| 198 |
|
| 199 |
out_path = args.export
|
| 200 |
if not out_path.startswith("models/"):
|
| 201 |
out_path = os.path.join(export_dir, os.path.basename(args.export))
|
| 202 |
+
|
| 203 |
+
# Save checkpoint
|
| 204 |
+
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path)
|
| 205 |
+
|
| 206 |
if val_loss is not None:
|
| 207 |
+
print(f"β
Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
|
| 208 |
+
hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)})
|
| 209 |
if val_loss < best_loss:
|
| 210 |
best_loss = val_loss
|
| 211 |
+
best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
|
| 212 |
+
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path)
|
| 213 |
+
print(f"π New best model saved: {best_path}")
|
| 214 |
else:
|
| 215 |
+
print(f"β
Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}")
|
| 216 |
+
hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)})
|
| 217 |
|
| 218 |
+
# Write metrics
|
| 219 |
metrics_path = os.path.join(export_dir, "vit_metrics.json")
|
| 220 |
payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
|
| 221 |
with open(metrics_path, "w") as f:
|
| 222 |
json.dump(payload, f)
|
| 223 |
+
|
| 224 |
+
print(f"π Training completed!")
|
| 225 |
+
if best_loss != float("inf"):
|
| 226 |
+
print(f"π Best validation loss: {best_loss:.4f}")
|
| 227 |
+
print(f"π Metrics saved to: {metrics_path}")
|
| 228 |
|
| 229 |
|
| 230 |
if __name__ == "__main__":
|
training_monitor.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple training monitor for Dressify.
|
| 4 |
+
Shows real-time training progress and status.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
|
| 13 |
+
class TrainingMonitor:
|
| 14 |
+
"""Monitor training progress and status."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, export_dir: str = "models/exports"):
|
| 17 |
+
self.export_dir = export_dir
|
| 18 |
+
self.status_file = os.path.join(export_dir, "training_status.json")
|
| 19 |
+
self.start_time = None
|
| 20 |
+
|
| 21 |
+
def start_training(self, model_name: str, config: Dict[str, Any]):
|
| 22 |
+
"""Start monitoring a training session."""
|
| 23 |
+
self.start_time = datetime.now()
|
| 24 |
+
status = {
|
| 25 |
+
"model": model_name,
|
| 26 |
+
"status": "training",
|
| 27 |
+
"start_time": self.start_time.isoformat(),
|
| 28 |
+
"config": config,
|
| 29 |
+
"epochs_completed": 0,
|
| 30 |
+
"current_loss": None,
|
| 31 |
+
"best_loss": float("inf"),
|
| 32 |
+
"last_update": datetime.now().isoformat()
|
| 33 |
+
}
|
| 34 |
+
self._save_status(status)
|
| 35 |
+
print(f"π Started monitoring {model_name} training")
|
| 36 |
+
|
| 37 |
+
def update_progress(self, epoch: int, loss: float, is_best: bool = False):
|
| 38 |
+
"""Update training progress."""
|
| 39 |
+
if not self.start_time:
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
with open(self.status_file, 'r') as f:
|
| 44 |
+
status = json.load(f)
|
| 45 |
+
except:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
status["epochs_completed"] = epoch
|
| 49 |
+
status["current_loss"] = loss
|
| 50 |
+
status["last_update"] = datetime.now().isoformat()
|
| 51 |
+
|
| 52 |
+
if is_best:
|
| 53 |
+
status["best_loss"] = min(status["best_loss"], loss)
|
| 54 |
+
|
| 55 |
+
self._save_status(status)
|
| 56 |
+
|
| 57 |
+
def complete_training(self, final_loss: float, total_epochs: int):
|
| 58 |
+
"""Mark training as completed."""
|
| 59 |
+
try:
|
| 60 |
+
with open(self.status_file, 'r') as f:
|
| 61 |
+
status = json.load(f)
|
| 62 |
+
except:
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
status["status"] = "completed"
|
| 66 |
+
status["epochs_completed"] = total_epochs
|
| 67 |
+
status["current_loss"] = final_loss
|
| 68 |
+
status["best_loss"] = min(status["best_loss"], final_loss)
|
| 69 |
+
status["completion_time"] = datetime.now().isoformat()
|
| 70 |
+
status["duration"] = str(datetime.now() - self.start_time) if self.start_time else None
|
| 71 |
+
|
| 72 |
+
self._save_status(status)
|
| 73 |
+
print(f"β
Training completed in {status['duration']}")
|
| 74 |
+
|
| 75 |
+
def fail_training(self, error: str):
|
| 76 |
+
"""Mark training as failed."""
|
| 77 |
+
try:
|
| 78 |
+
with open(self.status_file, 'r') as f:
|
| 79 |
+
status = json.load(f)
|
| 80 |
+
except:
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
status["status"] = "failed"
|
| 84 |
+
status["error"] = error
|
| 85 |
+
status["failure_time"] = datetime.now().isoformat()
|
| 86 |
+
|
| 87 |
+
self._save_status(status)
|
| 88 |
+
print(f"β Training failed: {error}")
|
| 89 |
+
|
| 90 |
+
def get_status(self) -> Optional[Dict[str, Any]]:
|
| 91 |
+
"""Get current training status."""
|
| 92 |
+
try:
|
| 93 |
+
with open(self.status_file, 'r') as f:
|
| 94 |
+
return json.load(f)
|
| 95 |
+
except:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
def _save_status(self, status: Dict[str, Any]):
|
| 99 |
+
"""Save status to file."""
|
| 100 |
+
os.makedirs(self.export_dir, exist_ok=True)
|
| 101 |
+
with open(self.status_file, 'w') as f:
|
| 102 |
+
json.dump(status, f, indent=2)
|
| 103 |
+
|
| 104 |
+
def print_status(self):
|
| 105 |
+
"""Print current training status."""
|
| 106 |
+
status = self.get_status()
|
| 107 |
+
if not status:
|
| 108 |
+
print("π No training status available")
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
print(f"\nπ Training Status: {status['model']}")
|
| 112 |
+
print(f"Status: {status['status']}")
|
| 113 |
+
print(f"Started: {status['start_time']}")
|
| 114 |
+
print(f"Epochs: {status['epochs_completed']}")
|
| 115 |
+
print(f"Current Loss: {status['current_loss']:.4f}" if status['current_loss'] else "Current Loss: N/A")
|
| 116 |
+
print(f"Best Loss: {status['best_loss']:.4f}" if status['best_loss'] != float("inf") else "Best Loss: N/A")
|
| 117 |
+
print(f"Last Update: {status['last_update']}")
|
| 118 |
+
|
| 119 |
+
if status['status'] == 'completed':
|
| 120 |
+
print(f"Duration: {status['duration']}")
|
| 121 |
+
elif status['status'] == 'failed':
|
| 122 |
+
print(f"Error: {status['error']}")
|
| 123 |
+
|
| 124 |
+
def create_monitor() -> TrainingMonitor:
|
| 125 |
+
"""Create a training monitor instance."""
|
| 126 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 127 |
+
return TrainingMonitor(export_dir)
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
# Test the monitor
|
| 131 |
+
monitor = create_monitor()
|
| 132 |
+
monitor.print_status()
|
utils/artifact_manager.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Comprehensive artifact manager for Dressify.
|
| 4 |
+
Handles packaging, downloading, and organizing all system artifacts.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
import zipfile
|
| 11 |
+
import tarfile
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Dict, List, Any, Optional
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
class ArtifactManager:
|
| 17 |
+
"""Manages all system artifacts for easy download and upload."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, base_dir: str = "/home/user/app"):
|
| 20 |
+
self.base_dir = base_dir
|
| 21 |
+
self.data_dir = os.path.join(base_dir, "data/Polyvore")
|
| 22 |
+
self.splits_dir = os.path.join(self.data_dir, "splits")
|
| 23 |
+
self.export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 24 |
+
|
| 25 |
+
# Default HF repositories - updated to use your specific repos
|
| 26 |
+
self.default_repos = {
|
| 27 |
+
"splits": "Stylique/Dressify-Helper",
|
| 28 |
+
"models": "Stylique/dressify-models",
|
| 29 |
+
"metadata": "Stylique/Dressify-Helper"
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Repository organization structure
|
| 33 |
+
self.repo_structure = {
|
| 34 |
+
"Stylique/dressify-models": {
|
| 35 |
+
"description": "Dressify trained models and checkpoints",
|
| 36 |
+
"files": {
|
| 37 |
+
"resnet_item_embedder_best.pth": "ResNet50 item embedder (best checkpoint)",
|
| 38 |
+
"vit_outfit_model_best.pth": "ViT outfit compatibility model (best checkpoint)",
|
| 39 |
+
"resnet_metrics.json": "ResNet training metrics and history",
|
| 40 |
+
"vit_metrics.json": "ViT training metrics and history",
|
| 41 |
+
"model_cards/": "Model documentation and cards"
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"Stylique/Dressify-Helper": {
|
| 45 |
+
"description": "Dressify dataset splits, metadata, and helper files",
|
| 46 |
+
"files": {
|
| 47 |
+
"splits/": "Dataset splits (train/valid/test)",
|
| 48 |
+
"metadata/": "Item metadata and outfit information",
|
| 49 |
+
"configs/": "Training configurations",
|
| 50 |
+
"packages/": "Pre-packaged downloads"
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def get_artifact_summary(self) -> Dict[str, Any]:
|
| 56 |
+
"""Get comprehensive summary of all available artifacts."""
|
| 57 |
+
summary = {
|
| 58 |
+
"timestamp": datetime.now().isoformat(),
|
| 59 |
+
"datasets": self._get_dataset_info(),
|
| 60 |
+
"splits": self._get_splits_info(),
|
| 61 |
+
"models": self._get_models_info(),
|
| 62 |
+
"configs": self._get_configs_info(),
|
| 63 |
+
"metadata": self._get_metadata_info(),
|
| 64 |
+
"hf_repos": self.repo_structure,
|
| 65 |
+
"total_size_mb": 0
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Calculate total size
|
| 69 |
+
total_size = 0
|
| 70 |
+
for category in summary.values():
|
| 71 |
+
if isinstance(category, dict) and "size_mb" in category:
|
| 72 |
+
total_size += category["size_mb"]
|
| 73 |
+
summary["total_size_mb"] = round(total_size, 2)
|
| 74 |
+
|
| 75 |
+
return summary
|
| 76 |
+
|
| 77 |
+
def _get_dataset_info(self) -> Dict[str, Any]:
|
| 78 |
+
"""Get information about the Polyvore dataset."""
|
| 79 |
+
info = {
|
| 80 |
+
"status": "not_found",
|
| 81 |
+
"size_mb": 0,
|
| 82 |
+
"files": [],
|
| 83 |
+
"images_count": 0
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
if os.path.exists(self.data_dir):
|
| 87 |
+
info["status"] = "available"
|
| 88 |
+
|
| 89 |
+
# Count images
|
| 90 |
+
images_dir = os.path.join(self.data_dir, "images")
|
| 91 |
+
if os.path.exists(images_dir):
|
| 92 |
+
try:
|
| 93 |
+
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
| 94 |
+
info["images_count"] = len(image_files)
|
| 95 |
+
except:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
# Calculate size
|
| 99 |
+
try:
|
| 100 |
+
total_size = sum(os.path.getsize(os.path.join(dirpath, filename))
|
| 101 |
+
for dirpath, dirnames, filenames in os.walk(self.data_dir)
|
| 102 |
+
for filename in filenames)
|
| 103 |
+
info["size_mb"] = round(total_size / (1024 * 1024), 2)
|
| 104 |
+
except:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
# List key files
|
| 108 |
+
key_files = ["images.zip", "polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
|
| 109 |
+
for file in key_files:
|
| 110 |
+
file_path = os.path.join(self.data_dir, file)
|
| 111 |
+
if os.path.exists(file_path):
|
| 112 |
+
info["files"].append({
|
| 113 |
+
"name": file,
|
| 114 |
+
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
|
| 115 |
+
"path": file_path
|
| 116 |
+
})
|
| 117 |
+
|
| 118 |
+
return info
|
| 119 |
+
|
| 120 |
+
def _get_splits_info(self) -> Dict[str, Any]:
|
| 121 |
+
"""Get information about dataset splits."""
|
| 122 |
+
info = {
|
| 123 |
+
"status": "not_found",
|
| 124 |
+
"size_mb": 0,
|
| 125 |
+
"files": [],
|
| 126 |
+
"splits_available": []
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
if os.path.exists(self.splits_dir):
|
| 130 |
+
info["status"] = "available"
|
| 131 |
+
|
| 132 |
+
split_files = [
|
| 133 |
+
"train.json", "valid.json", "test.json",
|
| 134 |
+
"outfits_train.json", "outfits_valid.json", "outfits_test.json",
|
| 135 |
+
"outfit_triplets_train.json", "outfit_triplets_valid.json", "outfit_triplets_test.json"
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
total_size = 0
|
| 139 |
+
for file in split_files:
|
| 140 |
+
file_path = os.path.join(self.splits_dir, file)
|
| 141 |
+
if os.path.exists(file_path):
|
| 142 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 143 |
+
total_size += size_mb
|
| 144 |
+
info["files"].append({
|
| 145 |
+
"name": file,
|
| 146 |
+
"size_mb": size_mb,
|
| 147 |
+
"path": file_path
|
| 148 |
+
})
|
| 149 |
+
info["splits_available"].append(file.replace(".json", ""))
|
| 150 |
+
|
| 151 |
+
info["size_mb"] = round(total_size, 2)
|
| 152 |
+
|
| 153 |
+
return info
|
| 154 |
+
|
| 155 |
+
def _get_models_info(self) -> Dict[str, Any]:
|
| 156 |
+
"""Get information about trained models."""
|
| 157 |
+
info = {
|
| 158 |
+
"status": "not_found",
|
| 159 |
+
"size_mb": 0,
|
| 160 |
+
"files": [],
|
| 161 |
+
"models_available": []
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
if os.path.exists(self.export_dir):
|
| 165 |
+
info["status"] = "available"
|
| 166 |
+
|
| 167 |
+
model_files = [
|
| 168 |
+
"resnet_item_embedder.pth", "resnet_item_embedder_best.pth",
|
| 169 |
+
"vit_outfit_model.pth", "vit_outfit_model_best.pth",
|
| 170 |
+
"resnet_metrics.json", "vit_metrics.json"
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
total_size = 0
|
| 174 |
+
for file in model_files:
|
| 175 |
+
file_path = os.path.join(self.export_dir, file)
|
| 176 |
+
if os.path.exists(file_path):
|
| 177 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 178 |
+
total_size += size_mb
|
| 179 |
+
info["files"].append({
|
| 180 |
+
"name": file,
|
| 181 |
+
"size_mb": size_mb,
|
| 182 |
+
"path": file_path,
|
| 183 |
+
"type": "checkpoint" if file.endswith(".pth") else "metrics"
|
| 184 |
+
})
|
| 185 |
+
if file.endswith(".pth"):
|
| 186 |
+
info["models_available"].append(file.replace(".pth", ""))
|
| 187 |
+
|
| 188 |
+
info["size_mb"] = round(total_size, 2)
|
| 189 |
+
|
| 190 |
+
return info
|
| 191 |
+
|
| 192 |
+
def _get_configs_info(self) -> Dict[str, Any]:
|
| 193 |
+
"""Get information about configuration files."""
|
| 194 |
+
info = {
|
| 195 |
+
"status": "not_found",
|
| 196 |
+
"size_mb": 0,
|
| 197 |
+
"files": []
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
config_files = [
|
| 201 |
+
"resnet_config_custom.json", "vit_config_custom.json",
|
| 202 |
+
"item.yaml", "outfit.yaml", "default.yaml"
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
total_size = 0
|
| 206 |
+
for file in config_files:
|
| 207 |
+
# Check export dir first, then configs dir
|
| 208 |
+
file_path = os.path.join(self.export_dir, file)
|
| 209 |
+
if not os.path.exists(file_path):
|
| 210 |
+
file_path = os.path.join("configs", file)
|
| 211 |
+
|
| 212 |
+
if os.path.exists(file_path):
|
| 213 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 214 |
+
total_size += size_mb
|
| 215 |
+
info["files"].append({
|
| 216 |
+
"name": file,
|
| 217 |
+
"size_mb": size_mb,
|
| 218 |
+
"path": file_path
|
| 219 |
+
})
|
| 220 |
+
|
| 221 |
+
if info["files"]:
|
| 222 |
+
info["status"] = "available"
|
| 223 |
+
info["size_mb"] = round(total_size, 2)
|
| 224 |
+
|
| 225 |
+
return info
|
| 226 |
+
|
| 227 |
+
def _get_metadata_info(self) -> Dict[str, Any]:
|
| 228 |
+
"""Get information about metadata files."""
|
| 229 |
+
info = {
|
| 230 |
+
"status": "not_found",
|
| 231 |
+
"size_mb": 0,
|
| 232 |
+
"files": []
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
metadata_files = [
|
| 236 |
+
"polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
total_size = 0
|
| 240 |
+
for file in metadata_files:
|
| 241 |
+
file_path = os.path.join(self.data_dir, file)
|
| 242 |
+
if os.path.exists(file_path):
|
| 243 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 244 |
+
total_size += size_mb
|
| 245 |
+
info["files"].append({
|
| 246 |
+
"name": file,
|
| 247 |
+
"size_mb": size_mb,
|
| 248 |
+
"path": file_path
|
| 249 |
+
})
|
| 250 |
+
|
| 251 |
+
if info["files"]:
|
| 252 |
+
info["status"] = "available"
|
| 253 |
+
info["size_mb"] = round(total_size, 2)
|
| 254 |
+
|
| 255 |
+
return info
|
| 256 |
+
|
| 257 |
+
def create_download_package(self, package_type: str = "complete") -> str:
|
| 258 |
+
"""Create a downloadable package of artifacts."""
|
| 259 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 260 |
+
|
| 261 |
+
if package_type == "complete":
|
| 262 |
+
# Complete package with everything
|
| 263 |
+
package_name = f"dressify_complete_{timestamp}"
|
| 264 |
+
package_path = os.path.join(self.export_dir, f"{package_name}.tar.gz")
|
| 265 |
+
|
| 266 |
+
with tarfile.open(package_path, "w:gz") as tar:
|
| 267 |
+
# Add splits
|
| 268 |
+
if os.path.exists(self.splits_dir):
|
| 269 |
+
tar.add(self.splits_dir, arcname="splits")
|
| 270 |
+
|
| 271 |
+
# Add models
|
| 272 |
+
if os.path.exists(self.export_dir):
|
| 273 |
+
for file in os.listdir(self.export_dir):
|
| 274 |
+
if file.endswith((".pth", ".json", ".yaml")):
|
| 275 |
+
tar.add(os.path.join(self.export_dir, file), arcname=f"models/{file}")
|
| 276 |
+
|
| 277 |
+
# Add metadata
|
| 278 |
+
metadata_files = ["polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
|
| 279 |
+
for file in metadata_files:
|
| 280 |
+
file_path = os.path.join(self.data_dir, file)
|
| 281 |
+
if os.path.exists(file_path):
|
| 282 |
+
tar.add(file_path, arcname=f"metadata/{file}")
|
| 283 |
+
|
| 284 |
+
# Add configs
|
| 285 |
+
configs_dir = "configs"
|
| 286 |
+
if os.path.exists(configs_dir):
|
| 287 |
+
tar.add(configs_dir, arcname="configs")
|
| 288 |
+
|
| 289 |
+
elif package_type == "splits_only":
|
| 290 |
+
# Only splits (lightweight)
|
| 291 |
+
package_name = f"dressify_splits_{timestamp}"
|
| 292 |
+
package_path = os.path.join(self.export_dir, f"{package_name}.tar.gz")
|
| 293 |
+
|
| 294 |
+
with tarfile.open(package_path, "w:gz") as tar:
|
| 295 |
+
if os.path.exists(self.splits_dir):
|
| 296 |
+
tar.add(self.splits_dir, arcname="splits")
|
| 297 |
+
|
| 298 |
+
elif package_type == "models_only":
|
| 299 |
+
# Only trained models
|
| 300 |
+
package_name = f"dressify_models_{timestamp}"
|
| 301 |
+
package_path = os.path.join(self.export_dir, f"{package_name}.tar.gz")
|
| 302 |
+
|
| 303 |
+
with tarfile.open(package_path, "w:gz") as tar:
|
| 304 |
+
if os.path.exists(self.export_dir):
|
| 305 |
+
for file in os.listdir(self.export_dir):
|
| 306 |
+
if file.endswith((".pth", ".json")):
|
| 307 |
+
tar.add(os.path.join(self.export_dir, file), arcname=f"models/{file}")
|
| 308 |
+
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(f"Unknown package type: {package_type}")
|
| 311 |
+
|
| 312 |
+
return package_path
|
| 313 |
+
|
| 314 |
+
def get_downloadable_files(self) -> List[Dict[str, Any]]:
|
| 315 |
+
"""Get list of all downloadable files."""
|
| 316 |
+
files = []
|
| 317 |
+
|
| 318 |
+
# Add splits
|
| 319 |
+
if os.path.exists(self.splits_dir):
|
| 320 |
+
for file in os.listdir(self.splits_dir):
|
| 321 |
+
if file.endswith(".json"):
|
| 322 |
+
file_path = os.path.join(self.splits_dir, file)
|
| 323 |
+
files.append({
|
| 324 |
+
"name": f"splits/{file}",
|
| 325 |
+
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
|
| 326 |
+
"path": file_path,
|
| 327 |
+
"category": "splits",
|
| 328 |
+
"description": f"Dataset split: {file.replace('.json', '')}"
|
| 329 |
+
})
|
| 330 |
+
|
| 331 |
+
# Add models
|
| 332 |
+
if os.path.exists(self.export_dir):
|
| 333 |
+
for file in os.listdir(self.export_dir):
|
| 334 |
+
if file.endswith((".pth", ".json")):
|
| 335 |
+
file_path = os.path.join(self.export_dir, file)
|
| 336 |
+
files.append({
|
| 337 |
+
"name": f"models/{file}",
|
| 338 |
+
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
|
| 339 |
+
"path": file_path,
|
| 340 |
+
"category": "models",
|
| 341 |
+
"description": "Trained model or metrics"
|
| 342 |
+
})
|
| 343 |
+
|
| 344 |
+
# Add metadata
|
| 345 |
+
metadata_files = ["polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
|
| 346 |
+
for file in metadata_files:
|
| 347 |
+
file_path = os.path.join(self.data_dir, file)
|
| 348 |
+
if os.path.exists(file_path):
|
| 349 |
+
files.append({
|
| 350 |
+
"name": f"metadata/{file}",
|
| 351 |
+
"size_mb": round(os.path.getsize(file_path) / (1024 * 1024), 2),
|
| 352 |
+
"path": file_path,
|
| 353 |
+
"category": "metadata",
|
| 354 |
+
"description": "Dataset metadata"
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
return files
|
| 358 |
+
|
| 359 |
+
def create_hf_upload_plan(self) -> Dict[str, Any]:
|
| 360 |
+
"""Create a plan for uploading to HF Hub."""
|
| 361 |
+
plan = {
|
| 362 |
+
"Stylique/dressify-models": {
|
| 363 |
+
"description": "Upload trained models and checkpoints",
|
| 364 |
+
"files_to_upload": [],
|
| 365 |
+
"estimated_size_mb": 0
|
| 366 |
+
},
|
| 367 |
+
"Stylique/Dressify-Helper": {
|
| 368 |
+
"description": "Upload dataset splits and metadata",
|
| 369 |
+
"files_to_upload": [],
|
| 370 |
+
"estimated_size_mb": 0
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
# Plan for models repo
|
| 375 |
+
if os.path.exists(self.export_dir):
|
| 376 |
+
for file in os.listdir(self.export_dir):
|
| 377 |
+
if file.endswith((".pth", ".json")):
|
| 378 |
+
file_path = os.path.join(self.export_dir, file)
|
| 379 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 380 |
+
plan["Stylique/dressify-models"]["files_to_upload"].append({
|
| 381 |
+
"name": file,
|
| 382 |
+
"path": file_path,
|
| 383 |
+
"size_mb": size_mb
|
| 384 |
+
})
|
| 385 |
+
plan["Stylique/dressify-models"]["estimated_size_mb"] += size_mb
|
| 386 |
+
|
| 387 |
+
# Plan for helper repo
|
| 388 |
+
if os.path.exists(self.splits_dir):
|
| 389 |
+
for file in os.listdir(self.splits_dir):
|
| 390 |
+
if file.endswith(".json"):
|
| 391 |
+
file_path = os.path.join(self.splits_dir, file)
|
| 392 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 393 |
+
plan["Stylique/Dressify-Helper"]["files_to_upload"].append({
|
| 394 |
+
"name": f"splits/{file}",
|
| 395 |
+
"path": file_path,
|
| 396 |
+
"size_mb": size_mb
|
| 397 |
+
})
|
| 398 |
+
plan["Stylique/Dressify-Helper"]["estimated_size_mb"] += size_mb
|
| 399 |
+
|
| 400 |
+
# Add metadata files
|
| 401 |
+
metadata_files = ["polyvore_item_metadata.json", "polyvore_outfit_titles.json", "categories.csv"]
|
| 402 |
+
for file in metadata_files:
|
| 403 |
+
file_path = os.path.join(self.data_dir, file)
|
| 404 |
+
if os.path.exists(file_path):
|
| 405 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 406 |
+
plan["Stylique/Dressify-Helper"]["files_to_upload"].append({
|
| 407 |
+
"name": f"metadata/{file}",
|
| 408 |
+
"path": file_path,
|
| 409 |
+
"size_mb": size_mb
|
| 410 |
+
})
|
| 411 |
+
plan["Stylique/Dressify-Helper"]["estimated_size_mb"] += size_mb
|
| 412 |
+
|
| 413 |
+
return plan
|
| 414 |
+
|
| 415 |
+
def create_artifact_manager() -> ArtifactManager:
|
| 416 |
+
"""Create an artifact manager instance."""
|
| 417 |
+
return ArtifactManager()
|
utils/export.py
CHANGED
|
@@ -5,11 +5,19 @@ import torch
|
|
| 5 |
|
| 6 |
|
| 7 |
def ensure_export_dir(path: str) -> str:
|
|
|
|
| 8 |
os.makedirs(path, exist_ok=True)
|
| 9 |
return path
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
|
|
|
|
| 13 |
model.eval()
|
| 14 |
traced = torch.jit.trace(model, example_inputs)
|
| 15 |
torch.jit.save(traced, out_path)
|
|
@@ -17,6 +25,7 @@ def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out
|
|
| 17 |
|
| 18 |
|
| 19 |
def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
|
|
|
|
| 20 |
model.eval()
|
| 21 |
torch.onnx.export(
|
| 22 |
model,
|
|
@@ -32,6 +41,20 @@ def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path:
|
|
| 32 |
return out_path
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def ensure_export_dir(path: str) -> str:
|
| 8 |
+
"""Create export directory and all parent directories if they don't exist."""
|
| 9 |
os.makedirs(path, exist_ok=True)
|
| 10 |
return path
|
| 11 |
|
| 12 |
|
| 13 |
+
def get_export_dir() -> str:
|
| 14 |
+
"""Get the default export directory, creating it if necessary."""
|
| 15 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 16 |
+
return ensure_export_dir(export_dir)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
|
| 20 |
+
"""Export model to TorchScript format."""
|
| 21 |
model.eval()
|
| 22 |
traced = torch.jit.trace(model, example_inputs)
|
| 23 |
torch.jit.save(traced, out_path)
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
|
| 28 |
+
"""Export model to ONNX format."""
|
| 29 |
model.eval()
|
| 30 |
torch.onnx.export(
|
| 31 |
model,
|
|
|
|
| 41 |
return out_path
|
| 42 |
|
| 43 |
|
| 44 |
+
def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str:
|
| 45 |
+
"""Save model checkpoint with metadata."""
|
| 46 |
+
# Ensure directory exists
|
| 47 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 48 |
+
|
| 49 |
+
# Save checkpoint
|
| 50 |
+
checkpoint = {
|
| 51 |
+
"state_dict": model.state_dict(),
|
| 52 |
+
**kwargs
|
| 53 |
+
}
|
| 54 |
+
torch.save(checkpoint, path)
|
| 55 |
+
return path
|
| 56 |
+
|
| 57 |
+
|
| 58 |
|
| 59 |
|
| 60 |
|
utils/hf_hub_integration.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Hugging Face Hub integration for Dressify.
|
| 4 |
+
Handles uploading artifacts to specific HF repositories.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import Dict, List, Any, Optional
|
| 12 |
+
from huggingface_hub import HfApi, create_repo, upload_file, upload_folder
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
class HFHubIntegration:
|
| 16 |
+
"""Integrates with Hugging Face Hub for artifact management."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, token: str = None):
|
| 19 |
+
self.api = HfApi(token=token)
|
| 20 |
+
self.token = token
|
| 21 |
+
|
| 22 |
+
# Your specific repositories
|
| 23 |
+
self.repos = {
|
| 24 |
+
"models": "Stylique/dressify-models",
|
| 25 |
+
"helper": "Stylique/Dressify-Helper"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# Repository descriptions and metadata
|
| 29 |
+
self.repo_metadata = {
|
| 30 |
+
"Stylique/dressify-models": {
|
| 31 |
+
"description": "Dressify trained models and checkpoints for outfit recommendation",
|
| 32 |
+
"tags": ["computer-vision", "fashion", "outfit-recommendation", "deep-learning"],
|
| 33 |
+
"license": "mit",
|
| 34 |
+
"language": "en"
|
| 35 |
+
},
|
| 36 |
+
"Stylique/Dressify-Helper": {
|
| 37 |
+
"description": "Dressify dataset splits, metadata, and helper files",
|
| 38 |
+
"tags": ["dataset", "fashion", "outfit-recommendation", "polyvore"],
|
| 39 |
+
"license": "mit",
|
| 40 |
+
"language": "en"
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def ensure_repos_exist(self) -> Dict[str, bool]:
|
| 45 |
+
"""Ensure all required repositories exist, create if they don't."""
|
| 46 |
+
results = {}
|
| 47 |
+
|
| 48 |
+
for repo_id in self.repos.values():
|
| 49 |
+
try:
|
| 50 |
+
# Try to get repo info
|
| 51 |
+
repo_info = self.api.repo_info(repo_id)
|
| 52 |
+
results[repo_id] = True
|
| 53 |
+
print(f"β
Repository exists: {repo_id}")
|
| 54 |
+
except Exception:
|
| 55 |
+
try:
|
| 56 |
+
# Create repository
|
| 57 |
+
if "models" in repo_id:
|
| 58 |
+
create_repo(
|
| 59 |
+
repo_id=repo_id,
|
| 60 |
+
repo_type="model",
|
| 61 |
+
token=self.token,
|
| 62 |
+
description=self.repo_metadata[repo_id]["description"],
|
| 63 |
+
license=self.repo_metadata[repo_id]["license"],
|
| 64 |
+
tags=self.repo_metadata[repo_id]["tags"]
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
create_repo(
|
| 68 |
+
repo_id=repo_id,
|
| 69 |
+
repo_type="dataset",
|
| 70 |
+
token=self.token,
|
| 71 |
+
description=self.repo_metadata[repo_id]["description"],
|
| 72 |
+
license=self.repo_metadata[repo_id]["license"],
|
| 73 |
+
tags=self.repo_metadata[repo_id]["tags"]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
results[repo_id] = True
|
| 77 |
+
print(f"β
Created repository: {repo_id}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
results[repo_id] = False
|
| 80 |
+
print(f"β Failed to create repository {repo_id}: {e}")
|
| 81 |
+
|
| 82 |
+
return results
|
| 83 |
+
|
| 84 |
+
def upload_models_to_hf(self, models_dir: str = None) -> Dict[str, Any]:
|
| 85 |
+
"""Upload trained models to the models repository."""
|
| 86 |
+
if models_dir is None:
|
| 87 |
+
models_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 88 |
+
|
| 89 |
+
if not os.path.exists(models_dir):
|
| 90 |
+
return {"success": False, "error": f"Models directory not found: {models_dir}"}
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
print(f"π Uploading models to {self.repos['models']}...")
|
| 94 |
+
|
| 95 |
+
# Files to upload
|
| 96 |
+
model_files = [
|
| 97 |
+
"resnet_item_embedder_best.pth",
|
| 98 |
+
"vit_outfit_model_best.pth",
|
| 99 |
+
"resnet_metrics.json",
|
| 100 |
+
"vit_metrics.json"
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
uploaded_files = []
|
| 104 |
+
total_size = 0
|
| 105 |
+
|
| 106 |
+
for file in model_files:
|
| 107 |
+
file_path = os.path.join(models_dir, file)
|
| 108 |
+
if os.path.exists(file_path):
|
| 109 |
+
try:
|
| 110 |
+
# Upload file
|
| 111 |
+
self.api.upload_file(
|
| 112 |
+
path_or_fileobj=file_path,
|
| 113 |
+
path_in_repo=file,
|
| 114 |
+
repo_id=self.repos['models'],
|
| 115 |
+
token=self.token
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 119 |
+
total_size += size_mb
|
| 120 |
+
uploaded_files.append({
|
| 121 |
+
"name": file,
|
| 122 |
+
"size_mb": size_mb,
|
| 123 |
+
"status": "uploaded"
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
print(f"β
Uploaded: {file} ({size_mb} MB)")
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
uploaded_files.append({
|
| 130 |
+
"name": file,
|
| 131 |
+
"status": "failed",
|
| 132 |
+
"error": str(e)
|
| 133 |
+
})
|
| 134 |
+
print(f"β Failed to upload {file}: {e}")
|
| 135 |
+
|
| 136 |
+
# Create model card
|
| 137 |
+
self._create_model_card()
|
| 138 |
+
|
| 139 |
+
result = {
|
| 140 |
+
"success": True,
|
| 141 |
+
"repository": self.repos['models'],
|
| 142 |
+
"uploaded_files": uploaded_files,
|
| 143 |
+
"total_size_mb": total_size,
|
| 144 |
+
"timestamp": datetime.now().isoformat()
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
print(f"π Models upload completed! Total size: {total_size} MB")
|
| 148 |
+
return result
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
return {"success": False, "error": str(e)}
|
| 152 |
+
|
| 153 |
+
def upload_splits_to_hf(self, splits_dir: str = None) -> Dict[str, Any]:
|
| 154 |
+
"""Upload dataset splits to the helper repository."""
|
| 155 |
+
if splits_dir is None:
|
| 156 |
+
splits_dir = os.path.join(os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"), "splits")
|
| 157 |
+
|
| 158 |
+
if not os.path.exists(splits_dir):
|
| 159 |
+
return {"success": False, "error": f"Splits directory not found: {splits_dir}"}
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
print(f"π Uploading splits to {self.repos['helper']}...")
|
| 163 |
+
|
| 164 |
+
# Upload entire splits directory
|
| 165 |
+
self.api.upload_folder(
|
| 166 |
+
folder_path=splits_dir,
|
| 167 |
+
path_in_repo="splits",
|
| 168 |
+
repo_id=self.repos['helper'],
|
| 169 |
+
token=self.token
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Calculate total size
|
| 173 |
+
total_size = 0
|
| 174 |
+
for root, dirs, files in os.walk(splits_dir):
|
| 175 |
+
for file in files:
|
| 176 |
+
file_path = os.path.join(root, file)
|
| 177 |
+
total_size += os.path.getsize(file_path)
|
| 178 |
+
|
| 179 |
+
total_size_mb = round(total_size / (1024 * 1024), 2)
|
| 180 |
+
|
| 181 |
+
result = {
|
| 182 |
+
"success": True,
|
| 183 |
+
"repository": self.repos['helper'],
|
| 184 |
+
"uploaded_folder": "splits",
|
| 185 |
+
"total_size_mb": total_size_mb,
|
| 186 |
+
"timestamp": datetime.now().isoformat()
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
print(f"π Splits upload completed! Total size: {total_size_mb} MB")
|
| 190 |
+
return result
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
return {"success": False, "error": str(e)}
|
| 194 |
+
|
| 195 |
+
def upload_metadata_to_hf(self, data_dir: str = None) -> Dict[str, Any]:
|
| 196 |
+
"""Upload metadata files to the helper repository."""
|
| 197 |
+
if data_dir is None:
|
| 198 |
+
data_dir = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")
|
| 199 |
+
|
| 200 |
+
if not os.path.exists(data_dir):
|
| 201 |
+
return {"success": False, "error": f"Data directory not found: {data_dir}"}
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
print(f"π Uploading metadata to {self.repos['helper']}...")
|
| 205 |
+
|
| 206 |
+
# Metadata files to upload
|
| 207 |
+
metadata_files = [
|
| 208 |
+
"polyvore_item_metadata.json",
|
| 209 |
+
"polyvore_outfit_titles.json",
|
| 210 |
+
"categories.csv"
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
uploaded_files = []
|
| 214 |
+
total_size = 0
|
| 215 |
+
|
| 216 |
+
for file in metadata_files:
|
| 217 |
+
file_path = os.path.join(data_dir, file)
|
| 218 |
+
if os.path.exists(file_path):
|
| 219 |
+
try:
|
| 220 |
+
# Upload to metadata subfolder
|
| 221 |
+
self.api.upload_file(
|
| 222 |
+
path_or_fileobj=file_path,
|
| 223 |
+
path_in_repo=f"metadata/{file}",
|
| 224 |
+
repo_id=self.repos['helper'],
|
| 225 |
+
token=self.token
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
size_mb = round(os.path.getsize(file_path) / (1024 * 1024), 2)
|
| 229 |
+
total_size += size_mb
|
| 230 |
+
uploaded_files.append({
|
| 231 |
+
"name": file,
|
| 232 |
+
"size_mb": size_mb,
|
| 233 |
+
"status": "uploaded"
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
print(f"β
Uploaded: {file} ({size_mb} MB)")
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
uploaded_files.append({
|
| 240 |
+
"name": file,
|
| 241 |
+
"status": "failed",
|
| 242 |
+
"error": str(e)
|
| 243 |
+
})
|
| 244 |
+
print(f"β Failed to upload {file}: {e}")
|
| 245 |
+
|
| 246 |
+
result = {
|
| 247 |
+
"success": True,
|
| 248 |
+
"repository": self.repos['helper'],
|
| 249 |
+
"uploaded_files": uploaded_files,
|
| 250 |
+
"total_size_mb": total_size,
|
| 251 |
+
"timestamp": datetime.now().isoformat()
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
print(f"π Metadata upload completed! Total size: {total_size} MB")
|
| 255 |
+
return result
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
return {"success": False, "error": str(e)}
|
| 259 |
+
|
| 260 |
+
def upload_everything_to_hf(self) -> Dict[str, Any]:
|
| 261 |
+
"""Upload all artifacts to HF Hub."""
|
| 262 |
+
print("π Starting comprehensive upload to HF Hub...")
|
| 263 |
+
|
| 264 |
+
# Ensure repositories exist
|
| 265 |
+
repo_status = self.ensure_repos_exist()
|
| 266 |
+
if not all(repo_status.values()):
|
| 267 |
+
return {"success": False, "error": "Failed to ensure repositories exist"}
|
| 268 |
+
|
| 269 |
+
# Upload everything
|
| 270 |
+
results = {
|
| 271 |
+
"models": self.upload_models_to_hf(),
|
| 272 |
+
"splits": self.upload_splits_to_hf(),
|
| 273 |
+
"metadata": self.upload_metadata_to_hf(),
|
| 274 |
+
"timestamp": datetime.now().isoformat()
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
# Summary
|
| 278 |
+
success_count = sum(1 for r in results.values() if isinstance(r, dict) and r.get("success", False))
|
| 279 |
+
total_count = len([r for r in results.values() if isinstance(r, dict)])
|
| 280 |
+
|
| 281 |
+
print(f"\nπ Upload Summary: {success_count}/{total_count} successful")
|
| 282 |
+
for category, result in results.items():
|
| 283 |
+
if isinstance(result, dict):
|
| 284 |
+
status = "β
" if result.get("success", False) else "β"
|
| 285 |
+
print(f" {status} {category}")
|
| 286 |
+
|
| 287 |
+
return results
|
| 288 |
+
|
| 289 |
+
def _create_model_card(self):
|
| 290 |
+
"""Create a model card for the models repository."""
|
| 291 |
+
model_card_content = """---
|
| 292 |
+
language: en
|
| 293 |
+
license: mit
|
| 294 |
+
tags:
|
| 295 |
+
- computer-vision
|
| 296 |
+
- fashion
|
| 297 |
+
- outfit-recommendation
|
| 298 |
+
- deep-learning
|
| 299 |
+
- resnet
|
| 300 |
+
- vision-transformer
|
| 301 |
+
---
|
| 302 |
+
|
| 303 |
+
# Dressify Outfit Recommendation Models
|
| 304 |
+
|
| 305 |
+
This repository contains the trained models for the Dressify outfit recommendation system.
|
| 306 |
+
|
| 307 |
+
## Models
|
| 308 |
+
|
| 309 |
+
### ResNet Item Embedder
|
| 310 |
+
- **Architecture**: ResNet50 with custom projection head
|
| 311 |
+
- **Purpose**: Generate 512-dimensional embeddings for fashion items
|
| 312 |
+
- **Training**: Triplet loss with semi-hard negative mining
|
| 313 |
+
- **Input**: Fashion item images (224x224)
|
| 314 |
+
- **Output**: L2-normalized 512D embeddings
|
| 315 |
+
|
| 316 |
+
### ViT Outfit Compatibility Model
|
| 317 |
+
- **Architecture**: Vision Transformer encoder
|
| 318 |
+
- **Purpose**: Score outfit compatibility from item embeddings
|
| 319 |
+
- **Training**: Triplet loss with cosine distance
|
| 320 |
+
- **Input**: Variable-length sequence of item embeddings
|
| 321 |
+
- **Output**: Compatibility score (0-1)
|
| 322 |
+
|
| 323 |
+
## Usage
|
| 324 |
+
|
| 325 |
+
```python
|
| 326 |
+
from huggingface_hub import hf_hub_download
|
| 327 |
+
import torch
|
| 328 |
+
|
| 329 |
+
# Download models
|
| 330 |
+
resnet_path = hf_hub_download(repo_id="Stylique/dressify-models", filename="resnet_item_embedder_best.pth")
|
| 331 |
+
vit_path = hf_hub_download(repo_id="Stylique/dressify-models", filename="vit_outfit_model_best.pth")
|
| 332 |
+
|
| 333 |
+
# Load models
|
| 334 |
+
resnet_model = torch.load(resnet_path)
|
| 335 |
+
vit_model = torch.load(vit_path)
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
## Training Details
|
| 339 |
+
|
| 340 |
+
- **Dataset**: Polyvore Outfits (Stylique/Polyvore)
|
| 341 |
+
- **Loss**: Triplet margin loss
|
| 342 |
+
- **Optimizer**: AdamW
|
| 343 |
+
- **Mixed Precision**: Enabled
|
| 344 |
+
- **Hardware**: NVIDIA GPU with CUDA
|
| 345 |
+
|
| 346 |
+
## Performance
|
| 347 |
+
|
| 348 |
+
- **ResNet**: ~25M parameters, fast inference
|
| 349 |
+
- **ViT**: ~12M parameters, efficient outfit scoring
|
| 350 |
+
- **Memory**: Optimized for deployment on Hugging Face Spaces
|
| 351 |
+
|
| 352 |
+
## Citation
|
| 353 |
+
|
| 354 |
+
If you use these models in your research, please cite:
|
| 355 |
+
|
| 356 |
+
```bibtex
|
| 357 |
+
@misc{dressify2024,
|
| 358 |
+
title={Dressify: Deep Learning for Fashion Outfit Recommendation},
|
| 359 |
+
author={Stylique},
|
| 360 |
+
year={2024},
|
| 361 |
+
url={https://huggingface.co/Stylique/dressify-models}
|
| 362 |
+
}
|
| 363 |
+
```
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Save model card
|
| 367 |
+
model_card_path = "model_card.md"
|
| 368 |
+
with open(model_card_path, 'w') as f:
|
| 369 |
+
f.write(model_card_content)
|
| 370 |
+
|
| 371 |
+
# Upload model card
|
| 372 |
+
try:
|
| 373 |
+
self.api.upload_file(
|
| 374 |
+
path_or_fileobj=model_card_path,
|
| 375 |
+
path_in_repo="README.md",
|
| 376 |
+
repo_id=self.repos['models'],
|
| 377 |
+
token=self.token
|
| 378 |
+
)
|
| 379 |
+
print("β
Model card uploaded")
|
| 380 |
+
|
| 381 |
+
# Clean up
|
| 382 |
+
os.remove(model_card_path)
|
| 383 |
+
except Exception as e:
|
| 384 |
+
print(f"β οΈ Failed to upload model card: {e}")
|
| 385 |
+
|
| 386 |
+
def get_upload_status(self) -> Dict[str, Any]:
|
| 387 |
+
"""Get current upload status and repository information."""
|
| 388 |
+
status = {
|
| 389 |
+
"repositories": {},
|
| 390 |
+
"last_upload": None,
|
| 391 |
+
"total_uploads": 0
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
for repo_id in self.repos.values():
|
| 395 |
+
try:
|
| 396 |
+
repo_info = self.api.repo_info(repo_id)
|
| 397 |
+
status["repositories"][repo_id] = {
|
| 398 |
+
"exists": True,
|
| 399 |
+
"last_modified": repo_info.last_modified.isoformat() if repo_info.last_modified else None,
|
| 400 |
+
"size": repo_info.size_on_disk if hasattr(repo_info, 'size_on_disk') else None
|
| 401 |
+
}
|
| 402 |
+
except Exception:
|
| 403 |
+
status["repositories"][repo_id] = {
|
| 404 |
+
"exists": False,
|
| 405 |
+
"last_modified": None,
|
| 406 |
+
"size": None
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
return status
|
| 410 |
+
|
| 411 |
+
def create_hf_integration(token: str = None) -> HFHubIntegration:
|
| 412 |
+
"""Create an HF Hub integration instance."""
|
| 413 |
+
return HFHubIntegration(token=token)
|
utils/runtime_fetcher.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Runtime artifact fetcher for Dressify.
|
| 4 |
+
Downloads pre-processed artifacts from Hugging Face Hub to avoid reprocessing.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
import tarfile
|
| 11 |
+
import zipfile
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Any, Optional
|
| 14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 15 |
+
|
| 16 |
+
class RuntimeArtifactFetcher:
|
| 17 |
+
"""Fetches artifacts from HF Hub at runtime to avoid reprocessing."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, base_dir: str = "/home/user/app"):
|
| 20 |
+
self.base_dir = base_dir
|
| 21 |
+
self.data_dir = os.path.join(base_dir, "data/Polyvore")
|
| 22 |
+
self.splits_dir = os.path.join(self.data_dir, "splits")
|
| 23 |
+
self.export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 24 |
+
|
| 25 |
+
# Default HF repositories - updated to use your specific repos
|
| 26 |
+
self.default_repos = {
|
| 27 |
+
"splits": "Stylique/Dressify-Helper",
|
| 28 |
+
"models": "Stylique/dressify-models",
|
| 29 |
+
"metadata": "Stylique/Dressify-Helper"
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def check_artifacts_needed(self) -> Dict[str, Any]:
|
| 33 |
+
"""Check what artifacts need to be fetched."""
|
| 34 |
+
needs = {
|
| 35 |
+
"splits": False,
|
| 36 |
+
"models": False,
|
| 37 |
+
"metadata": False,
|
| 38 |
+
"total_size_mb": 0
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Check splits
|
| 42 |
+
if not os.path.exists(self.splits_dir) or not self._has_complete_splits():
|
| 43 |
+
needs["splits"] = True
|
| 44 |
+
needs["total_size_mb"] += 50 # Estimate splits size
|
| 45 |
+
|
| 46 |
+
# Check models
|
| 47 |
+
if not os.path.exists(self.export_dir) or not self._has_trained_models():
|
| 48 |
+
needs["models"] = True
|
| 49 |
+
needs["total_size_mb"] += 200 # Estimate models size
|
| 50 |
+
|
| 51 |
+
# Check metadata
|
| 52 |
+
if not self._has_complete_metadata():
|
| 53 |
+
needs["metadata"] = True
|
| 54 |
+
needs["total_size_mb"] += 100 # Estimate metadata size
|
| 55 |
+
|
| 56 |
+
return needs
|
| 57 |
+
|
| 58 |
+
def _has_complete_splits(self) -> bool:
|
| 59 |
+
"""Check if complete splits are available."""
|
| 60 |
+
required_files = [
|
| 61 |
+
"train.json", "valid.json", "test.json",
|
| 62 |
+
"outfit_triplets_train.json", "outfit_triplets_valid.json", "outfit_triplets_test.json"
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
for file in required_files:
|
| 66 |
+
if not os.path.exists(os.path.join(self.splits_dir, file)):
|
| 67 |
+
return False
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
def _has_trained_models(self) -> bool:
|
| 71 |
+
"""Check if trained models are available."""
|
| 72 |
+
required_files = [
|
| 73 |
+
"resnet_item_embedder_best.pth",
|
| 74 |
+
"vit_outfit_model_best.pth"
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
for file in required_files:
|
| 78 |
+
if not os.path.exists(os.path.join(self.export_dir, file)):
|
| 79 |
+
return False
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
def _has_complete_metadata(self) -> bool:
|
| 83 |
+
"""Check if complete metadata is available."""
|
| 84 |
+
required_files = [
|
| 85 |
+
"polyvore_item_metadata.json",
|
| 86 |
+
"polyvore_outfit_titles.json",
|
| 87 |
+
"categories.csv"
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
for file in required_files:
|
| 91 |
+
if not os.path.exists(os.path.join(self.data_dir, file)):
|
| 92 |
+
return False
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
def fetch_splits_from_hf(self, repo: str = None, token: str = None) -> bool:
|
| 96 |
+
"""Fetch dataset splits from HF Hub."""
|
| 97 |
+
if repo is None:
|
| 98 |
+
repo = self.default_repos["splits"]
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
print(f"π Fetching splits from {repo}...")
|
| 102 |
+
|
| 103 |
+
# Create splits directory
|
| 104 |
+
os.makedirs(self.splits_dir, exist_ok=True)
|
| 105 |
+
|
| 106 |
+
# Download splits files
|
| 107 |
+
split_files = [
|
| 108 |
+
"train.json", "valid.json", "test.json",
|
| 109 |
+
"outfits_train.json", "outfits_valid.json", "outfits_test.json",
|
| 110 |
+
"outfit_triplets_train.json", "outfit_triplets_valid.json", "outfit_triplets_test.json"
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
for file in split_files:
|
| 114 |
+
try:
|
| 115 |
+
local_path = hf_hub_download(
|
| 116 |
+
repo_id=repo,
|
| 117 |
+
filename=f"splits/{file}",
|
| 118 |
+
local_dir=self.splits_dir,
|
| 119 |
+
token=token
|
| 120 |
+
)
|
| 121 |
+
print(f"β
Downloaded: {file}")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"β οΈ Failed to download {file}: {e}")
|
| 124 |
+
|
| 125 |
+
print(f"β
Splits fetched successfully to {self.splits_dir}")
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"β Failed to fetch splits: {e}")
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
def fetch_models_from_hf(self, repo: str = None, token: str = None) -> bool:
|
| 133 |
+
"""Fetch trained models from HF Hub."""
|
| 134 |
+
if repo is None:
|
| 135 |
+
repo = self.default_repos["models"]
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
print(f"π Fetching models from {repo}...")
|
| 139 |
+
|
| 140 |
+
# Create export directory
|
| 141 |
+
os.makedirs(self.export_dir, exist_ok=True)
|
| 142 |
+
|
| 143 |
+
# Download model files
|
| 144 |
+
model_files = [
|
| 145 |
+
"resnet_item_embedder_best.pth",
|
| 146 |
+
"vit_outfit_model_best.pth",
|
| 147 |
+
"resnet_metrics.json",
|
| 148 |
+
"vit_metrics.json"
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
for file in model_files:
|
| 152 |
+
try:
|
| 153 |
+
local_path = hf_hub_download(
|
| 154 |
+
repo_id=repo,
|
| 155 |
+
filename=file,
|
| 156 |
+
local_dir=self.export_dir,
|
| 157 |
+
token=token
|
| 158 |
+
)
|
| 159 |
+
print(f"β
Downloaded: {file}")
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"β οΈ Failed to download {file}: {e}")
|
| 162 |
+
|
| 163 |
+
print(f"β
Models fetched successfully to {self.export_dir}")
|
| 164 |
+
return True
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"β Failed to fetch models: {e}")
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
def fetch_metadata_from_hf(self, repo: str = None, token: str = None) -> bool:
|
| 171 |
+
"""Fetch metadata from HF Hub."""
|
| 172 |
+
if repo is None:
|
| 173 |
+
repo = self.default_repos["metadata"]
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
print(f"π Fetching metadata from {repo}...")
|
| 177 |
+
|
| 178 |
+
# Create data directory
|
| 179 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
| 180 |
+
|
| 181 |
+
# Download metadata files
|
| 182 |
+
metadata_files = [
|
| 183 |
+
"polyvore_item_metadata.json",
|
| 184 |
+
"polyvore_outfit_titles.json",
|
| 185 |
+
"categories.csv"
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
for file in metadata_files:
|
| 189 |
+
try:
|
| 190 |
+
local_path = hf_hub_download(
|
| 191 |
+
repo_id=repo,
|
| 192 |
+
filename=f"metadata/{file}",
|
| 193 |
+
local_dir=self.data_dir,
|
| 194 |
+
token=token
|
| 195 |
+
)
|
| 196 |
+
print(f"β
Downloaded: {file}")
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"β οΈ Failed to download {file}: {e}")
|
| 199 |
+
|
| 200 |
+
print(f"β
Metadata fetched successfully to {self.data_dir}")
|
| 201 |
+
return True
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"β Failed to fetch metadata: {e}")
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
def fetch_everything_from_hf(self, splits_repo: str = None, models_repo: str = None,
|
| 208 |
+
metadata_repo: str = None, token: str = None) -> Dict[str, bool]:
|
| 209 |
+
"""Fetch all artifacts from HF Hub."""
|
| 210 |
+
results = {}
|
| 211 |
+
|
| 212 |
+
print("π Starting comprehensive artifact fetch from HF Hub...")
|
| 213 |
+
|
| 214 |
+
# Fetch splits
|
| 215 |
+
results["splits"] = self.fetch_splits_from_hf(splits_repo, token)
|
| 216 |
+
|
| 217 |
+
# Fetch models
|
| 218 |
+
results["models"] = self.fetch_models_from_hf(models_repo, token)
|
| 219 |
+
|
| 220 |
+
# Fetch metadata
|
| 221 |
+
results["metadata"] = self.fetch_metadata_from_hf(metadata_repo, token)
|
| 222 |
+
|
| 223 |
+
# Summary
|
| 224 |
+
success_count = sum(results.values())
|
| 225 |
+
total_count = len(results)
|
| 226 |
+
|
| 227 |
+
print(f"\nπ Fetch Summary: {success_count}/{total_count} successful")
|
| 228 |
+
for artifact, success in results.items():
|
| 229 |
+
status = "β
" if success else "β"
|
| 230 |
+
print(f" {status} {artifact}")
|
| 231 |
+
|
| 232 |
+
return results
|
| 233 |
+
|
| 234 |
+
def download_and_extract_package(self, package_path: str, extract_to: str = None) -> bool:
|
| 235 |
+
"""Download and extract a package from HF Hub."""
|
| 236 |
+
try:
|
| 237 |
+
if extract_to is None:
|
| 238 |
+
extract_to = self.base_dir
|
| 239 |
+
|
| 240 |
+
print(f"π Downloading and extracting package: {package_path}")
|
| 241 |
+
|
| 242 |
+
# Download the package
|
| 243 |
+
local_path = hf_hub_download(
|
| 244 |
+
repo_id="Stylique/Dressify-Helper",
|
| 245 |
+
filename=f"packages/{os.path.basename(package_path)}",
|
| 246 |
+
local_dir=extract_to,
|
| 247 |
+
token=None
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Extract based on file type
|
| 251 |
+
if package_path.endswith(".tar.gz"):
|
| 252 |
+
with tarfile.open(local_path, 'r:gz') as tar:
|
| 253 |
+
tar.extractall(extract_to)
|
| 254 |
+
elif package_path.endswith(".zip"):
|
| 255 |
+
with zipfile.ZipFile(local_path, 'r') as zipf:
|
| 256 |
+
zipf.extractall(extract_to)
|
| 257 |
+
|
| 258 |
+
print(f"β
Package extracted to {extract_to}")
|
| 259 |
+
return True
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"β Failed to download/extract package: {e}")
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
def get_fetch_status(self) -> Dict[str, Any]:
|
| 266 |
+
"""Get current fetch status."""
|
| 267 |
+
return {
|
| 268 |
+
"splits_available": self._has_complete_splits(),
|
| 269 |
+
"models_available": self._has_trained_models(),
|
| 270 |
+
"metadata_available": self._has_complete_metadata(),
|
| 271 |
+
"artifacts_needed": self.check_artifacts_needed(),
|
| 272 |
+
"base_dir": self.base_dir,
|
| 273 |
+
"splits_dir": self.splits_dir,
|
| 274 |
+
"export_dir": self.export_dir,
|
| 275 |
+
"hf_repos": self.default_repos
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
def create_runtime_fetcher() -> RuntimeArtifactFetcher:
|
| 279 |
+
"""Create a runtime fetcher instance."""
|
| 280 |
+
return RuntimeArtifactFetcher()
|
| 281 |
+
|
| 282 |
+
def auto_fetch_if_needed(token: str = None) -> Dict[str, bool]:
|
| 283 |
+
"""Automatically fetch artifacts if they're needed."""
|
| 284 |
+
fetcher = create_runtime_fetcher()
|
| 285 |
+
|
| 286 |
+
# Check what's needed
|
| 287 |
+
needs = fetcher.check_artifacts_needed()
|
| 288 |
+
|
| 289 |
+
if not any([needs["splits"], needs["models"], needs["metadata"]]):
|
| 290 |
+
print("β
All artifacts are already available - no fetching needed")
|
| 291 |
+
return {"splits": True, "models": True, "metadata": True}
|
| 292 |
+
|
| 293 |
+
print(f"π Auto-fetching needed artifacts (estimated size: {needs['total_size_mb']} MB)")
|
| 294 |
+
|
| 295 |
+
# Fetch what's needed
|
| 296 |
+
results = {}
|
| 297 |
+
if needs["splits"]:
|
| 298 |
+
results["splits"] = fetcher.fetch_splits_from_hf(token=token)
|
| 299 |
+
|
| 300 |
+
if needs["models"]:
|
| 301 |
+
results["models"] = fetcher.fetch_models_from_hf(token=token)
|
| 302 |
+
|
| 303 |
+
if needs["metadata"]:
|
| 304 |
+
results["metadata"] = fetcher.fetch_metadata_from_hf(token=token)
|
| 305 |
+
|
| 306 |
+
return results
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
# Test the fetcher
|
| 310 |
+
fetcher = create_runtime_fetcher()
|
| 311 |
+
status = fetcher.get_fetch_status()
|
| 312 |
+
print("Current fetch status:", json.dumps(status, indent=2))
|