Ali Mohsin commited on
Commit
42733e7
·
1 Parent(s): fac18b7
Files changed (4) hide show
  1. inference.py +5 -2
  2. train_resnet.py +10 -7
  3. train_vit_triplet.py +8 -5
  4. utils/data_fetch.py +29 -1
inference.py CHANGED
@@ -82,8 +82,11 @@ class InferenceService:
82
  def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
83
  if len(images) == 0:
84
  return []
85
- batch = torch.stack([self.transform(img) for img in images]).to(self.device)
86
- emb = self.resnet(batch)
 
 
 
87
  emb = nn.functional.normalize(emb, dim=-1)
88
  return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
89
 
 
82
  def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
83
  if len(images) == 0:
84
  return []
85
+ batch = torch.stack([self.transform(img) for img in images])
86
+ batch = batch.to(self.device, memory_format=torch.channels_last)
87
+ use_amp = (self.device == "cuda")
88
+ with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
89
+ emb = self.resnet(batch)
90
  emb = nn.functional.normalize(emb, dim=-1)
91
  return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
92
 
train_resnet.py CHANGED
@@ -27,10 +27,12 @@ def parse_args() -> argparse.Namespace:
27
  def main() -> None:
28
  args = parse_args()
29
  device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
 
 
30
 
31
  dataset = PolyvoreTripletDataset(args.data_root, split="train")
32
 
33
- loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
34
  model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
35
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
36
  criterion = nn.TripletMarginLoss(margin=0.2, p=2)
@@ -45,12 +47,13 @@ def main() -> None:
45
  for batch in loader:
46
  # Expect batch as (anchor, positive, negative)
47
  anchor, positive, negative = batch
48
- anchor = anchor.to(device)
49
- positive = positive.to(device)
50
- negative = negative.to(device)
51
- emb_a = model(anchor)
52
- emb_p = model(positive)
53
- emb_n = model(negative)
 
54
  loss = criterion(emb_a, emb_p, emb_n)
55
  optimizer.zero_grad(set_to_none=True)
56
  loss.backward()
 
27
  def main() -> None:
28
  args = parse_args()
29
  device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
30
+ if device == "cuda":
31
+ torch.backends.cudnn.benchmark = True
32
 
33
  dataset = PolyvoreTripletDataset(args.data_root, split="train")
34
 
35
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"))
36
  model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
37
  optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
38
  criterion = nn.TripletMarginLoss(margin=0.2, p=2)
 
47
  for batch in loader:
48
  # Expect batch as (anchor, positive, negative)
49
  anchor, positive, negative = batch
50
+ anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
51
+ positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
52
+ negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
53
+ with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
54
+ emb_a = model(anchor)
55
+ emb_p = model(positive)
56
+ emb_n = model(negative)
57
  loss = criterion(emb_a, emb_p, emb_n)
58
  optimizer.zero_grad(set_to_none=True)
59
  loss.backward()
train_vit_triplet.py CHANGED
@@ -42,13 +42,15 @@ def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device:
42
  def main() -> None:
43
  args = parse_args()
44
  device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
 
 
45
 
46
  dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
47
 
48
  def collate(batch):
49
  return batch # variable length handled inside training loop
50
 
51
- loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
52
 
53
  model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
54
  embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
@@ -80,10 +82,11 @@ def main() -> None:
80
  N = torch.cat(negative_tokens, dim=0)
81
 
82
  # get outfit-level embeddings via ViT encoder pooled output
83
- ea = model.encoder(A).mean(dim=1)
84
- ep = model.encoder(P).mean(dim=1)
85
- en = model.encoder(N).mean(dim=1)
86
- loss = triplet(ea, ep, en)
 
87
  optimizer.zero_grad(set_to_none=True)
88
  loss.backward()
89
  optimizer.step()
 
42
  def main() -> None:
43
  args = parse_args()
44
  device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
45
+ if device == "cuda":
46
+ torch.backends.cudnn.benchmark = True
47
 
48
  dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
49
 
50
  def collate(batch):
51
  return batch # variable length handled inside training loop
52
 
53
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"), collate_fn=collate)
54
 
55
  model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
56
  embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
 
82
  N = torch.cat(negative_tokens, dim=0)
83
 
84
  # get outfit-level embeddings via ViT encoder pooled output
85
+ with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
86
+ ea = model.encoder(A).mean(dim=1)
87
+ ep = model.encoder(P).mean(dim=1)
88
+ en = model.encoder(N).mean(dim=1)
89
+ loss = triplet(ea, ep, en)
90
  optimizer.zero_grad(set_to_none=True)
91
  loss.backward()
92
  optimizer.step()
utils/data_fetch.py CHANGED
@@ -43,7 +43,35 @@ def ensure_dataset_ready() -> Optional[str]:
43
 
44
  # Download the HF dataset snapshot into root
45
  try:
46
- snapshot_download("Stylique/Polyvore", repo_type="dataset", local_dir=root, local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  except Exception as e: # pragma: no cover
48
  print(f"Failed to download Stylique/Polyvore dataset: {e}")
49
  return None
 
43
 
44
  # Download the HF dataset snapshot into root
45
  try:
46
+ # Only fetch what's needed to run and prepare splits
47
+ allow = [
48
+ "images.zip",
49
+ "images/*.jpg",
50
+ "images/*.jpeg",
51
+ "images/*.png",
52
+ "train.json",
53
+ "valid.json",
54
+ "test.json",
55
+ "fill_in_blank_*.json",
56
+ "compatibility_*.txt",
57
+ "polyvore_item_metadata.json",
58
+ "polyvore_outfit_titles.json",
59
+ "categories.csv",
60
+ ]
61
+ ignore = [
62
+ "**/*hglmm*",
63
+ "disjoint/*",
64
+ "nondisjoint/*",
65
+ "*/large/*",
66
+ ]
67
+ snapshot_download(
68
+ "Stylique/Polyvore",
69
+ repo_type="dataset",
70
+ local_dir=root,
71
+ local_dir_use_symlinks=False,
72
+ allow_patterns=allow,
73
+ ignore_patterns=ignore,
74
+ )
75
  except Exception as e: # pragma: no cover
76
  print(f"Failed to download Stylique/Polyvore dataset: {e}")
77
  return None