Ali Mohsin commited on
Commit
55c158e
·
1 Parent(s): 42733e7
data/polyvore.py CHANGED
@@ -118,9 +118,7 @@ class PolyvoreOutfitTripletDataset(Dataset):
118
  self.samples: List[Dict[str, Any]] = json.load(f)
119
 
120
  def _load_image(self, item_id: str) -> Image.Image:
121
- img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
122
- if not os.path.exists(img_path):
123
- raise FileNotFoundError(img_path)
124
  return Image.open(img_path).convert("RGB")
125
 
126
  def __len__(self) -> int:
 
118
  self.samples: List[Dict[str, Any]] = json.load(f)
119
 
120
  def _load_image(self, item_id: str) -> Image.Image:
121
+ img_path = PolyvoreTripletDataset._find_image_path(self, item_id)
 
 
122
  return Image.open(img_path).convert("RGB")
123
 
124
  def __len__(self) -> int:
scripts/prepare_polyvore.py CHANGED
@@ -3,7 +3,79 @@ import json
3
  import random
4
  import argparse
5
  from pathlib import Path
6
- from typing import Dict, Any, List, Set
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]:
@@ -16,30 +88,33 @@ def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]:
16
  for p in candidates:
17
  if os.path.exists(p):
18
  with open(p, "r") as f:
19
- data = json.load(f)
20
- # Expect list where each item has key "items" listing item ids
21
- return data
22
- raise FileNotFoundError(f"Could not find {split}.json in {root} or {root}/splits")
 
23
 
24
 
25
  def try_load_any_outfits(root: str) -> List[Dict[str, Any]]:
26
- candidates = [
27
- os.path.join(root, "outfits.json"),
28
- os.path.join(root, "all.json"),
29
- os.path.join(root, "data.json"),
30
- ]
31
- for p in candidates:
32
- if os.path.exists(p):
33
- with open(p, "r") as f:
34
- return json.load(f)
35
- # As a last resort, merge available splits
36
  merged: List[Dict[str, Any]] = []
37
  for sp in ["train", "valid", "test"]:
38
  try:
39
  merged.extend(load_outfits_json(root, sp))
40
  except FileNotFoundError:
41
  continue
42
- return merged
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]:
@@ -141,7 +216,20 @@ def main() -> None:
141
  out_dir = args.out or os.path.join(args.root, "splits")
142
  Path(out_dir).mkdir(parents=True, exist_ok=True)
143
 
144
- if args.random_split:
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  all_outfits = try_load_any_outfits(args.root)
146
  if not all_outfits:
147
  raise FileNotFoundError("No outfits found to split. Provide official splits or an outfits.json file.")
@@ -155,14 +243,6 @@ def main() -> None:
155
  "valid": all_outfits[n_train:n_train + n_valid],
156
  "test": all_outfits[n_train + n_valid:],
157
  }
158
- else:
159
- splits = {}
160
- for split in ["train", "valid", "test"]:
161
- try:
162
- splits[split] = load_outfits_json(args.root, split)
163
- except FileNotFoundError as e:
164
- print(f"Skipping {split}: {e}")
165
- splits[split] = []
166
 
167
  for split, outfits in splits.items():
168
  if not outfits:
 
3
  import random
4
  import argparse
5
  from pathlib import Path
6
+ from typing import Dict, Any, List, Set, Union
7
+
8
+
9
+ def _normalize_outfits(obj: Union[List[Any], Dict[str, Any]]) -> List[Dict[str, Any]]:
10
+ """Normalize various Polyvore JSON formats into a list of {"items": [id,...]} dicts.
11
+
12
+ Accepts:
13
+ - List of objects where each object may be:
14
+ - {"items": [id,...]} already
15
+ - {"items": [{"item_id": id}...]} (extract item_id or id)
16
+ - {"set_id": ..., "items": [...]}
17
+ - List of ids directly
18
+ - Dict mapping outfit_id -> list of item ids or an object with items.
19
+ """
20
+ result: List[Dict[str, Any]] = []
21
+ if isinstance(obj, dict):
22
+ # values could be list of ids or dicts with items
23
+ values = list(obj.values())
24
+ for v in values:
25
+ if isinstance(v, list):
26
+ # list of ids or list of dicts
27
+ if len(v) > 0 and isinstance(v[0], dict):
28
+ items = []
29
+ for it in v:
30
+ if isinstance(it, dict):
31
+ iid = it.get("item_id") or it.get("id") or it.get("itemId")
32
+ if iid is not None:
33
+ items.append(str(iid))
34
+ if items:
35
+ result.append({"items": items})
36
+ else:
37
+ result.append({"items": [str(x) for x in v]})
38
+ elif isinstance(v, dict):
39
+ if "items" in v:
40
+ itm = v["items"]
41
+ if isinstance(itm, list):
42
+ if itm and isinstance(itm[0], dict):
43
+ items = []
44
+ for it in itm:
45
+ iid = it.get("item_id") or it.get("id") or it.get("itemId")
46
+ if iid is not None:
47
+ items.append(str(iid))
48
+ if items:
49
+ result.append({"items": items})
50
+ else:
51
+ result.append({"items": [str(x) for x in itm]})
52
+ return result
53
+ if isinstance(obj, list):
54
+ for e in obj:
55
+ if isinstance(e, dict):
56
+ if "items" in e:
57
+ itm = e["items"]
58
+ if isinstance(itm, list):
59
+ if itm and isinstance(itm[0], dict):
60
+ items = []
61
+ for it in itm:
62
+ iid = it.get("item_id") or it.get("id") or it.get("itemId")
63
+ if iid is not None:
64
+ items.append(str(iid))
65
+ if items:
66
+ result.append({"items": items})
67
+ else:
68
+ result.append({"items": [str(x) for x in itm]})
69
+ else:
70
+ # some variants use different key names but include list of item ids
71
+ for k in ("good", "outfit", "products"):
72
+ if k in e and isinstance(e[k], list):
73
+ result.append({"items": [str(x) for x in e[k]]})
74
+ break
75
+ elif isinstance(e, list):
76
+ result.append({"items": [str(x) for x in e]})
77
+ return result
78
+ return result
79
 
80
 
81
  def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]:
 
88
  for p in candidates:
89
  if os.path.exists(p):
90
  with open(p, "r") as f:
91
+ raw = json.load(f)
92
+ data = _normalize_outfits(raw)
93
+ if data:
94
+ return data
95
+ raise FileNotFoundError(f"Could not find usable {split} split in {root} or {root}/splits")
96
 
97
 
98
  def try_load_any_outfits(root: str) -> List[Dict[str, Any]]:
99
+ # Prefer official splits if present
 
 
 
 
 
 
 
 
 
100
  merged: List[Dict[str, Any]] = []
101
  for sp in ["train", "valid", "test"]:
102
  try:
103
  merged.extend(load_outfits_json(root, sp))
104
  except FileNotFoundError:
105
  continue
106
+ if merged:
107
+ return merged
108
+ # Fallback: common aggregated files
109
+ for name in ("outfits.json", "all.json", "data.json"):
110
+ p = os.path.join(root, name)
111
+ if os.path.exists(p):
112
+ with open(p, "r") as f:
113
+ raw = json.load(f)
114
+ data = _normalize_outfits(raw)
115
+ if data:
116
+ return data
117
+ return []
118
 
119
 
120
  def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]:
 
216
  out_dir = args.out or os.path.join(args.root, "splits")
217
  Path(out_dir).mkdir(parents=True, exist_ok=True)
218
 
219
+ # Prefer official splits; if missing, optionally create random split
220
+ splits = {}
221
+ found_any_official = False
222
+ for split in ["train", "valid", "test"]:
223
+ try:
224
+ data = load_outfits_json(args.root, split)
225
+ splits[split] = data
226
+ if data:
227
+ found_any_official = True
228
+ except FileNotFoundError as e:
229
+ print(f"Skipping {split}: {e}")
230
+ splits[split] = []
231
+
232
+ if not found_any_official and args.random_split:
233
  all_outfits = try_load_any_outfits(args.root)
234
  if not all_outfits:
235
  raise FileNotFoundError("No outfits found to split. Provide official splits or an outfits.json file.")
 
243
  "valid": all_outfits[n_train:n_train + n_valid],
244
  "test": all_outfits[n_train + n_valid:],
245
  }
 
 
 
 
 
 
 
 
246
 
247
  for split, outfits in splits.items():
248
  if not outfits:
train_resnet.py CHANGED
@@ -30,6 +30,35 @@ def main() -> None:
30
  if device == "cuda":
31
  torch.backends.cudnn.benchmark = True
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  dataset = PolyvoreTripletDataset(args.data_root, split="train")
34
 
35
  loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"))
 
30
  if device == "cuda":
31
  torch.backends.cudnn.benchmark = True
32
 
33
+ # Ensure splits exist; if missing, prepare from official splits
34
+ splits_dir = os.path.join(args.data_root, "splits")
35
+ triplet_path = os.path.join(splits_dir, "train.json")
36
+ if not os.path.exists(triplet_path):
37
+ os.makedirs(splits_dir, exist_ok=True)
38
+ try:
39
+ from scripts.prepare_polyvore import main as prepare_main
40
+ import sys
41
+ argv_bak = sys.argv
42
+ try:
43
+ # First try using official splits (no random)
44
+ sys.argv = ["prepare_polyvore.py", "--root", args.data_root]
45
+ prepare_main()
46
+ finally:
47
+ sys.argv = argv_bak
48
+ except Exception:
49
+ # As a fallback, try random split on any available aggregate file
50
+ try:
51
+ from scripts.prepare_polyvore import main as prepare_main
52
+ import sys
53
+ argv_bak = sys.argv
54
+ try:
55
+ sys.argv = ["prepare_polyvore.py", "--root", args.data_root, "--random_split"]
56
+ prepare_main()
57
+ finally:
58
+ sys.argv = argv_bak
59
+ except Exception:
60
+ pass
61
+
62
  dataset = PolyvoreTripletDataset(args.data_root, split="train")
63
 
64
  loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"))
train_vit_triplet.py CHANGED
@@ -45,6 +45,33 @@ def main() -> None:
45
  if device == "cuda":
46
  torch.backends.cudnn.benchmark = True
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
49
 
50
  def collate(batch):
 
45
  if device == "cuda":
46
  torch.backends.cudnn.benchmark = True
47
 
48
+ # Ensure outfit triplets exist
49
+ splits_dir = os.path.join(args.data_root, "splits")
50
+ trip_path = os.path.join(splits_dir, "outfit_triplets_train.json")
51
+ if not os.path.exists(trip_path):
52
+ os.makedirs(splits_dir, exist_ok=True)
53
+ try:
54
+ from scripts.prepare_polyvore import main as prepare_main
55
+ import sys
56
+ argv_bak = sys.argv
57
+ try:
58
+ sys.argv = ["prepare_polyvore.py", "--root", args.data_root]
59
+ prepare_main()
60
+ finally:
61
+ sys.argv = argv_bak
62
+ except Exception:
63
+ try:
64
+ from scripts.prepare_polyvore import main as prepare_main
65
+ import sys
66
+ argv_bak = sys.argv
67
+ try:
68
+ sys.argv = ["prepare_polyvore.py", "--root", args.data_root, "--random_split"]
69
+ prepare_main()
70
+ finally:
71
+ sys.argv = argv_bak
72
+ except Exception:
73
+ pass
74
+
75
  dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
76
 
77
  def collate(batch):
utils/data_fetch.py CHANGED
@@ -36,42 +36,43 @@ def ensure_dataset_ready() -> Optional[str]:
36
  root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore"))
37
  Path(root).mkdir(parents=True, exist_ok=True)
38
 
39
- # If already present, ensure images are unzipped and return
40
  _unzip_images_if_needed(root)
41
- if os.path.isdir(os.path.join(root, "images")):
42
- return root
43
 
44
  # Download the HF dataset snapshot into root
45
  try:
46
  # Only fetch what's needed to run and prepare splits
47
  allow = [
48
  "images.zip",
49
- "images/*.jpg",
50
- "images/*.jpeg",
51
- "images/*.png",
52
  "train.json",
53
  "valid.json",
54
  "test.json",
55
- "fill_in_blank_*.json",
56
- "compatibility_*.txt",
57
  "polyvore_item_metadata.json",
58
  "polyvore_outfit_titles.json",
59
  "categories.csv",
60
  ]
 
61
  ignore = [
62
  "**/*hglmm*",
63
- "disjoint/*",
64
- "nondisjoint/*",
65
- "*/large/*",
 
 
 
66
  ]
67
- snapshot_download(
68
- "Stylique/Polyvore",
69
- repo_type="dataset",
70
- local_dir=root,
71
- local_dir_use_symlinks=False,
72
- allow_patterns=allow,
73
- ignore_patterns=ignore,
74
- )
 
 
 
 
75
  except Exception as e: # pragma: no cover
76
  print(f"Failed to download Stylique/Polyvore dataset: {e}")
77
  return None
 
36
  root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore"))
37
  Path(root).mkdir(parents=True, exist_ok=True)
38
 
39
+ # If images are already present, don't return early; still ensure metadata JSONs exist
40
  _unzip_images_if_needed(root)
 
 
41
 
42
  # Download the HF dataset snapshot into root
43
  try:
44
  # Only fetch what's needed to run and prepare splits
45
  allow = [
46
  "images.zip",
 
 
 
47
  "train.json",
48
  "valid.json",
49
  "test.json",
 
 
50
  "polyvore_item_metadata.json",
51
  "polyvore_outfit_titles.json",
52
  "categories.csv",
53
  ]
54
+ # Explicit ignores to prevent huge downloads (>10GB)
55
  ignore = [
56
  "**/*hglmm*",
57
+ "disjoint/**",
58
+ "nondisjoint/**",
59
+ "*/large/**",
60
+ "**/*.tar",
61
+ "**/*.tar.gz",
62
+ "**/*.7z",
63
  ]
64
+ need_meta = not all(os.path.exists(os.path.join(root, f)) for f in [
65
+ "train.json", "valid.json", "test.json", "categories.csv"
66
+ ])
67
+ if need_meta or not os.path.isdir(os.path.join(root, "images")):
68
+ snapshot_download(
69
+ "Stylique/Polyvore",
70
+ repo_type="dataset",
71
+ local_dir=root,
72
+ local_dir_use_symlinks=False,
73
+ allow_patterns=allow,
74
+ ignore_patterns=ignore,
75
+ )
76
  except Exception as e: # pragma: no cover
77
  print(f"Failed to download Stylique/Polyvore dataset: {e}")
78
  return None