"""📦 core-dino | Data Loader for Self-Supervised DINO Training on Core-Five 🚀 This module defines the `DinoDataset` which streams multi-resolution satellite patches from the Core-Five dataset, preparing teacher-student views for resolution-agnostic self-supervised learning. """ import os import io import time import torch import random import requests import numpy as np import geopandas as gpd import h5py import xarray as xr from torch import nn from torch.utils.data import Dataset import albumentations as A import fsspec from utils import ( shared_store, process_pool, write_last_updated, AddPoissonNoise, AddSaltPepperNoise ) class DinoDataset(Dataset): """ 🧠 DinoDataset — resolution-agnostic loader for Core-Five 🌍 Streams random crops of HR satellite images from Hugging Face, creates clean (teacher) and augmented (student) views using Albumentations & torch. --- 👤 Author: Gajesh Ladhar 🔗 LinkedIn: 🔗 https://www.linkedin.com/in/gajeshladhar/ 🤗 Hugging Face: 🤗 https://huggingface.co/gajeshladhar """ def __init__(self, imgsz, batch_size=1, queue_size=50): """ 📐 Init the dataset with remote Core-Five metadata and start async patch fetching. Args: imgsz (int): Patch size (min 320 recommended) batch_size (int): Number of patches per batch queue_size (int): Max queue length for shared store """ if imgsz < 320: raise ValueError("❗️imgsz must be ≥ 320 for stable patch extraction — got {}".format(imgsz)) self.imgsz = imgsz metadata_url = "https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/metadata.parquet" self.df_metadata = gpd.read_parquet(fsspec.open(metadata_url).open()) self.batch_size = batch_size self.queue_size = queue_size self.store = shared_store for _ in range(6): process_pool.submit(self.fetch_and_store) @staticmethod def transform(batch): """ 🎛️ Apply augmentation pipeline to simulate degraded inputs for student; teacher gets clean view. Maintains shape consistency. Returns: Dict with 'student' and 'teacher' uint8 tensors """ augment_satellite = A.Compose([ A.GaussNoise(std_range=(0.01, 0.1), p=0.3), AddPoissonNoise(p=0.3), AddSaltPepperNoise(amount=0.02, p=0.3), A.MultiplicativeNoise(multiplier=(0.9, 1.1), elementwise=True, p=0.3), A.MotionBlur(blur_limit=(3, 11), p=0.3), A.GaussianBlur(blur_limit=(3, 11), p=0.3), A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=0.1), A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.3), A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30, p=0.3), A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30, val_shift_limit=30, p=0.3), A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.2), A.CoarseDropout(num_holes_range=(1, 4), hole_height_range=(0.05, 0.2), hole_width_range=(0.05, 0.2), fill='random_uniform', p=0.1) ]) imgsz_half = batch[0].shape[-1] size = np.random.choice(np.arange(32 * 10, imgsz_half, 32)) student, teacher = [], [] for img in batch: student_data = nn.Upsample(size=size, mode='bilinear')(torch.tensor(img[np.newaxis, :]))[0].data.numpy().astype("uint8") student_data = augment_satellite(image=student_data.transpose(1, 2, 0))['image'].transpose(2, 0, 1) student.append(torch.tensor(student_data)) teacher.append(torch.tensor(img)) return { "student": torch.stack(student).to(torch.uint8), "teacher": torch.stack(teacher).to(torch.uint8) } def fetch_and_store(self): """ 🔄 Continuously samples random crops from Core-Five, augments them via `transform`, and updates the shared queue for training. """ np.random.seed(int.from_bytes(os.urandom(4), 'little')) while True: try: batch = [] for _ in range(self.batch_size): path = os.path.join("https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/", self.df_metadata.sample(n=1).path.iloc[0]) buffer = io.BytesIO(requests.get(path, headers={"User-Agent": "Mozilla/5.0"}).content) with h5py.File(buffer, "r") as f: x = f["hr/x"][:] y = f["hr/y"][:] data = f["/hr/data"][:] bands = list(range(data.shape[0])) ds = xr.DataArray(data, dims=['band', 'y', 'x'], coords=[bands, y, x]).astype("uint8") imgsz_half = self.imgsz // 2 yid = np.random.randint(imgsz_half, len(ds.y) - imgsz_half) xid = np.random.randint(imgsz_half, len(ds.x) - imgsz_half) ds = ds.isel(y=range(yid - imgsz_half, yid + imgsz_half), x=range(xid - imgsz_half, xid + imgsz_half)).compute() ds['y'], ds['x'] = np.linspace(ds.y.values[0], ds.y.values[-1], ds.shape[1]), \ np.linspace(ds.x.values[0], ds.x.values[-1], ds.shape[2]) batch.append(ds.data) result = DinoDataset.transform(batch) if len(self.store) >= self.queue_size: index = np.random.randint(0, self.queue_size - 1) self.store[index] = result else: self.store.append(result) # enable for getting recent updates if np.random.random() < 0.20: write_last_updated() except KeyboardInterrupt: break except Exception as e: print("ERROR:", e) continue if __name__=="__main__": dataset = DinoDataset(imgsz=1696,batch_size=3,queue_size=1000) while True : print(len(dataset.store)) time.sleep(5)