Spaces:
Running
Running
Ali Mohsin
commited on
Commit
·
42733e7
1
Parent(s):
fac18b7
fixes
Browse files- inference.py +5 -2
- train_resnet.py +10 -7
- train_vit_triplet.py +8 -5
- utils/data_fetch.py +29 -1
inference.py
CHANGED
|
@@ -82,8 +82,11 @@ class InferenceService:
|
|
| 82 |
def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
|
| 83 |
if len(images) == 0:
|
| 84 |
return []
|
| 85 |
-
batch = torch.stack([self.transform(img) for img in images])
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
| 87 |
emb = nn.functional.normalize(emb, dim=-1)
|
| 88 |
return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 89 |
|
|
|
|
| 82 |
def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
|
| 83 |
if len(images) == 0:
|
| 84 |
return []
|
| 85 |
+
batch = torch.stack([self.transform(img) for img in images])
|
| 86 |
+
batch = batch.to(self.device, memory_format=torch.channels_last)
|
| 87 |
+
use_amp = (self.device == "cuda")
|
| 88 |
+
with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
|
| 89 |
+
emb = self.resnet(batch)
|
| 90 |
emb = nn.functional.normalize(emb, dim=-1)
|
| 91 |
return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 92 |
|
train_resnet.py
CHANGED
|
@@ -27,10 +27,12 @@ def parse_args() -> argparse.Namespace:
|
|
| 27 |
def main() -> None:
|
| 28 |
args = parse_args()
|
| 29 |
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
|
|
|
|
|
|
| 30 |
|
| 31 |
dataset = PolyvoreTripletDataset(args.data_root, split="train")
|
| 32 |
|
| 33 |
-
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=
|
| 34 |
model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
|
| 35 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 36 |
criterion = nn.TripletMarginLoss(margin=0.2, p=2)
|
|
@@ -45,12 +47,13 @@ def main() -> None:
|
|
| 45 |
for batch in loader:
|
| 46 |
# Expect batch as (anchor, positive, negative)
|
| 47 |
anchor, positive, negative = batch
|
| 48 |
-
anchor = anchor.to(device)
|
| 49 |
-
positive = positive.to(device)
|
| 50 |
-
negative = negative.to(device)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
| 54 |
loss = criterion(emb_a, emb_p, emb_n)
|
| 55 |
optimizer.zero_grad(set_to_none=True)
|
| 56 |
loss.backward()
|
|
|
|
| 27 |
def main() -> None:
|
| 28 |
args = parse_args()
|
| 29 |
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 30 |
+
if device == "cuda":
|
| 31 |
+
torch.backends.cudnn.benchmark = True
|
| 32 |
|
| 33 |
dataset = PolyvoreTripletDataset(args.data_root, split="train")
|
| 34 |
|
| 35 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"))
|
| 36 |
model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
|
| 37 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 38 |
criterion = nn.TripletMarginLoss(margin=0.2, p=2)
|
|
|
|
| 47 |
for batch in loader:
|
| 48 |
# Expect batch as (anchor, positive, negative)
|
| 49 |
anchor, positive, negative = batch
|
| 50 |
+
anchor = anchor.to(device, memory_format=torch.channels_last, non_blocking=True)
|
| 51 |
+
positive = positive.to(device, memory_format=torch.channels_last, non_blocking=True)
|
| 52 |
+
negative = negative.to(device, memory_format=torch.channels_last, non_blocking=True)
|
| 53 |
+
with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
|
| 54 |
+
emb_a = model(anchor)
|
| 55 |
+
emb_p = model(positive)
|
| 56 |
+
emb_n = model(negative)
|
| 57 |
loss = criterion(emb_a, emb_p, emb_n)
|
| 58 |
optimizer.zero_grad(set_to_none=True)
|
| 59 |
loss.backward()
|
train_vit_triplet.py
CHANGED
|
@@ -42,13 +42,15 @@ def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device:
|
|
| 42 |
def main() -> None:
|
| 43 |
args = parse_args()
|
| 44 |
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
|
|
|
|
|
|
| 45 |
|
| 46 |
dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
|
| 47 |
|
| 48 |
def collate(batch):
|
| 49 |
return batch # variable length handled inside training loop
|
| 50 |
|
| 51 |
-
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=
|
| 52 |
|
| 53 |
model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
|
| 54 |
embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
|
|
@@ -80,10 +82,11 @@ def main() -> None:
|
|
| 80 |
N = torch.cat(negative_tokens, dim=0)
|
| 81 |
|
| 82 |
# get outfit-level embeddings via ViT encoder pooled output
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
optimizer.zero_grad(set_to_none=True)
|
| 88 |
loss.backward()
|
| 89 |
optimizer.step()
|
|
|
|
| 42 |
def main() -> None:
|
| 43 |
args = parse_args()
|
| 44 |
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 45 |
+
if device == "cuda":
|
| 46 |
+
torch.backends.cudnn.benchmark = True
|
| 47 |
|
| 48 |
dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
|
| 49 |
|
| 50 |
def collate(batch):
|
| 51 |
return batch # variable length handled inside training loop
|
| 52 |
|
| 53 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=(device=="cuda"), collate_fn=collate)
|
| 54 |
|
| 55 |
model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
|
| 56 |
embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
|
|
|
|
| 82 |
N = torch.cat(negative_tokens, dim=0)
|
| 83 |
|
| 84 |
# get outfit-level embeddings via ViT encoder pooled output
|
| 85 |
+
with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")):
|
| 86 |
+
ea = model.encoder(A).mean(dim=1)
|
| 87 |
+
ep = model.encoder(P).mean(dim=1)
|
| 88 |
+
en = model.encoder(N).mean(dim=1)
|
| 89 |
+
loss = triplet(ea, ep, en)
|
| 90 |
optimizer.zero_grad(set_to_none=True)
|
| 91 |
loss.backward()
|
| 92 |
optimizer.step()
|
utils/data_fetch.py
CHANGED
|
@@ -43,7 +43,35 @@ def ensure_dataset_ready() -> Optional[str]:
|
|
| 43 |
|
| 44 |
# Download the HF dataset snapshot into root
|
| 45 |
try:
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
except Exception as e: # pragma: no cover
|
| 48 |
print(f"Failed to download Stylique/Polyvore dataset: {e}")
|
| 49 |
return None
|
|
|
|
| 43 |
|
| 44 |
# Download the HF dataset snapshot into root
|
| 45 |
try:
|
| 46 |
+
# Only fetch what's needed to run and prepare splits
|
| 47 |
+
allow = [
|
| 48 |
+
"images.zip",
|
| 49 |
+
"images/*.jpg",
|
| 50 |
+
"images/*.jpeg",
|
| 51 |
+
"images/*.png",
|
| 52 |
+
"train.json",
|
| 53 |
+
"valid.json",
|
| 54 |
+
"test.json",
|
| 55 |
+
"fill_in_blank_*.json",
|
| 56 |
+
"compatibility_*.txt",
|
| 57 |
+
"polyvore_item_metadata.json",
|
| 58 |
+
"polyvore_outfit_titles.json",
|
| 59 |
+
"categories.csv",
|
| 60 |
+
]
|
| 61 |
+
ignore = [
|
| 62 |
+
"**/*hglmm*",
|
| 63 |
+
"disjoint/*",
|
| 64 |
+
"nondisjoint/*",
|
| 65 |
+
"*/large/*",
|
| 66 |
+
]
|
| 67 |
+
snapshot_download(
|
| 68 |
+
"Stylique/Polyvore",
|
| 69 |
+
repo_type="dataset",
|
| 70 |
+
local_dir=root,
|
| 71 |
+
local_dir_use_symlinks=False,
|
| 72 |
+
allow_patterns=allow,
|
| 73 |
+
ignore_patterns=ignore,
|
| 74 |
+
)
|
| 75 |
except Exception as e: # pragma: no cover
|
| 76 |
print(f"Failed to download Stylique/Polyvore dataset: {e}")
|
| 77 |
return None
|