Ali Mohsin commited on
Commit
e13ff13
Β·
1 Parent(s): 6fa8724
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +110 -23
app.py CHANGED
@@ -809,7 +809,7 @@ with gr.Blocks(fill_height=True, title="Dressify - Advanced Outfit Recommendatio
809
  with gr.Row():
810
  occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
811
  weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
812
- outfit_style = gr.Dropdown(choices=["casual", "smart_casual", "formal", "sporty"], value="casual", label="Outfit Style")
813
  with gr.Row():
814
  num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits")
815
  out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
 
809
  with gr.Row():
810
  occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
811
  weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
812
+ outfit_style = gr.Dropdown(choices=["casual", "smart_casual", "formal", "sporty", "traditional"], value="casual", label="Outfit Style")
813
  with gr.Row():
814
  num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Number of outfits")
815
  out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
inference.py CHANGED
@@ -89,7 +89,7 @@ class InferenceService:
89
  return "other"
90
 
91
  try:
92
- # Define clothing categories with descriptions
93
  categories = [
94
  "a shirt, t-shirt, blouse, or top",
95
  "pants, jeans, trousers, or bottoms",
@@ -101,7 +101,10 @@ class InferenceService:
101
  "a watch, ring, necklace, or jewelry",
102
  "a bag, purse, or handbag",
103
  "a hat, cap, or headwear",
104
- "a belt or accessory"
 
 
 
105
  ]
106
 
107
  # Prepare image and text
@@ -117,7 +120,7 @@ class InferenceService:
117
  similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
118
  values, indices = similarity[0].topk(1)
119
 
120
- # Map to outfit categories
121
  category_map = {
122
  0: "shirt", # shirt, t-shirt, blouse, top
123
  1: "pants", # pants, jeans, trousers, bottoms
@@ -129,7 +132,10 @@ class InferenceService:
129
  7: "accessory", # watch, ring, necklace, jewelry
130
  8: "accessory", # bag, purse, handbag
131
  9: "accessory", # hat, cap, headwear
132
- 10: "accessory" # belt, accessory
 
 
 
133
  }
134
 
135
  predicted_category = category_map.get(indices[0].item(), "other")
@@ -429,17 +435,18 @@ class InferenceService:
429
  },
430
  "formal": {
431
  "style": "professional, elegant, sophisticated",
432
- "preferred_categories": ["blazer", "dress shirt", "dress pant", "oxford", "suit", "shirt", "pants", "shoes"],
433
  "color_palette": ["navy", "black", "white", "gray", "charcoal"],
434
  "accessory_limit": 4,
 
435
  "weather_modifiers": {
436
- "hot": {"preferred_categories": ["light shirt", "light pant", "oxford"]},
437
- "cold": {"preferred_categories": ["blazer", "suit", "boot"]},
438
- "rain": {"preferred_categories": ["blazer", "boot", "umbrella"]}
439
  },
440
  "occasion_modifiers": {
441
- "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4},
442
- "casual": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
443
  }
444
  },
445
  "sporty": {
@@ -456,6 +463,23 @@ class InferenceService:
456
  "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 2},
457
  "formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
458
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  }
460
  }
461
 
@@ -591,15 +615,29 @@ class InferenceService:
591
  return "bottom"
592
  elif cat_lower == "shoes":
593
  return "shoe"
 
 
594
  elif cat_lower == "accessory":
595
  return "accessory"
 
 
 
 
 
 
596
 
597
- # Upper body items (tops, outerwear)
598
  upper_keywords = [
599
  "top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
600
- "jacket", "blazer", "coat", "hoodie", "sweater", "pullover", "cardigan",
601
- "vest", "waistcoat", "windbreaker", "bomber", "denim jacket", "leather jacket",
602
- "polo", "henley", "tunic", "crop top", "bodysuit", "romper", "jumpsuit"
 
 
 
 
 
 
603
  ]
604
 
605
  # Bottom items
@@ -608,7 +646,7 @@ class InferenceService:
608
  "skirt", "short", "shorts", "legging", "leggings", "tights",
609
  "chino", "khaki", "cargo", "jogger", "sweatpant", "sweatpants",
610
  "culotte", "palazzo", "mini skirt", "midi skirt", "maxi skirt",
611
- "bermuda", "capri", "bike short", "bike shorts"
612
  ]
613
 
614
  # Footwear
@@ -617,7 +655,8 @@ class InferenceService:
617
  "sandal", "sandals", "flat", "flats", "loafer", "loafers", "oxford",
618
  "pump", "pumps", "stiletto", "wedge", "ankle boot", "knee high boot",
619
  "combat boot", "hiking boot", "running shoe", "athletic shoe",
620
- "mule", "mules", "clog", "clogs", "espadrille", "espadrilles"
 
621
  ]
622
 
623
  # Accessories (can have multiple)
@@ -630,7 +669,9 @@ class InferenceService:
630
  ]
631
 
632
  # Check each category
633
- if any(k in cat_lower for k in upper_keywords):
 
 
634
  return "upper"
635
  elif any(k in cat_lower for k in bottom_keywords):
636
  return "bottom"
@@ -646,13 +687,14 @@ class InferenceService:
646
  uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
647
  bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
648
  shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
 
649
  accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
650
  others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
651
 
652
- print(f"πŸ” DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, accessories: {len(accs)}, others: {len(others)}")
653
 
654
  # Check if we have enough items to create outfits
655
- total_items = len(uppers) + len(bottoms) + len(shoes) + len(accs) + len(others)
656
  if total_items < 2:
657
  print(f"πŸ” DEBUG: Not enough items to create outfits - total: {total_items}")
658
  return []
@@ -719,10 +761,39 @@ class InferenceService:
719
 
720
  # Strategy 1: Core outfit (shirt + pants + shoes) + accessories
721
  if strategy == 0 and uppers and bottoms and shoes:
722
- # Core outfit: exactly 1 of each required slot
723
- subset.append(int(rng.choice(uppers)))
724
- subset.append(int(rng.choice(bottoms)))
725
- subset.append(int(rng.choice(shoes)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
  # Prioritize accessories - higher chance of including them
728
  remaining_slots = outfit_length - len(subset)
@@ -851,6 +922,10 @@ class InferenceService:
851
  if category_counts.get("shoe", 0) == 0:
852
  penalty += -1000.0
853
 
 
 
 
 
854
  # Duplicate non-accessory categories: -∞ penalty
855
  for cat, count in category_counts.items():
856
  if cat != "accessory" and count > 1:
@@ -884,6 +959,18 @@ class InferenceService:
884
  if "accessory" in categories:
885
  bonus += 0.3 # Bonus for including accessories
886
 
 
 
 
 
 
 
 
 
 
 
 
 
887
  return base_score + penalty + bonus
888
 
889
  # Score and filter valid outfits with penalty adjustment
 
89
  return "other"
90
 
91
  try:
92
+ # Define clothing categories with descriptions (including Pakistani traditional wear)
93
  categories = [
94
  "a shirt, t-shirt, blouse, or top",
95
  "pants, jeans, trousers, or bottoms",
 
101
  "a watch, ring, necklace, or jewelry",
102
  "a bag, purse, or handbag",
103
  "a hat, cap, or headwear",
104
+ "a belt or accessory",
105
+ "a kameez, kurta, or traditional Pakistani shirt",
106
+ "shalwar, traditional Pakistani pants, or loose trousers",
107
+ "Peshawari chappal, traditional Pakistani sandals, or ethnic footwear"
108
  ]
109
 
110
  # Prepare image and text
 
120
  similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
121
  values, indices = similarity[0].topk(1)
122
 
123
+ # Map to outfit categories (including Pakistani traditional wear)
124
  category_map = {
125
  0: "shirt", # shirt, t-shirt, blouse, top
126
  1: "pants", # pants, jeans, trousers, bottoms
 
132
  7: "accessory", # watch, ring, necklace, jewelry
133
  8: "accessory", # bag, purse, handbag
134
  9: "accessory", # hat, cap, headwear
135
+ 10: "accessory", # belt, accessory
136
+ 11: "kameez", # kameez, kurta, traditional Pakistani shirt
137
+ 12: "shalwar", # shalwar, traditional Pakistani pants
138
+ 13: "peshawari" # Peshawari chappal, traditional Pakistani sandals
139
  }
140
 
141
  predicted_category = category_map.get(indices[0].item(), "other")
 
435
  },
436
  "formal": {
437
  "style": "professional, elegant, sophisticated",
438
+ "preferred_categories": ["blazer", "jacket", "suit jacket", "dress shirt", "dress pant", "oxford", "suit", "shirt", "pants", "shoes"],
439
  "color_palette": ["navy", "black", "white", "gray", "charcoal"],
440
  "accessory_limit": 4,
441
+ "requires_outerwear": True, # Flag to indicate formal outfits should include jackets
442
  "weather_modifiers": {
443
+ "hot": {"preferred_categories": ["light shirt", "light pant", "oxford"], "requires_outerwear": False},
444
+ "cold": {"preferred_categories": ["blazer", "suit", "boot"], "requires_outerwear": True},
445
+ "rain": {"preferred_categories": ["blazer", "boot", "umbrella"], "requires_outerwear": True}
446
  },
447
  "occasion_modifiers": {
448
+ "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 4, "requires_outerwear": True},
449
+ "casual": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3, "requires_outerwear": False}
450
  }
451
  },
452
  "sporty": {
 
463
  "business": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 2},
464
  "formal": {"preferred_categories": ["shirt", "pants", "shoes"], "accessory_limit": 3}
465
  }
466
+ },
467
+ "traditional": {
468
+ "style": "Pakistani traditional, cultural, ethnic",
469
+ "preferred_categories": ["kameez", "kurta", "shalwar", "peshawari", "chappal", "traditional", "ethnic"],
470
+ "color_palette": ["white", "black", "navy", "maroon", "gold", "green", "traditional colors"],
471
+ "accessory_limit": 3,
472
+ "requires_traditional": True, # Flag for traditional outfit combinations
473
+ "weather_modifiers": {
474
+ "hot": {"preferred_categories": ["light kameez", "cotton shalwar", "peshawari chappal"]},
475
+ "cold": {"preferred_categories": ["warm kameez", "thick shalwar", "traditional boots"]},
476
+ "rain": {"preferred_categories": ["waterproof kameez", "shalwar", "traditional boots"]}
477
+ },
478
+ "occasion_modifiers": {
479
+ "business": {"preferred_categories": ["formal kameez", "shalwar", "peshawari"], "accessory_limit": 2},
480
+ "formal": {"preferred_categories": ["elegant kameez", "shalwar", "peshawari"], "accessory_limit": 3},
481
+ "casual": {"preferred_categories": ["casual kameez", "shalwar", "chappal"], "accessory_limit": 2}
482
+ }
483
  }
484
  }
485
 
 
615
  return "bottom"
616
  elif cat_lower == "shoes":
617
  return "shoe"
618
+ elif cat_lower == "jacket":
619
+ return "outerwear" # Separate category for jackets/blazers
620
  elif cat_lower == "accessory":
621
  return "accessory"
622
+ elif cat_lower == "kameez":
623
+ return "upper" # Kameez is upper body wear
624
+ elif cat_lower == "shalwar":
625
+ return "bottom" # Shalwar is bottom wear
626
+ elif cat_lower == "peshawari":
627
+ return "shoe" # Peshawari chappal is footwear
628
 
629
+ # Upper body items (tops, innerwear)
630
  upper_keywords = [
631
  "top", "shirt", "tshirt", "t-shirt", "blouse", "tank", "camisole", "cami",
632
+ "hoodie", "sweater", "pullover", "cardigan", "polo", "henley", "tunic",
633
+ "crop top", "bodysuit", "romper", "jumpsuit", "kameez", "kurta", "shalwar kameez"
634
+ ]
635
+
636
+ # Outerwear items (jackets, coats, blazers)
637
+ outerwear_keywords = [
638
+ "jacket", "blazer", "coat", "vest", "waistcoat", "windbreaker", "bomber",
639
+ "denim jacket", "leather jacket", "suit jacket", "sport coat", "trench coat",
640
+ "pea coat", "overcoat", "cardigan", "sweater jacket"
641
  ]
642
 
643
  # Bottom items
 
646
  "skirt", "short", "shorts", "legging", "leggings", "tights",
647
  "chino", "khaki", "cargo", "jogger", "sweatpant", "sweatpants",
648
  "culotte", "palazzo", "mini skirt", "midi skirt", "maxi skirt",
649
+ "bermuda", "capri", "bike short", "bike shorts", "shalwar", "shalwar kameez"
650
  ]
651
 
652
  # Footwear
 
655
  "sandal", "sandals", "flat", "flats", "loafer", "loafers", "oxford",
656
  "pump", "pumps", "stiletto", "wedge", "ankle boot", "knee high boot",
657
  "combat boot", "hiking boot", "running shoe", "athletic shoe",
658
+ "mule", "mules", "clog", "clogs", "espadrille", "espadrilles",
659
+ "peshawari", "chappal", "peshawari chappal", "traditional sandal"
660
  ]
661
 
662
  # Accessories (can have multiple)
 
669
  ]
670
 
671
  # Check each category
672
+ if any(k in cat_lower for k in outerwear_keywords):
673
+ return "outerwear"
674
+ elif any(k in cat_lower for k in upper_keywords):
675
  return "upper"
676
  elif any(k in cat_lower for k in bottom_keywords):
677
  return "bottom"
 
687
  uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
688
  bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
689
  shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
690
+ outerwear = [i for i in ids if get_category_type(cat_str(i)) == "outerwear"]
691
  accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
692
  others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
693
 
694
+ print(f"πŸ” DEBUG: Category pools - uppers: {len(uppers)}, bottoms: {len(bottoms)}, shoes: {len(shoes)}, outerwear: {len(outerwear)}, accessories: {len(accs)}, others: {len(others)}")
695
 
696
  # Check if we have enough items to create outfits
697
+ total_items = len(uppers) + len(bottoms) + len(shoes) + len(outerwear) + len(accs) + len(others)
698
  if total_items < 2:
699
  print(f"πŸ” DEBUG: Not enough items to create outfits - total: {total_items}")
700
  return []
 
761
 
762
  # Strategy 1: Core outfit (shirt + pants + shoes) + accessories
763
  if strategy == 0 and uppers and bottoms and shoes:
764
+ # Special handling for traditional Pakistani outfits: kameez + shalwar + peshawari
765
+ if outfit_style == "traditional":
766
+ # Check for traditional items
767
+ traditional_uppers = [i for i in uppers if "kameez" in cat_str(i) or "kurta" in cat_str(i)]
768
+ traditional_bottoms = [i for i in bottoms if "shalwar" in cat_str(i)]
769
+ traditional_shoes = [i for i in shoes if "peshawari" in cat_str(i) or "chappal" in cat_str(i)]
770
+
771
+ if traditional_uppers and traditional_bottoms and traditional_shoes:
772
+ # Traditional Pakistani outfit: kameez + shalwar + peshawari
773
+ subset.append(int(rng.choice(traditional_uppers))) # Kameez/Kurta
774
+ subset.append(int(rng.choice(traditional_bottoms))) # Shalwar
775
+ subset.append(int(rng.choice(traditional_shoes))) # Peshawari chappal
776
+ print(f"πŸ” DEBUG: Generated traditional Pakistani outfit: kameez + shalwar + peshawari")
777
+ else:
778
+ # Fallback to regular outfit if traditional items not available
779
+ subset.append(int(rng.choice(uppers)))
780
+ subset.append(int(rng.choice(bottoms)))
781
+ subset.append(int(rng.choice(shoes)))
782
+ print(f"πŸ” DEBUG: Generated regular outfit (traditional items not available)")
783
+
784
+ # Special handling for formal outfits: require jacket + shirt + pants + shoes
785
+ elif occasion == "formal" and outerwear and len(outerwear) > 0:
786
+ # Formal 3-piece suit: jacket + shirt + pants + shoes
787
+ subset.append(int(rng.choice(outerwear))) # Jacket/blazer
788
+ subset.append(int(rng.choice(uppers))) # Shirt
789
+ subset.append(int(rng.choice(bottoms))) # Pants
790
+ subset.append(int(rng.choice(shoes))) # Shoes
791
+ print(f"πŸ” DEBUG: Generated formal 3-piece suit with jacket")
792
+ else:
793
+ # Regular core outfit: shirt + pants + shoes
794
+ subset.append(int(rng.choice(uppers)))
795
+ subset.append(int(rng.choice(bottoms)))
796
+ subset.append(int(rng.choice(shoes)))
797
 
798
  # Prioritize accessories - higher chance of including them
799
  remaining_slots = outfit_length - len(subset)
 
922
  if category_counts.get("shoe", 0) == 0:
923
  penalty += -1000.0
924
 
925
+ # Special penalty for formal outfits missing outerwear
926
+ if occasion == "formal" and category_counts.get("outerwear", 0) == 0:
927
+ penalty += -500.0 # Strong penalty but not -∞ for formal without jacket
928
+
929
  # Duplicate non-accessory categories: -∞ penalty
930
  for cat, count in category_counts.items():
931
  if cat != "accessory" and count > 1:
 
959
  if "accessory" in categories:
960
  bonus += 0.3 # Bonus for including accessories
961
 
962
+ # Formal outfit bonus for including outerwear
963
+ if occasion == "formal" and "outerwear" in categories:
964
+ bonus += 0.5 # Strong bonus for formal outfits with jackets
965
+
966
+ # Traditional Pakistani outfit bonus
967
+ if outfit_style == "traditional":
968
+ traditional_items = [cat for cat in categories if any(traditional in cat.lower() for traditional in ["kameez", "kurta", "shalwar", "peshawari", "chappal"])]
969
+ if len(traditional_items) >= 2: # At least 2 traditional items
970
+ bonus += 0.6 # Strong bonus for traditional outfit combinations
971
+ if len(traditional_items) == 3: # Complete traditional set
972
+ bonus += 0.3 # Additional bonus for complete traditional outfit
973
+
974
  return base_score + penalty + bonus
975
 
976
  # Score and filter valid outfits with penalty adjustment