--- library_name: project-lighter tags: - lighter - model_hub_mixin - pytorch_model_hub_mixin language: en license: apache-2.0 arxiv: 2501.09001 --- # CT-FM SegResNet This model is a SegResNet containing the weights of the pre-trained CT-FM, using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons. ## Running instructions # CT-FM SegResNet Fine-tuning This notebook demonstrates how to: 1. Load a SSL pre-trained model into a SegResNet 2. Recommended preprocessing and postprocessing steps that were used during pre-training 3. Finetuning instructions overview ## Setup Install requirements and import necessary packages ```python # Install lighter_zoo package %pip install lighter_zoo -U -qq ``` Note: you may need to restart the kernel to use updated packages. ```python ``` ```python # Imports import torch from lighter_zoo import SegResNet from monai.transforms import ( Compose, LoadImage, EnsureType, Orientation, ScaleIntensityRange, CropForeground, Invert, Activations, AsDiscrete, KeepLargestConnectedComponent, SaveImage ) from monai.inferers import SlidingWindowInferer ``` ## Load Model Download and initialize the pre-trained model from HuggingFace Hub ```python # Load pre-trained model model = SegResNet.from_pretrained( "project-lighter/ct_fm_segresnet" ) ``` ## Setup Processing Pipelines Define preprocessing and postprocessing transforms ```python # Preprocessing pipeline preprocess = Compose([ LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension EnsureType(), # Ensure correct data type Orientation(axcodes="SPL"), # Standardize orientation # Scale intensity to [0,1] range, clipping outliers ScaleIntensityRange( a_min=-1024, # Min HU value a_max=2048, # Max HU value b_min=0, # Target min b_max=1, # Target max clip=True # Clip values outside range ), CropForeground() # Remove background to reduce computation ]) ``` monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5. ## Run Inference Process an input CT scan and extract features ```python # Configure sliding window inference inferer = SlidingWindowInferer( roi_size=[96, 160, 160], # Size of patches to process sw_batch_size=2, # Number of windows to process in parallel overlap=0.625, # Overlap between windows (reduces boundary artifacts) mode="gaussian" # Gaussian weighting for overlap regions ) # Input path input_path = "/home/suraj/Repositories/semantic-search-app/assets/scans/s0114.nii.gz" # Preprocess input input_tensor = preprocess(input_path) # Run inference with torch.no_grad(): model = model.to("cuda") input_tensor = input_tensor.to("cuda") output = inferer(input_tensor.unsqueeze(dim=0), model)[0] output = output.to("cpu") print(output.shape) ``` torch.Size([2, 227, 181, 258]) ## Fine-tuning Instructions The model above does not include a trained decoder, which means the predictions you receive will be nonsensical. However, you can leverage the pre-trained encoder and model architecture to fine-tune on your own datasets—especially if they are small. A simple way to integrate this into your pipeline is to replace the model in your training process with the pre-trained version. For example: ```python model = SegResNet.from_pretrained('project-lighter/ct_fm_segresnet') ``` We recommend using Auto3DSeg in conjunction with our model. For detailed guidance, please refer to the instructions here: https://project-lighter.github.io/CT-FM/replication-guide/downstream/#tumor-segmentation-with-auto3dseg