Ali Mohsin commited on
Commit
941ea8d
Β·
1 Parent(s): 1216fc5

final commit 5000

Browse files
Files changed (3) hide show
  1. app.py +12 -8
  2. inference.py +267 -60
  3. train_vit_triplet.py +22 -3
app.py CHANGED
@@ -278,9 +278,10 @@ def _background_bootstrap():
278
  if not os.path.exists(vit_ckpt):
279
  BOOT_STATUS = "training-vit"
280
  subprocess.run([
281
- "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "3",
282
- "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "3",
283
- "--skip_validation", "--max_samples", "200", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
 
284
  ], check=False)
285
  service.reload_models()
286
  BOOT_STATUS = "ready"
@@ -389,7 +390,7 @@ def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(
389
  return out
390
 
391
 
392
- def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
393
  # Check model status first
394
  model_status = service.get_model_status()
395
  if not model_status["can_recommend"]:
@@ -415,7 +416,7 @@ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits:
415
  {"id": f"item_{i}", "image": images[i], "category": None}
416
  for i in range(len(images))
417
  ]
418
- res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
419
 
420
  # Check if compose_outfits returned an error
421
  if res and isinstance(res[0], dict) and "error" in res[0]:
@@ -710,8 +711,9 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
710
  log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
711
  vit_result = subprocess.run([
712
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
713
- "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "3",
714
- "--skip_validation", "--max_samples", "200", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
 
715
  ] + dataset_args, capture_output=True, text=True, check=False)
716
 
717
  if vit_result.returncode == 0:
@@ -770,11 +772,13 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
770
  with gr.Row():
771
  occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
772
  weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
 
 
773
  num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits")
774
  out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
775
  out_json = gr.JSON(label="Outfit Details")
776
  btn2 = gr.Button("Generate Outfits", variant="primary")
777
- btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits], outputs=[out_gallery, out_json])
778
 
779
  with gr.Tab("πŸ”¬ Advanced Training"):
780
  gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
 
278
  if not os.path.exists(vit_ckpt):
279
  BOOT_STATUS = "training-vit"
280
  subprocess.run([
281
+ "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "10",
282
+ "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
283
+ "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
284
+ "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
285
  ], check=False)
286
  service.reload_models()
287
  BOOT_STATUS = "ready"
 
390
  return out
391
 
392
 
393
+ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int, outfit_style: str = "casual"):
394
  # Check model status first
395
  model_status = service.get_model_status()
396
  if not model_status["can_recommend"]:
 
416
  {"id": f"item_{i}", "image": images[i], "category": None}
417
  for i in range(len(images))
418
  ]
419
+ res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits), "outfit_style": outfit_style})
420
 
421
  # Check if compose_outfits returned an error
422
  if res and isinstance(res[0], dict) and "error" in res[0]:
 
711
  log_message += f"\nπŸš€ Starting ViT training on {dataset_size} samples...\n"
712
  vit_result = subprocess.run([
713
  "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
714
+ "--batch_size", "4", "--lr", "5e-4", "--early_stopping_patience", "5",
715
+ "--max_samples", "5000", "--triplet_margin", "0.5", "--gradient_clip", "1.0",
716
+ "--warmup_epochs", "2", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
717
  ] + dataset_args, capture_output=True, text=True, check=False)
718
 
719
  if vit_result.returncode == 0:
 
772
  with gr.Row():
773
  occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
774
  weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
775
+ outfit_style = gr.Dropdown(choices=["casual", "smart_casual", "formal", "sporty"], value="casual", label="Outfit Style")
776
+ with gr.Row():
777
  num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits")
778
  out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
779
  out_json = gr.JSON(label="Outfit Details")
780
  btn2 = gr.Button("Generate Outfits", variant="primary")
781
+ btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits, outfit_style], outputs=[out_gallery, out_json])
782
 
783
  with gr.Tab("πŸ”¬ Advanced Training"):
784
  gr.Markdown("### 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
inference.py CHANGED
@@ -233,25 +233,148 @@ class InferenceService:
233
  if len(proc_items) < 2:
234
  return []
235
 
236
- # 2) Candidate generation
237
  rng = np.random.default_rng(int(context.get("seed", 42)))
238
  num_outfits = int(context.get("num_outfits", 3))
239
  min_size, max_size = 4, 6
240
  ids = list(range(len(proc_items)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  # Enhanced category-aware pools with diversity checks
243
  def cat_str(i: int) -> str:
244
  return (proc_items[i].get("category") or "").lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def get_category_type(cat: str) -> str:
247
- """Map category to outfit slot type"""
248
- if any(k in cat for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie", "sweater", "cardigan"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  return "upper"
250
- elif any(k in cat for k in ["pant", "trouser", "jean", "skirt", "short", "legging"]):
251
  return "bottom"
252
- elif any(k in cat for k in ["shoe", "sneaker", "boot", "heel", "sandal", "flat"]):
253
  return "shoe"
254
- elif any(k in cat for k in ["watch", "belt", "ring", "bracelet", "accessor", "bag", "hat", "scarf", "necklace"]):
255
  return "accessory"
256
  else:
257
  return "other"
@@ -275,52 +398,50 @@ class InferenceService:
275
 
276
  for _ in range(num_samples):
277
  subset = []
278
- used_categories = set()
279
 
280
- # Ensure one item from each major category (no duplicates)
281
- if uppers:
 
282
  subset.append(int(rng.choice(uppers)))
283
- used_categories.add("upper")
284
-
285
- if bottoms:
286
  subset.append(int(rng.choice(bottoms)))
287
- used_categories.add("bottom")
288
-
289
- if shoes:
290
  subset.append(int(rng.choice(shoes)))
291
- used_categories.add("shoe")
292
-
293
- # Add accessories (can have multiple, but ensure diversity)
294
- if accs:
295
- # Add 1-2 accessories
296
- num_accs = rng.integers(1, min(3, len(accs) + 1))
297
- available_accs = [i for i in accs if i not in subset]
298
- if available_accs:
299
- selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
300
- subset.extend(selected_accs.tolist())
301
- used_categories.add("accessory")
302
-
303
- # Add other items if available and we have space
304
- if others and len(subset) < max_size:
305
- available_others = [i for i in others if i not in subset]
306
- if available_others:
307
- num_others = rng.integers(0, min(2, len(available_others) + 1))
308
- if num_others > 0:
309
- selected_others = rng.choice(available_others, size=min(num_others, len(available_others)), replace=False)
310
- subset.extend(selected_others.tolist())
311
- used_categories.add("other")
312
-
313
- # Ensure we have at least min_size items
314
- if len(subset) < min_size:
315
- remaining = [i for i in ids if i not in subset]
316
- needed = min_size - len(subset)
317
- if remaining:
318
- additional = rng.choice(remaining, size=min(needed, len(remaining)), replace=False)
319
- subset.extend(additional.tolist())
 
 
 
 
320
 
321
- # Remove duplicates and ensure valid outfit
322
  subset = list(set(subset))
323
- if len(subset) >= 2: # At least 2 items for a valid outfit
324
  candidates.append(subset)
325
 
326
  # 3) Score using ViT
@@ -334,46 +455,132 @@ class InferenceService:
334
  s = self.vit.score_compatibility(embs).item()
335
  return float(s)
336
 
337
- # Filter out invalid outfits (duplicate categories)
338
  def is_valid_outfit(subset: List[int]) -> bool:
339
- """Check if outfit has no duplicate categories"""
340
  categories = [get_category_type(cat_str(i)) for i in subset]
341
- # Allow multiple accessories, but only one of each other category
342
  category_counts = {}
 
343
  for cat in categories:
344
- if cat == "accessory":
345
- continue # Allow multiple accessories
346
- if cat in category_counts:
347
- return False # Duplicate non-accessory category
348
- category_counts[cat] = 1
 
 
 
 
 
 
 
 
 
 
 
 
349
  return True
350
 
351
- # Score and filter valid outfits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  valid_candidates = [subset for subset in candidates if is_valid_outfit(subset)]
353
  if not valid_candidates:
354
  # Fallback: use all candidates if no valid ones found
355
  valid_candidates = candidates
356
 
357
- scored = [(subset, score_subset(subset)) for subset in valid_candidates]
 
 
 
 
 
 
 
358
  scored.sort(key=lambda x: x[1], reverse=True)
359
  topk = scored[:num_outfits]
360
 
361
  results = []
362
- for subset, score in topk:
363
  # Double-check validity and get item details
364
  outfit_items = []
365
  for i in subset:
366
  item = proc_items[i]
367
  outfit_items.append({
368
  "id": item["id"],
369
- "category": item.get("category", "unknown")
 
370
  })
371
 
 
 
 
 
 
372
  results.append({
373
  "item_ids": [item["id"] for item in outfit_items],
374
  "items": outfit_items,
375
- "score": float(score),
376
- "categories": [item["category"] for item in outfit_items]
 
 
 
 
 
 
 
 
 
 
 
 
377
  })
378
 
379
  return results
 
233
  if len(proc_items) < 2:
234
  return []
235
 
236
+ # 2) Candidate generation with outfit templates
237
  rng = np.random.default_rng(int(context.get("seed", 42)))
238
  num_outfits = int(context.get("num_outfits", 3))
239
  min_size, max_size = 4, 6
240
  ids = list(range(len(proc_items)))
241
+
242
+ # Outfit templates for cohesive styling
243
+ outfit_templates = {
244
+ "casual": {
245
+ "style": "relaxed, comfortable, everyday",
246
+ "preferred_categories": ["tshirt", "jean", "sneaker", "hoodie", "sweatpant"],
247
+ "color_palette": ["neutral", "denim", "white", "black", "gray"],
248
+ "accessory_limit": 2
249
+ },
250
+ "smart_casual": {
251
+ "style": "polished but relaxed, business casual",
252
+ "preferred_categories": ["shirt", "chino", "loafer", "blazer", "polo"],
253
+ "color_palette": ["navy", "white", "khaki", "brown", "gray"],
254
+ "accessory_limit": 3
255
+ },
256
+ "formal": {
257
+ "style": "professional, elegant, sophisticated",
258
+ "preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit"],
259
+ "color_palette": ["navy", "black", "white", "gray", "charcoal"],
260
+ "accessory_limit": 4
261
+ },
262
+ "sporty": {
263
+ "style": "athletic, active, performance",
264
+ "preferred_categories": ["athletic shirt", "jogger", "running shoe", "tank", "legging"],
265
+ "color_palette": ["bright", "neon", "white", "black", "primary colors"],
266
+ "accessory_limit": 1
267
+ }
268
+ }
269
+
270
+ # Select outfit template (can be passed in context or randomly selected)
271
+ template_name = context.get("outfit_style", rng.choice(list(outfit_templates.keys())))
272
+ template = outfit_templates[template_name]
273
 
274
  # Enhanced category-aware pools with diversity checks
275
  def cat_str(i: int) -> str:
276
  return (proc_items[i].get("category") or "").lower()
277
+
278
+ def extract_color_from_category(category: str) -> str:
279
+ """Extract color information from category name"""
280
+ category_lower = category.lower()
281
+ color_keywords = {
282
+ "black": ["black", "dark", "charcoal", "navy"],
283
+ "white": ["white", "cream", "ivory", "off-white"],
284
+ "gray": ["gray", "grey", "silver", "ash"],
285
+ "brown": ["brown", "tan", "beige", "khaki", "camel"],
286
+ "blue": ["blue", "navy", "denim", "indigo", "royal"],
287
+ "red": ["red", "burgundy", "maroon", "crimson"],
288
+ "green": ["green", "olive", "emerald", "forest"],
289
+ "yellow": ["yellow", "gold", "mustard", "lemon"],
290
+ "pink": ["pink", "rose", "coral", "salmon"],
291
+ "purple": ["purple", "violet", "lavender", "plum"],
292
+ "orange": ["orange", "peach", "apricot", "tangerine"],
293
+ "neutral": ["neutral", "nude", "natural", "earth"]
294
+ }
295
+
296
+ for color, keywords in color_keywords.items():
297
+ if any(kw in category_lower for kw in keywords):
298
+ return color
299
+ return "unknown"
300
+
301
+ def calculate_color_consistency_score(items: List[int]) -> float:
302
+ """Calculate color consistency score for outfit items"""
303
+ colors = [extract_color_from_category(cat_str(i)) for i in items]
304
+ color_counts = {}
305
+ for color in colors:
306
+ color_counts[color] = color_counts.get(color, 0) + 1
307
+
308
+ # Prefer outfits with 2-3 dominant colors
309
+ dominant_colors = [c for c, count in color_counts.items() if count >= 2]
310
+ if len(dominant_colors) == 0:
311
+ return 0.5 # Neutral score for all different colors
312
+ elif len(dominant_colors) == 1:
313
+ return 0.8 # Good consistency
314
+ elif len(dominant_colors) == 2:
315
+ return 1.0 # Perfect balance
316
+ else:
317
+ return 0.3 # Too many dominant colors
318
+
319
+ def calculate_style_consistency_score(items: List[int]) -> float:
320
+ """Calculate style consistency based on template preferences"""
321
+ categories = [cat_str(i) for i in items]
322
+ preferred_cats = template["preferred_categories"]
323
+
324
+ matches = 0
325
+ for cat in categories:
326
+ if any(pref in cat for pref in preferred_cats):
327
+ matches += 1
328
+
329
+ return matches / len(categories) if categories else 0.0
330
 
331
  def get_category_type(cat: str) -> str:
332
+ """Map category to outfit slot type with comprehensive taxonomy"""
333
+ cat_lower = cat.lower().strip()
334
+
335
+ # Upper body items (tops, outerwear)
336
+ upper_keywords = [
337
+ "top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
338
+ "jacket", "blazer", "coat", "hoodie", "sweater", "pullover", "cardigan",
339
+ "vest", "waistcoat", "windbreaker", "bomber", "denim jacket", "leather jacket",
340
+ "polo", "henley", "tunic", "crop top", "bodysuit", "romper", "jumpsuit"
341
+ ]
342
+
343
+ # Bottom items
344
+ bottom_keywords = [
345
+ "pant", "pants", "trouser", "trousers", "jean", "jeans", "denim",
346
+ "skirt", "short", "shorts", "legging", "leggings", "tights",
347
+ "chino", "khaki", "cargo", "jogger", "sweatpant", "sweatpants",
348
+ "culotte", "palazzo", "mini skirt", "midi skirt", "maxi skirt",
349
+ "bermuda", "capri", "bike short", "bike shorts"
350
+ ]
351
+
352
+ # Footwear
353
+ shoe_keywords = [
354
+ "shoe", "shoes", "sneaker", "sneakers", "boot", "boots", "heel", "heels",
355
+ "sandal", "sandals", "flat", "flats", "loafer", "loafers", "oxford",
356
+ "pump", "pumps", "stiletto", "wedge", "ankle boot", "knee high boot",
357
+ "combat boot", "hiking boot", "running shoe", "athletic shoe",
358
+ "mule", "mules", "clog", "clogs", "espadrille", "espadrilles"
359
+ ]
360
+
361
+ # Accessories (can have multiple)
362
+ accessory_keywords = [
363
+ "watch", "belt", "ring", "rings", "bracelet", "bracelets", "necklace", "necklaces",
364
+ "earring", "earrings", "bag", "bags", "handbag", "purse", "clutch", "tote",
365
+ "hat", "cap", "beanie", "scarf", "scarves", "glove", "gloves", "sunglass", "sunglasses",
366
+ "tie", "bow tie", "pocket square", "cufflink", "cufflinks", "brooch", "pin",
367
+ "hair accessory", "headband", "hair clip", "barrette", "scrunchy", "scrunchies"
368
+ ]
369
+
370
+ # Check each category
371
+ if any(k in cat_lower for k in upper_keywords):
372
  return "upper"
373
+ elif any(k in cat_lower for k in bottom_keywords):
374
  return "bottom"
375
+ elif any(k in cat_lower for k in shoe_keywords):
376
  return "shoe"
377
+ elif any(k in cat_lower for k in accessory_keywords):
378
  return "accessory"
379
  else:
380
  return "other"
 
398
 
399
  for _ in range(num_samples):
400
  subset = []
 
401
 
402
+ # EXACT SLOT CONSTRAINTS: Exactly 1 upper, 1 bottom, 1 shoe, ≀2 accessories
403
+ if uppers and bottoms and shoes:
404
+ # Core outfit: exactly 1 of each required slot
405
  subset.append(int(rng.choice(uppers)))
 
 
 
406
  subset.append(int(rng.choice(bottoms)))
 
 
 
407
  subset.append(int(rng.choice(shoes)))
408
+
409
+ # Add accessories based on template limit
410
+ if accs:
411
+ max_accs = template["accessory_limit"]
412
+ num_accs = rng.integers(1, min(max_accs + 1, len(accs) + 1))
413
+ available_accs = [i for i in accs if i not in subset]
414
+ if available_accs:
415
+ selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
416
+ subset.extend(selected_accs.tolist())
417
+
418
+ # Add 0-1 other items for variety (but not if it would exceed max_size)
419
+ if others and len(subset) < max_size:
420
+ available_others = [i for i in others if i not in subset]
421
+ if available_others and rng.random() < 0.3: # 30% chance to add other item
422
+ subset.append(int(rng.choice(available_others)))
423
+ else:
424
+ # Fallback: ensure we have at least 3 items with category diversity
425
+ required_categories = []
426
+ if uppers: required_categories.append(("upper", uppers))
427
+ if bottoms: required_categories.append(("bottom", bottoms))
428
+ if shoes: required_categories.append(("shoe", shoes))
429
+
430
+ # Add one from each available required category
431
+ for cat_type, cat_items in required_categories:
432
+ subset.append(int(rng.choice(cat_items)))
433
+
434
+ # Add accessories if available
435
+ if accs and len(subset) < max_size:
436
+ num_accs = rng.integers(1, min(3, len(accs) + 1))
437
+ available_accs = [i for i in accs if i not in subset]
438
+ if available_accs:
439
+ selected_accs = rng.choice(available_accs, size=min(num_accs, len(available_accs)), replace=False)
440
+ subset.extend(selected_accs.tolist())
441
 
442
+ # Remove duplicates and validate
443
  subset = list(set(subset))
444
+ if len(subset) >= 3: # At least 3 items for a valid outfit
445
  candidates.append(subset)
446
 
447
  # 3) Score using ViT
 
455
  s = self.vit.score_compatibility(embs).item()
456
  return float(s)
457
 
458
+ # Enhanced validation with strict slot constraints
459
  def is_valid_outfit(subset: List[int]) -> bool:
460
+ """Check if outfit meets exact slot requirements"""
461
  categories = [get_category_type(cat_str(i)) for i in subset]
 
462
  category_counts = {}
463
+
464
  for cat in categories:
465
+ category_counts[cat] = category_counts.get(cat, 0) + 1
466
+
467
+ # STRICT VALIDATION:
468
+ # - Exactly 1 upper, 1 bottom, 1 shoe
469
+ # - ≀2 accessories
470
+ # - No other duplicates
471
+ if category_counts.get("upper", 0) != 1:
472
+ return False
473
+ if category_counts.get("bottom", 0) != 1:
474
+ return False
475
+ if category_counts.get("shoe", 0) != 1:
476
+ return False
477
+ if category_counts.get("accessory", 0) > 2:
478
+ return False
479
+ if category_counts.get("other", 0) > 1:
480
+ return False
481
+
482
  return True
483
 
484
+ def calculate_outfit_penalty(subset: List[int], base_score: float) -> float:
485
+ """Calculate penalty-adjusted score for outfit quality with style/color bonuses"""
486
+ categories = [get_category_type(cat_str(i)) for i in subset]
487
+ category_counts = {}
488
+
489
+ for cat in categories:
490
+ category_counts[cat] = category_counts.get(cat, 0) + 1
491
+
492
+ penalty = 0.0
493
+ bonus = 0.0
494
+
495
+ # Missing core slots: -∞ penalty
496
+ if category_counts.get("upper", 0) == 0:
497
+ penalty += -1000.0
498
+ if category_counts.get("bottom", 0) == 0:
499
+ penalty += -1000.0
500
+ if category_counts.get("shoe", 0) == 0:
501
+ penalty += -1000.0
502
+
503
+ # Duplicate non-accessory categories: -∞ penalty
504
+ for cat, count in category_counts.items():
505
+ if cat != "accessory" and count > 1:
506
+ penalty += -1000.0
507
+
508
+ # Too many accessories: moderate penalty
509
+ max_accs = template["accessory_limit"]
510
+ if category_counts.get("accessory", 0) > max_accs:
511
+ penalty += -2.0
512
+
513
+ # Unbalanced outfit: small penalty
514
+ if len(subset) < 3:
515
+ penalty += -1.0
516
+ elif len(subset) > 6:
517
+ penalty += -0.5
518
+
519
+ # Style consistency bonus
520
+ style_score = calculate_style_consistency_score(subset)
521
+ bonus += style_score * 0.5 # Up to 0.5 bonus for style consistency
522
+
523
+ # Color consistency bonus
524
+ color_score = calculate_color_consistency_score(subset)
525
+ bonus += color_score * 0.3 # Up to 0.3 bonus for color consistency
526
+
527
+ # Template adherence bonus
528
+ if style_score > 0.6: # Good style match
529
+ bonus += 0.2
530
+
531
+ return base_score + penalty + bonus
532
+
533
+ # Score and filter valid outfits with penalty adjustment
534
  valid_candidates = [subset for subset in candidates if is_valid_outfit(subset)]
535
  if not valid_candidates:
536
  # Fallback: use all candidates if no valid ones found
537
  valid_candidates = candidates
538
 
539
+ # Score with penalty adjustment
540
+ scored = []
541
+ for subset in valid_candidates:
542
+ base_score = score_subset(subset)
543
+ adjusted_score = calculate_outfit_penalty(subset, base_score)
544
+ scored.append((subset, adjusted_score, base_score))
545
+
546
+ # Sort by penalty-adjusted score
547
  scored.sort(key=lambda x: x[1], reverse=True)
548
  topk = scored[:num_outfits]
549
 
550
  results = []
551
+ for subset, adjusted_score, base_score in topk:
552
  # Double-check validity and get item details
553
  outfit_items = []
554
  for i in subset:
555
  item = proc_items[i]
556
  outfit_items.append({
557
  "id": item["id"],
558
+ "category": item.get("category", "unknown"),
559
+ "category_type": get_category_type(item.get("category", ""))
560
  })
561
 
562
+ # Calculate additional metrics
563
+ style_score = calculate_style_consistency_score(subset)
564
+ color_score = calculate_color_consistency_score(subset)
565
+ colors = [extract_color_from_category(cat_str(i)) for i in subset]
566
+
567
  results.append({
568
  "item_ids": [item["id"] for item in outfit_items],
569
  "items": outfit_items,
570
+ "score": float(adjusted_score),
571
+ "base_score": float(base_score),
572
+ "categories": [item["category"] for item in outfit_items],
573
+ "category_types": [item["category_type"] for item in outfit_items],
574
+ "outfit_size": len(outfit_items),
575
+ "is_valid": is_valid_outfit(subset),
576
+ "template": {
577
+ "name": template_name,
578
+ "style": template["style"],
579
+ "style_score": float(style_score),
580
+ "color_score": float(color_score),
581
+ "colors": colors,
582
+ "accessory_limit": template["accessory_limit"]
583
+ }
584
  })
585
 
586
  return results
train_vit_triplet.py CHANGED
@@ -26,13 +26,15 @@ def parse_args() -> argparse.Namespace:
26
  p.add_argument("--batch_size", type=int, default=4)
27
  p.add_argument("--lr", type=float, default=5e-4)
28
  p.add_argument("--embedding_dim", type=int, default=512)
29
- p.add_argument("--triplet_margin", type=float, default=0.3)
30
  p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
31
  p.add_argument("--eval_every", type=int, default=1)
32
  p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
33
- p.add_argument("--max_samples", type=int, default=500, help="Maximum number of training samples (for faster testing)")
34
- p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
35
  p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
 
 
36
  return p.parse_args()
37
 
38
 
@@ -113,6 +115,17 @@ def main() -> None:
113
 
114
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
115
  triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
118
  best_loss = float("inf")
@@ -159,7 +172,13 @@ def main() -> None:
159
 
160
  optimizer.zero_grad(set_to_none=True)
161
  loss.backward()
 
 
 
 
 
162
  optimizer.step()
 
163
 
164
  # Collect metrics (simplified for ViT training)
165
  # Note: ViT training uses outfit-level embeddings, not classification predictions
 
26
  p.add_argument("--batch_size", type=int, default=4)
27
  p.add_argument("--lr", type=float, default=5e-4)
28
  p.add_argument("--embedding_dim", type=int, default=512)
29
+ p.add_argument("--triplet_margin", type=float, default=0.5)
30
  p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
31
  p.add_argument("--eval_every", type=int, default=1)
32
  p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training")
33
+ p.add_argument("--max_samples", type=int, default=5000, help="Maximum number of training samples (for better quality)")
34
+ p.add_argument("--early_stopping_patience", type=int, default=5, help="Early stopping patience")
35
  p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
36
+ p.add_argument("--gradient_clip", type=float, default=1.0, help="Gradient clipping value")
37
+ p.add_argument("--warmup_epochs", type=int, default=2, help="Learning rate warmup epochs")
38
  return p.parse_args()
39
 
40
 
 
115
 
116
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
117
  triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
118
+
119
+ # Learning rate scheduler with warmup
120
+ total_steps = len(loader) * args.epochs
121
+ warmup_steps = len(loader) * args.warmup_epochs
122
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
123
+ optimizer,
124
+ max_lr=args.lr,
125
+ total_steps=total_steps,
126
+ pct_start=warmup_steps/total_steps,
127
+ anneal_strategy='cos'
128
+ )
129
 
130
  export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
131
  best_loss = float("inf")
 
172
 
173
  optimizer.zero_grad(set_to_none=True)
174
  loss.backward()
175
+
176
+ # Gradient clipping for stability
177
+ if args.gradient_clip > 0:
178
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
179
+
180
  optimizer.step()
181
+ scheduler.step() # Update learning rate
182
 
183
  # Collect metrics (simplified for ViT training)
184
  # Note: ViT training uses outfit-level embeddings, not classification predictions