""" ðŸĶī core-dino | YOLO Backbone Wrapper for Feature Extraction 🔍 Wraps a YOLO model to extract intermediate feature maps for DINO-style self-supervised training. Optionally applies an MLP projection head. Author: Gajesh Ladhar 🔗 LinkedIn: https://www.linkedin.com/in/gajeshladhar/ ðŸĪ— Hugging Face: https://huggingface.co/gajeshladhar """ import torch import torch.nn as nn from ultralytics import YOLO class YOLOBackBone(nn.Module): """ ðŸ§Đ Extracts multi-scale spatial features from YOLO backbone. Args: model_path (str): Path to YOLO weights (.pt) stop_at (int): Layer index to cut the model use_mlp (bool): Whether to apply MLP projection head mlp_dim (int): Output dim of MLP head (if enabled) """ def __init__(self, model_path='yolo11x.pt', stop_at=23, use_mlp=True, mlp_dim=512): super().__init__() raw_model = YOLO(model_path).model.train() self.layers = nn.ModuleList(raw_model.model[:stop_at]) self.layer_defs = raw_model.yaml["backbone"] + raw_model.yaml["head"] self.use_mlp = use_mlp if use_mlp: self.init_mlp(self._get_out_channels(self.layers[-1]), mlp_dim) for p in self.parameters(): p.requires_grad = True def _get_out_channels(self, layer): return 768 def init_mlp(self, in_channels, out_channels): self.mlp_head = nn.Identity() # self.mlp_head = nn.Sequential( # nn.Conv2d(in_channels, 2048, 1), # nn.GELU(), # nn.Conv2d(2048, out_channels, 1), # nn.GELU(), # nn.Conv2d(out_channels, in_channels, 1) # ) def apply_mlp(self, x): return self.mlp_head(x) if self.use_mlp else x def forward(self, x): """ 🚀 Forward pass through selected YOLO layers and optional MLP. Args: x (Tensor): Input image tensor (B, C, H, W) Returns: Tensor: Final feature map """ outputs = [] for i, layer in enumerate(self.layers): from_ids = self.layer_defs[i][0] from_ids = [from_ids] if isinstance(from_ids, int) else from_ids inputs = [x if j == -1 else outputs[j] for j in from_ids] x = layer(inputs if len(inputs) > 1 else inputs[0]) outputs.append(x) return self.apply_mlp(x) def count_params(self): total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return total, trainable