import os import argparse import sys from typing import List import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader # Fix import paths sys.path.append(os.path.dirname(os.path.abspath(__file__))) from data.polyvore import PolyvoreOutfitTripletDataset from models.vit_outfit import OutfitCompatibilityModel from models.resnet_embedder import ResNetItemEmbedder from utils.export import ensure_export_dir from utils.advanced_metrics import AdvancedMetrics, calculate_outfit_compatibility_metrics import json def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")) p.add_argument("--epochs", type=int, default=50) p.add_argument("--batch_size", type=int, default=4) p.add_argument("--lr", type=float, default=5e-4) p.add_argument("--embedding_dim", type=int, default=512) p.add_argument("--triplet_margin", type=float, default=0.5) p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth") p.add_argument("--eval_every", type=int, default=1) p.add_argument("--skip_validation", action="store_true", help="Skip validation for faster training") p.add_argument("--max_samples", type=int, default=5000, help="Maximum number of training samples (for better quality)") p.add_argument("--early_stopping_patience", type=int, default=5, help="Early stopping patience") p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement") p.add_argument("--gradient_clip", type=float, default=1.0, help="Gradient clipping value") p.add_argument("--warmup_epochs", type=int, default=2, help="Learning rate warmup epochs") return p.parse_args() def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device: str, max_len: int = 4) -> torch.Tensor: if len(imgs) == 0: return torch.zeros((max_len, embedder.proj.out_features), device=device) k = min(len(imgs), max_len) x = torch.stack(imgs[:k], dim=0).to(device) with torch.no_grad(): e = embedder(x) # (k, D) tokens = torch.zeros((max_len, e.shape[-1]), device=device) tokens[:k] = e return tokens def main() -> None: args = parse_args() device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") if device == "cuda": torch.backends.cudnn.benchmark = True print(f"🚀 Starting ViT Outfit training on {device}") print(f"📁 Data root: {args.data_root}") print(f"⚙️ Config: {args.epochs} epochs, batch_size={args.batch_size}, lr={args.lr}") # Ensure outfit triplets exist splits_dir = os.path.join(args.data_root, "splits") trip_path = os.path.join(splits_dir, "outfit_triplets_train.json") if not os.path.exists(trip_path): print(f"⚠️ Outfit triplet file not found: {trip_path}") print("🔧 Attempting to prepare dataset...") os.makedirs(splits_dir, exist_ok=True) try: # Try to import and run the prepare script sys.path.append(os.path.join(os.path.dirname(__file__), "scripts")) from prepare_polyvore import main as prepare_main print("✅ Successfully imported prepare_polyvore") # Prepare dataset without random splits prepare_main() print("✅ Dataset preparation completed") except Exception as e: print(f"❌ Failed to prepare dataset: {e}") print("💡 Please ensure the dataset is prepared manually") return else: print(f"✅ Found existing outfit triplets: {trip_path}") try: dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train") # Limit dataset size for faster training/testing max_samples = min(len(dataset), args.max_samples) print(f"🔍 Debug: Original dataset size: {len(dataset)}, max_samples: {args.max_samples}") if len(dataset) > max_samples: dataset.samples = dataset.samples[:max_samples] print(f"📊 Dataset limited to {max_samples} samples for faster training") else: print(f"📊 Dataset loaded: {len(dataset)} samples (no limiting needed)") except Exception as e: print(f"❌ Failed to load dataset: {e}") return def collate(batch): return batch # variable length handled inside training loop loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate) model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device) embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval() for p in embedder.parameters(): p.requires_grad_(False) print(f"🏗️ Models created:") print(f" - ViT Outfit: {model.__class__.__name__}") print(f" - ResNet Embedder: {embedder.__class__.__name__}") print(f"📈 Total parameters: {sum(p.numel() for p in model.parameters()):,}") optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2) triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin) # Learning rate scheduler with warmup total_steps = len(loader) * args.epochs warmup_steps = len(loader) * args.warmup_epochs scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, total_steps=total_steps, pct_start=warmup_steps/total_steps, anneal_strategy='cos' ) export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports") best_loss = float("inf") hist = [] patience_counter = 0 best_epoch = 0 metrics_collector = AdvancedMetrics() print(f"💾 Checkpoints will be saved to: {export_dir}") print(f"🛑 Early stopping patience: {args.early_stopping_patience} epochs") for epoch in range(args.epochs): model.train() running_loss = 0.0 steps = 0 print(f"🔄 Epoch {epoch+1}/{args.epochs}") for batch_idx, batch in enumerate(loader): try: # batch: List[(ga_imgs, gb_imgs, bd_imgs)] anchor_tokens = [] positive_tokens = [] negative_tokens = [] for ga, gb, bd in batch: ta = embed_outfit(ga, embedder, device) tb = embed_outfit(gb, embedder, device) tn = embed_outfit(bd, embedder, device) anchor_tokens.append(ta.unsqueeze(0)) positive_tokens.append(tb.unsqueeze(0)) negative_tokens.append(tn.unsqueeze(0)) A = torch.cat(anchor_tokens, dim=0) # (B, N, D) P = torch.cat(positive_tokens, dim=0) N = torch.cat(negative_tokens, dim=0) # get outfit-level embeddings via ViT encoder pooled output with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"), enabled=(device=="cuda")): ea = model.encoder(A).mean(dim=1) ep = model.encoder(P).mean(dim=1) en = model.encoder(N).mean(dim=1) loss = triplet(ea, ep, en) optimizer.zero_grad(set_to_none=True) loss.backward() # Gradient clipping for stability if args.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) optimizer.step() scheduler.step() # Update learning rate # Collect metrics (simplified for ViT training) # Note: ViT training uses outfit-level embeddings, not classification predictions # So we skip the problematic metrics collection that expects binary targets running_loss += loss.item() steps += 1 if batch_idx % 10 == 0: # More frequent logging print(f" Batch {batch_idx}/{len(loader)}: loss={loss.item():.4f}") except Exception as e: print(f"❌ Error in batch {batch_idx}: {e}") continue # Print final batch completion print(f" ✅ Batch {len(loader)-1}/{len(loader)}: loss={loss.item():.4f}") print(f" 📊 Epoch {epoch+1} completed: {len(loader)} batches processed") avg_loss = running_loss / max(1, steps) # Fast validation with limited samples to prevent hanging val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json") val_loss = None if not args.skip_validation and os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0: try: print(f" 🔍 Starting validation (limited to 50 samples for speed)...") val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid") # Limit validation to first 50 samples to prevent hanging val_samples = val_ds.samples[:50] val_ds.samples = val_samples val_loader = DataLoader(val_ds, batch_size=min(args.batch_size, 8), shuffle=False, num_workers=0, collate_fn=lambda x: x) model.eval() losses = [] with torch.no_grad(): for i, vbatch in enumerate(val_loader): if i >= 10: # Limit to 10 batches max for speed break anchor_tokens = [] positive_tokens = [] negative_tokens = [] for ga, gb, bd in vbatch: ta = embed_outfit(ga, embedder, device) tb = embed_outfit(gb, embedder, device) tn = embed_outfit(bd, embedder, device) anchor_tokens.append(ta.unsqueeze(0)) positive_tokens.append(tb.unsqueeze(0)) negative_tokens.append(tn.unsqueeze(0)) A = torch.cat(anchor_tokens, dim=0) P = torch.cat(positive_tokens, dim=0) N = torch.cat(negative_tokens, dim=0) ea = model.encoder(A).mean(dim=1) ep = model.encoder(P).mean(dim=1) en = model.encoder(N).mean(dim=1) l = triplet(ea, ep, en).item() losses.append(l) val_loss = sum(losses) / max(1, len(losses)) print(f" 📊 Validation loss: {val_loss:.4f} (from {len(losses)} batches)") except Exception as e: print(f" ⚠️ Validation failed: {e}") val_loss = None out_path = args.export if not out_path.startswith("models/"): out_path = os.path.join(export_dir, os.path.basename(args.export)) # Save checkpoint torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, out_path) if val_loss is not None: print(f"✅ Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}") hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)}) # Early stopping logic if val_loss < best_loss - args.min_delta: best_loss = val_loss best_epoch = epoch + 1 patience_counter = 0 best_path = os.path.join(export_dir, "vit_outfit_model_best.pth") torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path) print(f"🏆 New best model saved: {best_path} (val_loss: {val_loss:.4f})") else: patience_counter += 1 print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})") if patience_counter >= args.early_stopping_patience: print(f"🛑 Early stopping triggered after {patience_counter} epochs without improvement") print(f"🏆 Best model was at epoch {best_epoch} with val_loss {best_loss:.4f}") break else: print(f"✅ Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}") hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)}) # Write comprehensive metrics metrics_path = os.path.join(export_dir, "vit_metrics.json") # Get advanced metrics (simplified for ViT training) # Note: ViT training doesn't collect classification metrics, so we create empty metrics advanced_metrics = { "total_predictions": 0, "total_targets": 0, "total_scores": 0, "total_embeddings": 0, "total_outfit_scores": 0 } final_metrics = { "best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "best_epoch": best_epoch, "total_epochs": epoch + 1, "early_stopping_triggered": patience_counter >= args.early_stopping_patience, "patience_counter": patience_counter, "training_config": { "epochs": args.epochs, "batch_size": args.batch_size, "learning_rate": args.lr, "embedding_dim": args.embedding_dim, "triplet_margin": args.triplet_margin, "early_stopping_patience": args.early_stopping_patience, "min_delta": args.min_delta }, "history": hist, "advanced_metrics": advanced_metrics } with open(metrics_path, "w") as f: json.dump(final_metrics, f, indent=2) # Always save a best model (use final model if no validation was done) if best_loss == float("inf"): best_path = os.path.join(export_dir, "vit_outfit_model_best.pth") torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path) print(f"🏆 Final model saved as best: {best_path} (loss: {avg_loss:.4f})") print(f"📊 Training completed! Best val_loss: {best_loss:.4f} at epoch {best_epoch}") print(f"📈 Comprehensive metrics saved to: {metrics_path}") print(f"🔬 Advanced metrics: {advanced_metrics}") if __name__ == "__main__": main()