Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
Β·
941ea8d
1
Parent(s):
1216fc5
final commit 5000
Browse files- app.py +12 -8
- inference.py +267 -60
- train_vit_triplet.py +22 -3
app.py
CHANGED
|
@@ -278,9 +278,10 @@ def _background_bootstrap():
|
|
| 278 |
if not os.path.exists(vit_ckpt):
|
| 279 |
BOOT_STATUS = "training-vit"
|
| 280 |
subprocess.run([
|
| 281 |
-
"python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "
|
| 282 |
-
"--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "
|
| 283 |
-
"--
|
|
|
|
| 284 |
], check=False)
|
| 285 |
service.reload_models()
|
| 286 |
BOOT_STATUS = "ready"
|
|
@@ -389,7 +390,7 @@ def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(
|
|
| 389 |
return out
|
| 390 |
|
| 391 |
|
| 392 |
-
def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
|
| 393 |
# Check model status first
|
| 394 |
model_status = service.get_model_status()
|
| 395 |
if not model_status["can_recommend"]:
|
|
@@ -415,7 +416,7 @@ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits:
|
|
| 415 |
{"id": f"item_{i}", "image": images[i], "category": None}
|
| 416 |
for i in range(len(images))
|
| 417 |
]
|
| 418 |
-
res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
|
| 419 |
|
| 420 |
# Check if compose_outfits returned an error
|
| 421 |
if res and isinstance(res[0], dict) and "error" in res[0]:
|
|
@@ -710,8 +711,9 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
|
|
| 710 |
log_message += f"\nπ Starting ViT training on {dataset_size} samples...\n"
|
| 711 |
vit_result = subprocess.run([
|
| 712 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 713 |
-
"--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "
|
| 714 |
-
"--
|
|
|
|
| 715 |
] + dataset_args, capture_output=True, text=True, check=False)
|
| 716 |
|
| 717 |
if vit_result.returncode == 0:
|
|
@@ -770,11 +772,13 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
|
|
| 770 |
with gr.Row():
|
| 771 |
occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
|
| 772 |
weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
|
|
|
|
|
|
|
| 773 |
num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits")
|
| 774 |
out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
|
| 775 |
out_json = gr.JSON(label="Outfit Details")
|
| 776 |
btn2 = gr.Button("Generate Outfits", variant="primary")
|
| 777 |
-
btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits], outputs=[out_gallery, out_json])
|
| 778 |
|
| 779 |
with gr.Tab("π¬ Advanced Training"):
|
| 780 |
gr.Markdown("### π― Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
|
|
|
|
| 278 |
if not os.path.exists(vit_ckpt):
|
| 279 |
BOOT_STATUS = "training-vit"
|
| 280 |
subprocess.run([
|
| 281 |
+
"python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "10",
|
| 282 |
+
"--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
|
| 283 |
+
"--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
|
| 284 |
+
"--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 285 |
], check=False)
|
| 286 |
service.reload_models()
|
| 287 |
BOOT_STATUS = "ready"
|
|
|
|
| 390 |
return out
|
| 391 |
|
| 392 |
|
| 393 |
+
def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int, outfit_style: str = "casual"):
|
| 394 |
# Check model status first
|
| 395 |
model_status = service.get_model_status()
|
| 396 |
if not model_status["can_recommend"]:
|
|
|
|
| 416 |
{"id": f"item_{i}", "image": images[i], "category": None}
|
| 417 |
for i in range(len(images))
|
| 418 |
]
|
| 419 |
+
res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits), "outfit_style": outfit_style})
|
| 420 |
|
| 421 |
# Check if compose_outfits returned an error
|
| 422 |
if res and isinstance(res[0], dict) and "error" in res[0]:
|
|
|
|
| 711 |
log_message += f"\nπ Starting ViT training on {dataset_size} samples...\n"
|
| 712 |
vit_result = subprocess.run([
|
| 713 |
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 714 |
+
"--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
|
| 715 |
+
"--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
|
| 716 |
+
"--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 717 |
] + dataset_args, capture_output=True, text=True, check=False)
|
| 718 |
|
| 719 |
if vit_result.returncode == 0:
|
|
|
|
| 772 |
with gr.Row():
|
| 773 |
occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
|
| 774 |
weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
|
| 775 |
+
outfit_style = gr.Dropdown(choices=["casual", "smart_casual", "formal", "sporty"], value="casual", label="Outfit Style")
|
| 776 |
+
with gr.Row():
|
| 777 |
num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits")
|
| 778 |
out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
|
| 779 |
out_json = gr.JSON(label="Outfit Details")
|
| 780 |
btn2 = gr.Button("Generate Outfits", variant="primary")
|
| 781 |
+
btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits, outfit_style], outputs=[out_gallery, out_json])
|
| 782 |
|
| 783 |
with gr.Tab("π¬ Advanced Training"):
|
| 784 |
gr.Markdown("### π― Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
|
inference.py
CHANGED
|
@@ -233,25 +233,148 @@ class InferenceService:
|
|
| 233 |
if len(proc_items) < 2:
|
| 234 |
return []
|
| 235 |
|
| 236 |
-
# 2) Candidate generation
|
| 237 |
rng = np.random.default_rng(int(context.get("seed", 42)))
|
| 238 |
num_outfits = int(context.get("num_outfits", 3))
|
| 239 |
min_size, max_size = 4, 6
|
| 240 |
ids = list(range(len(proc_items)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
# Enhanced category-aware pools with diversity checks
|
| 243 |
def cat_str(i: int) -> str:
|
| 244 |
return (proc_items[i].get("category") or "").lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
def get_category_type(cat: str) -> str:
|
| 247 |
-
"""Map category to outfit slot type"""
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
return "upper"
|
| 250 |
-
elif any(k in
|
| 251 |
return "bottom"
|
| 252 |
-
elif any(k in
|
| 253 |
return "shoe"
|
| 254 |
-
elif any(k in
|
| 255 |
return "accessory"
|
| 256 |
else:
|
| 257 |
return "other"
|
|
@@ -275,52 +398,50 @@ class InferenceService:
|
|
| 275 |
|
| 276 |
for _ in range(num_samples):
|
| 277 |
subset = []
|
| 278 |
-
used_categories = set()
|
| 279 |
|
| 280 |
-
#
|
| 281 |
-
if uppers:
|
|
|
|
| 282 |
subset.append(int(rng.choice(uppers)))
|
| 283 |
-
used_categories.add("upper")
|
| 284 |
-
|
| 285 |
-
if bottoms:
|
| 286 |
subset.append(int(rng.choice(bottoms)))
|
| 287 |
-
used_categories.add("bottom")
|
| 288 |
-
|
| 289 |
-
if shoes:
|
| 290 |
subset.append(int(rng.choice(shoes)))
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
if
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
-
# Remove duplicates and
|
| 322 |
subset = list(set(subset))
|
| 323 |
-
if len(subset) >=
|
| 324 |
candidates.append(subset)
|
| 325 |
|
| 326 |
# 3) Score using ViT
|
|
@@ -334,46 +455,132 @@ class InferenceService:
|
|
| 334 |
s = self.vit.score_compatibility(embs).item()
|
| 335 |
return float(s)
|
| 336 |
|
| 337 |
-
#
|
| 338 |
def is_valid_outfit(subset: List[int]) -> bool:
|
| 339 |
-
"""Check if outfit
|
| 340 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 341 |
-
# Allow multiple accessories, but only one of each other category
|
| 342 |
category_counts = {}
|
|
|
|
| 343 |
for cat in categories:
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
return True
|
| 350 |
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
valid_candidates = [subset for subset in candidates if is_valid_outfit(subset)]
|
| 353 |
if not valid_candidates:
|
| 354 |
# Fallback: use all candidates if no valid ones found
|
| 355 |
valid_candidates = candidates
|
| 356 |
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
scored.sort(key=lambda x: x[1], reverse=True)
|
| 359 |
topk = scored[:num_outfits]
|
| 360 |
|
| 361 |
results = []
|
| 362 |
-
for subset,
|
| 363 |
# Double-check validity and get item details
|
| 364 |
outfit_items = []
|
| 365 |
for i in subset:
|
| 366 |
item = proc_items[i]
|
| 367 |
outfit_items.append({
|
| 368 |
"id": item["id"],
|
| 369 |
-
"category": item.get("category", "unknown")
|
|
|
|
| 370 |
})
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
results.append({
|
| 373 |
"item_ids": [item["id"] for item in outfit_items],
|
| 374 |
"items": outfit_items,
|
| 375 |
-
"score": float(
|
| 376 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
})
|
| 378 |
|
| 379 |
return results
|
|
|
|
| 233 |
if len(proc_items) < 2:
|
| 234 |
return []
|
| 235 |
|
| 236 |
+
# 2) Candidate generation with outfit templates
|
| 237 |
rng = np.random.default_rng(int(context.get("seed", 42)))
|
| 238 |
num_outfits = int(context.get("num_outfits", 3))
|
| 239 |
min_size, max_size = 4, 6
|
| 240 |
ids = list(range(len(proc_items)))
|
| 241 |
+
|
| 242 |
+
# Outfit templates for cohesive styling
|
| 243 |
+
outfit_templates = {
|
| 244 |
+
"casual": {
|
| 245 |
+
"style": "relaxed, comfortable, everyday",
|
| 246 |
+
"preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant"],
|
| 247 |
+
"color_palette": ["neutral", "denim", "white", "black", "gray"],
|
| 248 |
+
"accessory_limit": 2
|
| 249 |
+
},
|
| 250 |
+
"smart_casual": {
|
| 251 |
+
"style": "polished but relaxed, business casual",
|
| 252 |
+
"preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo"],
|
| 253 |
+
"color_palette": ["navy", "white", "khaki", "brown", "gray"],
|
| 254 |
+
"accessory_limit": 3
|
| 255 |
+
},
|
| 256 |
+
"formal": {
|
| 257 |
+
"style": "professional, elegant, sophisticated",
|
| 258 |
+
"preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit"],
|
| 259 |
+
"color_palette": ["navy", "black", "white", "gray", "charcoal"],
|
| 260 |
+
"accessory_limit": 4
|
| 261 |
+
},
|
| 262 |
+
"sporty": {
|
| 263 |
+
"style": "athletic, active, performance",
|
| 264 |
+
"preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging"],
|
| 265 |
+
"color_palette": ["bright", "neon", "white", "black", "primary colors"],
|
| 266 |
+
"accessory_limit": 1
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
# Select outfit template (can be passed in context or randomly selected)
|
| 271 |
+
template_name = context.get("outfit_style", rng.choice(list(outfit_templates.keys())))
|
| 272 |
+
template = outfit_templates[template_name]
|
| 273 |
|
| 274 |
# Enhanced category-aware pools with diversity checks
|
| 275 |
def cat_str(i: int) -> str:
|
| 276 |
return (proc_items[i].get("category") or "").lower()
|
| 277 |
+
|
| 278 |
+
def extract_color_from_category(category: str) -> str:
|
| 279 |
+
"""Extract color information from category name"""
|
| 280 |
+
category_lower = category.lower()
|
| 281 |
+
color_keywords = {
|
| 282 |
+
"black": ["black", "dark", "charcoal", "navy"],
|
| 283 |
+
"white": ["white", "cream", "ivory", "off-white"],
|
| 284 |
+
"gray": ["gray", "grey", "silver", "ash"],
|
| 285 |
+
"brown": ["brown", "tan", "beige", "khaki", "camel"],
|
| 286 |
+
"blue": ["blue", "navy", "denim", "indigo", "royal"],
|
| 287 |
+
"red": ["red", "burgundy", "maroon", "crimson"],
|
| 288 |
+
"green": ["green", "olive", "emerald", "forest"],
|
| 289 |
+
"yellow": ["yellow", "gold", "mustard", "lemon"],
|
| 290 |
+
"pink": ["pink", "rose", "coral", "salmon"],
|
| 291 |
+
"purple": ["purple", "violet", "lavender", "plum"],
|
| 292 |
+
"orange": ["orange", "peach", "apricot", "tangerine"],
|
| 293 |
+
"neutral": ["neutral", "nude", "natural", "earth"]
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
for color, keywords in color_keywords.items():
|
| 297 |
+
if any(kw in category_lower for kw in keywords):
|
| 298 |
+
return color
|
| 299 |
+
return "unknown"
|
| 300 |
+
|
| 301 |
+
def calculate_color_consistency_score(items: List[int]) -> float:
|
| 302 |
+
"""Calculate color consistency score for outfit items"""
|
| 303 |
+
colors = [extract_color_from_category(cat_str(i)) for i in items]
|
| 304 |
+
color_counts = {}
|
| 305 |
+
for color in colors:
|
| 306 |
+
color_counts[color] = color_counts.get(color, 0) + 1
|
| 307 |
+
|
| 308 |
+
# Prefer outfits with 2-3 dominant colors
|
| 309 |
+
dominant_colors = [c for c, count in color_counts.items() if count >= 2]
|
| 310 |
+
if len(dominant_colors) == 0:
|
| 311 |
+
return 0.5 # Neutral score for all different colors
|
| 312 |
+
elif len(dominant_colors) == 1:
|
| 313 |
+
return 0.8 # Good consistency
|
| 314 |
+
elif len(dominant_colors) == 2:
|
| 315 |
+
return 1.0 # Perfect balance
|
| 316 |
+
else:
|
| 317 |
+
return 0.3 # Too many dominant colors
|
| 318 |
+
|
| 319 |
+
def calculate_style_consistency_score(items: List[int]) -> float:
|
| 320 |
+
"""Calculate style consistency based on template preferences"""
|
| 321 |
+
categories = [cat_str(i) for i in items]
|
| 322 |
+
preferred_cats = template["preferred_categories"]
|
| 323 |
+
|
| 324 |
+
matches = 0
|
| 325 |
+
for cat in categories:
|
| 326 |
+
if any(pref in cat for pref in preferred_cats):
|
| 327 |
+
matches += 1
|
| 328 |
+
|
| 329 |
+
return matches / len(categories) if categories else 0.0
|
| 330 |
|
| 331 |
def get_category_type(cat: str) -> str:
|
| 332 |
+
"""Map category to outfit slot type with comprehensive taxonomy"""
|
| 333 |
+
cat_lower = cat.lower().strip()
|
| 334 |
+
|
| 335 |
+
# Upper body items (tops, outerwear)
|
| 336 |
+
upper_keywords = [
|
| 337 |
+
"top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
|
| 338 |
+
"jacket", "blazer", "coat", "hoodie", "sweater", "pullover", "cardigan",
|
| 339 |
+
"vest", "waistcoat", "windbreaker", "bomber", "denim jacket", "leather jacket",
|
| 340 |
+
"polo", "henley", "tunic", "crop top", "bodysuit", "romper", "jumpsuit"
|
| 341 |
+
]
|
| 342 |
+
|
| 343 |
+
# Bottom items
|
| 344 |
+
bottom_keywords = [
|
| 345 |
+
"pant", "pants", "trouser", "trousers", "jean", "jeans", "denim",
|
| 346 |
+
"skirt", "short", "shorts", "legging", "leggings", "tights",
|
| 347 |
+
"chino", "khaki", "cargo", "jogger", "sweatpant", "sweatpants",
|
| 348 |
+
"culotte", "palazzo", "mini skirt", "midi skirt", "maxi skirt",
|
| 349 |
+
"bermuda", "capri", "bike short", "bike shorts"
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
# Footwear
|
| 353 |
+
shoe_keywords = [
|
| 354 |
+
"shoe", "shoes", "sneaker", "sneakers", "boot", "boots", "heel", "heels",
|
| 355 |
+
"sandal", "sandals", "flat", "flats", "loafer", "loafers", "oxford",
|
| 356 |
+
"pump", "pumps", "stiletto", "wedge", "ankle boot", "knee high boot",
|
| 357 |
+
"combat boot", "hiking boot", "running shoe", "athletic shoe",
|
| 358 |
+
"mule", "mules", "clog", "clogs", "espadrille", "espadrilles"
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
# Accessories (can have multiple)
|
| 362 |
+
accessory_keywords = [
|
| 363 |
+
"watch", "belt", "ring", "rings", "bracelet", "bracelets", "necklace", "necklaces",
|
| 364 |
+
"earring", "earrings", "bag", "bags", "handbag", "purse", "clutch", "tote",
|
| 365 |
+
"hat", "cap", "beanie", "scarf", "scarves", "glove", "gloves", "sunglass", "sunglasses",
|
| 366 |
+
"tie", "bow tie", "pocket square", "cufflink", "cufflinks", "brooch", "pin",
|
| 367 |
+
"hair accessory", "headband", "hair clip", "barrette", "scrunchy", "scrunchies"
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
# Check each category
|
| 371 |
+
if any(k in cat_lower for k in upper_keywords):
|
| 372 |
return "upper"
|
| 373 |
+
elif any(k in cat_lower for k in bottom_keywords):
|
| 374 |
return "bottom"
|
| 375 |
+
elif any(k in cat_lower for k in shoe_keywords):
|
| 376 |
return "shoe"
|
| 377 |
+
elif any(k in cat_lower for k in accessory_keywords):
|
| 378 |
return "accessory"
|
| 379 |
else:
|
| 380 |
return "other"
|
|
|
|
| 398 |
|
| 399 |
for _ in range(num_samples):
|
| 400 |
subset = []
|
|
|
|
| 401 |
|
| 402 |
+
# EXACT SLOT CONSTRAINTS: Exactly 1 upper, 1 bottom, 1 shoe, β€2 accessories
|
| 403 |
+
if uppers and bottoms and shoes:
|
| 404 |
+
# Core outfit: exactly 1 of each required slot
|
| 405 |
subset.append(int(rng.choice(uppers)))
|
|
|
|
|
|
|
|
|
|
| 406 |
subset.append(int(rng.choice(bottoms)))
|
|
|
|
|
|
|
|
|
|
| 407 |
subset.append(int(rng.choice(shoes)))
|
| 408 |
+
|
| 409 |
+
# Add accessories based on template limit
|
| 410 |
+
if accs:
|
| 411 |
+
max_accs = template["accessory_limit"]
|
| 412 |
+
num_accs = rng.integers(1, min(max_accs + 1, len(accs) + 1))
|
| 413 |
+
available_accs = [i for i in accs if i not in subset]
|
| 414 |
+
if available_accs:
|
| 415 |
+
selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
|
| 416 |
+
subset.extend(selected_accs.tolist())
|
| 417 |
+
|
| 418 |
+
# Add 0-1 other items for variety (but not if it would exceed max_size)
|
| 419 |
+
if others and len(subset) < max_size:
|
| 420 |
+
available_others = [i for i in others if i not in subset]
|
| 421 |
+
if available_others and rng.random() < 0.3: # 30% chance to add other item
|
| 422 |
+
subset.append(int(rng.choice(available_others)))
|
| 423 |
+
else:
|
| 424 |
+
# Fallback: ensure we have at least 3 items with category diversity
|
| 425 |
+
required_categories = []
|
| 426 |
+
if uppers: required_categories.append(("upper", uppers))
|
| 427 |
+
if bottoms: required_categories.append(("bottom", bottoms))
|
| 428 |
+
if shoes: required_categories.append(("shoe", shoes))
|
| 429 |
+
|
| 430 |
+
# Add one from each available required category
|
| 431 |
+
for cat_type, cat_items in required_categories:
|
| 432 |
+
subset.append(int(rng.choice(cat_items)))
|
| 433 |
+
|
| 434 |
+
# Add accessories if available
|
| 435 |
+
if accs and len(subset) < max_size:
|
| 436 |
+
num_accs = rng.integers(1, min(3, len(accs) + 1))
|
| 437 |
+
available_accs = [i for i in accs if i not in subset]
|
| 438 |
+
if available_accs:
|
| 439 |
+
selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
|
| 440 |
+
subset.extend(selected_accs.tolist())
|
| 441 |
|
| 442 |
+
# Remove duplicates and validate
|
| 443 |
subset = list(set(subset))
|
| 444 |
+
if len(subset) >= 3: # At least 3 items for a valid outfit
|
| 445 |
candidates.append(subset)
|
| 446 |
|
| 447 |
# 3) Score using ViT
|
|
|
|
| 455 |
s = self.vit.score_compatibility(embs).item()
|
| 456 |
return float(s)
|
| 457 |
|
| 458 |
+
# Enhanced validation with strict slot constraints
|
| 459 |
def is_valid_outfit(subset: List[int]) -> bool:
|
| 460 |
+
"""Check if outfit meets exact slot requirements"""
|
| 461 |
categories = [get_category_type(cat_str(i)) for i in subset]
|
|
|
|
| 462 |
category_counts = {}
|
| 463 |
+
|
| 464 |
for cat in categories:
|
| 465 |
+
category_counts[cat] = category_counts.get(cat, 0) + 1
|
| 466 |
+
|
| 467 |
+
# STRICT VALIDATION:
|
| 468 |
+
# - Exactly 1 upper, 1 bottom, 1 shoe
|
| 469 |
+
# - β€2 accessories
|
| 470 |
+
# - No other duplicates
|
| 471 |
+
if category_counts.get("upper", 0) != 1:
|
| 472 |
+
return False
|
| 473 |
+
if category_counts.get("bottom", 0) != 1:
|
| 474 |
+
return False
|
| 475 |
+
if category_counts.get("shoe", 0) != 1:
|
| 476 |
+
return False
|
| 477 |
+
if category_counts.get("accessory", 0) > 2:
|
| 478 |
+
return False
|
| 479 |
+
if category_counts.get("other", 0) > 1:
|
| 480 |
+
return False
|
| 481 |
+
|
| 482 |
return True
|
| 483 |
|
| 484 |
+
def calculate_outfit_penalty(subset: List[int], base_score: float) -> float:
|
| 485 |
+
"""Calculate penalty-adjusted score for outfit quality with style/color bonuses"""
|
| 486 |
+
categories = [get_category_type(cat_str(i)) for i in subset]
|
| 487 |
+
category_counts = {}
|
| 488 |
+
|
| 489 |
+
for cat in categories:
|
| 490 |
+
category_counts[cat] = category_counts.get(cat, 0) + 1
|
| 491 |
+
|
| 492 |
+
penalty = 0.0
|
| 493 |
+
bonus = 0.0
|
| 494 |
+
|
| 495 |
+
# Missing core slots: -β penalty
|
| 496 |
+
if category_counts.get("upper", 0) == 0:
|
| 497 |
+
penalty += -1000.0
|
| 498 |
+
if category_counts.get("bottom", 0) == 0:
|
| 499 |
+
penalty += -1000.0
|
| 500 |
+
if category_counts.get("shoe", 0) == 0:
|
| 501 |
+
penalty += -1000.0
|
| 502 |
+
|
| 503 |
+
# Duplicate non-accessory categories: -β penalty
|
| 504 |
+
for cat, count in category_counts.items():
|
| 505 |
+
if cat != "accessory" and count > 1:
|
| 506 |
+
penalty += -1000.0
|
| 507 |
+
|
| 508 |
+
# Too many accessories: moderate penalty
|
| 509 |
+
max_accs = template["accessory_limit"]
|
| 510 |
+
if category_counts.get("accessory", 0) > max_accs:
|
| 511 |
+
penalty += -2.0
|
| 512 |
+
|
| 513 |
+
# Unbalanced outfit: small penalty
|
| 514 |
+
if len(subset) < 3:
|
| 515 |
+
penalty += -1.0
|
| 516 |
+
elif len(subset) > 6:
|
| 517 |
+
penalty += -0.5
|
| 518 |
+
|
| 519 |
+
# Style consistency bonus
|
| 520 |
+
style_score = calculate_style_consistency_score(subset)
|
| 521 |
+
bonus += style_score * 0.5 # Up to 0.5 bonus for style consistency
|
| 522 |
+
|
| 523 |
+
# Color consistency bonus
|
| 524 |
+
color_score = calculate_color_consistency_score(subset)
|
| 525 |
+
bonus += color_score * 0.3 # Up to 0.3 bonus for color consistency
|
| 526 |
+
|
| 527 |
+
# Template adherence bonus
|
| 528 |
+
if style_score > 0.6: # Good style match
|
| 529 |
+
bonus += 0.2
|
| 530 |
+
|
| 531 |
+
return base_score + penalty + bonus
|
| 532 |
+
|
| 533 |
+
# Score and filter valid outfits with penalty adjustment
|
| 534 |
valid_candidates = [subset for subset in candidates if is_valid_outfit(subset)]
|
| 535 |
if not valid_candidates:
|
| 536 |
# Fallback: use all candidates if no valid ones found
|
| 537 |
valid_candidates = candidates
|
| 538 |
|
| 539 |
+
# Score with penalty adjustment
|
| 540 |
+
scored = []
|
| 541 |
+
for subset in valid_candidates:
|
| 542 |
+
base_score = score_subset(subset)
|
| 543 |
+
adjusted_score = calculate_outfit_penalty(subset, base_score)
|
| 544 |
+
scored.append((subset, adjusted_score, base_score))
|
| 545 |
+
|
| 546 |
+
# Sort by penalty-adjusted score
|
| 547 |
scored.sort(key=lambda x: x[1], reverse=True)
|
| 548 |
topk = scored[:num_outfits]
|
| 549 |
|
| 550 |
results = []
|
| 551 |
+
for subset, adjusted_score, base_score in topk:
|
| 552 |
# Double-check validity and get item details
|
| 553 |
outfit_items = []
|
| 554 |
for i in subset:
|
| 555 |
item = proc_items[i]
|
| 556 |
outfit_items.append({
|
| 557 |
"id": item["id"],
|
| 558 |
+
"category": item.get("category", "unknown"),
|
| 559 |
+
"category_type": get_category_type(item.get("category", ""))
|
| 560 |
})
|
| 561 |
|
| 562 |
+
# Calculate additional metrics
|
| 563 |
+
style_score = calculate_style_consistency_score(subset)
|
| 564 |
+
color_score = calculate_color_consistency_score(subset)
|
| 565 |
+
colors = [extract_color_from_category(cat_str(i)) for i in subset]
|
| 566 |
+
|
| 567 |
results.append({
|
| 568 |
"item_ids": [item["id"] for item in outfit_items],
|
| 569 |
"items": outfit_items,
|
| 570 |
+
"score": float(adjusted_score),
|
| 571 |
+
"base_score": float(base_score),
|
| 572 |
+
"categories": [item["category"] for item in outfit_items],
|
| 573 |
+
"category_types": [item["category_type"] for item in outfit_items],
|
| 574 |
+
"outfit_size": len(outfit_items),
|
| 575 |
+
"is_valid": is_valid_outfit(subset),
|
| 576 |
+
"template": {
|
| 577 |
+
"name": template_name,
|
| 578 |
+
"style": template["style"],
|
| 579 |
+
"style_score": float(style_score),
|
| 580 |
+
"color_score": float(color_score),
|
| 581 |
+
"colors": colors,
|
| 582 |
+
"accessory_limit": template["accessory_limit"]
|
| 583 |
+
}
|
| 584 |
})
|
| 585 |
|
| 586 |
return results
|
train_vit_triplet.py
CHANGED
|
@@ -26,13 +26,15 @@ def parse_args() -> argparse.Namespace:
|
|
| 26 |
p.add_argument("--batch_size", type=int, default=4)
|
| 27 |
p.add_argument("--lr", type=float, default=5e-4)
|
| 28 |
p.add_argument("--embedding_dim", type=int, default=512)
|
| 29 |
-
p.add_argument("--triplet_margin", type=float, default=0.
|
| 30 |
p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
|
| 31 |
p.add_argument("--eval_every", type=int, default=1)
|
| 32 |
p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
|
| 33 |
-
p.add_argument("--max_samples", type=int, default=
|
| 34 |
-
p.add_argument("--early_stopping_patience", type=int, default=
|
| 35 |
p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
|
|
|
|
|
|
|
| 36 |
return p.parse_args()
|
| 37 |
|
| 38 |
|
|
@@ -113,6 +115,17 @@ def main() -> None:
|
|
| 113 |
|
| 114 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
|
| 115 |
triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
|
| 118 |
best_loss = float("inf")
|
|
@@ -159,7 +172,13 @@ def main() -> None:
|
|
| 159 |
|
| 160 |
optimizer.zero_grad(set_to_none=True)
|
| 161 |
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
optimizer.step()
|
|
|
|
| 163 |
|
| 164 |
# Collect metrics (simplified for ViT training)
|
| 165 |
# Note: ViT training uses outfit-level embeddings, not classification predictions
|
|
|
|
| 26 |
p.add_argument("--batch_size", type=int, default=4)
|
| 27 |
p.add_argument("--lr", type=float, default=5e-4)
|
| 28 |
p.add_argument("--embedding_dim", type=int, default=512)
|
| 29 |
+
p.add_argument("--triplet_margin", type=float, default=0.5)
|
| 30 |
p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
|
| 31 |
p.add_argument("--eval_every", type=int, default=1)
|
| 32 |
p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
|
| 33 |
+
p.add_argument("--max_samples", type=int, default=5000, help="Maximum number of training samples (for better quality)")
|
| 34 |
+
p.add_argument("--early_stopping_patience", type=int, default=5, help="Early stopping patience")
|
| 35 |
p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
|
| 36 |
+
p.add_argument("--gradient_clip", type=float, default=1.0, help="Gradient clipping value")
|
| 37 |
+
p.add_argument("--warmup_epochs", type=int, default=2, help="Learning rate warmup epochs")
|
| 38 |
return p.parse_args()
|
| 39 |
|
| 40 |
|
|
|
|
| 115 |
|
| 116 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
|
| 117 |
triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
|
| 118 |
+
|
| 119 |
+
# Learning rate scheduler with warmup
|
| 120 |
+
total_steps = len(loader) * args.epochs
|
| 121 |
+
warmup_steps = len(loader) * args.warmup_epochs
|
| 122 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 123 |
+
optimizer,
|
| 124 |
+
max_lr=args.lr,
|
| 125 |
+
total_steps=total_steps,
|
| 126 |
+
pct_start=warmup_steps/total_steps,
|
| 127 |
+
anneal_strategy='cos'
|
| 128 |
+
)
|
| 129 |
|
| 130 |
export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
|
| 131 |
best_loss = float("inf")
|
|
|
|
| 172 |
|
| 173 |
optimizer.zero_grad(set_to_none=True)
|
| 174 |
loss.backward()
|
| 175 |
+
|
| 176 |
+
# Gradient clipping for stability
|
| 177 |
+
if args.gradient_clip > 0:
|
| 178 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
| 179 |
+
|
| 180 |
optimizer.step()
|
| 181 |
+
scheduler.step() # Update learning rate
|
| 182 |
|
| 183 |
# Collect metrics (simplified for ViT training)
|
| 184 |
# Note: ViT training uses outfit-level embeddings, not classification predictions
|