ct_fm_segresnet / README.md
surajpaib's picture
Upload README.md with huggingface_hub
6963196 verified
---
library_name: project-lighter
tags:
- lighter
- model_hub_mixin
- pytorch_model_hub_mixin
language: en
license: apache-2.0
arxiv: 2501.09001
---
# CT-FM SegResNet
This model is a SegResNet containing the weights of the pre-trained CT-FM, using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons.
## Running instructions
# CT-FM SegResNet Fine-tuning
This notebook demonstrates how to:
1. Load a SSL pre-trained model into a SegResNet
2. Recommended preprocessing and postprocessing steps that were used during pre-training
3. Finetuning instructions overview
## Setup
Install requirements and import necessary packages
```python
# Install lighter_zoo package
%pip install lighter_zoo -U -qq
```
Note: you may need to restart the kernel to use updated packages.
```python
```
```python
# Imports
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
Compose, LoadImage, EnsureType, Orientation,
ScaleIntensityRange, CropForeground, Invert,
Activations, AsDiscrete, KeepLargestConnectedComponent,
SaveImage
)
from monai.inferers import SlidingWindowInferer
```
## Load Model
Download and initialize the pre-trained model from HuggingFace Hub
```python
# Load pre-trained model
model = SegResNet.from_pretrained(
"project-lighter/ct_fm_segresnet"
)
```
## Setup Processing Pipelines
Define preprocessing and postprocessing transforms
```python
# Preprocessing pipeline
preprocess = Compose([
LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension
EnsureType(), # Ensure correct data type
Orientation(axcodes="SPL"), # Standardize orientation
# Scale intensity to [0,1] range, clipping outliers
ScaleIntensityRange(
a_min=-1024, # Min HU value
a_max=2048, # Max HU value
b_min=0, # Target min
b_max=1, # Target max
clip=True # Clip values outside range
),
CropForeground() # Remove background to reduce computation
])
```
monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
## Run Inference
Process an input CT scan and extract features
```python
# Configure sliding window inference
inferer = SlidingWindowInferer(
roi_size=[96, 160, 160], # Size of patches to process
sw_batch_size=2, # Number of windows to process in parallel
overlap=0.625, # Overlap between windows (reduces boundary artifacts)
mode="gaussian" # Gaussian weighting for overlap regions
)
# Input path
input_path = "/home/suraj/Repositories/semantic-search-app/assets/scans/s0114.nii.gz"
# Preprocess input
input_tensor = preprocess(input_path)
# Run inference
with torch.no_grad():
model = model.to("cuda")
input_tensor = input_tensor.to("cuda")
output = inferer(input_tensor.unsqueeze(dim=0), model)[0]
output = output.to("cpu")
print(output.shape)
```
torch.Size([2, 227, 181, 258])
## Fine-tuning Instructions
The model above does not include a trained decoder, which means the predictions you receive will be nonsensical.
However, you can leverage the pre-trained encoder and model architecture to fine-tune on your own datasets—especially if they are small. A simple way to integrate this into your pipeline is to replace the model in your training process with the pre-trained version. For example:
```python
model = SegResNet.from_pretrained('project-lighter/ct_fm_segresnet')
```
We recommend using Auto3DSeg in conjunction with our model. For detailed guidance, please refer to the instructions here:
https://project-lighter.github.io/CT-FM/replication-guide/downstream/#tumor-segmentation-with-auto3dseg