surajpaib commited on
Commit
23e05e9
·
verified ·
1 Parent(s): a005e18

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +37 -61
README.md CHANGED
@@ -18,15 +18,13 @@ The backbone is based on a SegResNet, a 3D U-Net variant. If you want to just lo
18
  ## Running instructions
19
 
20
 
21
- # Whole Body Segmentation Inference
22
 
23
  This notebook demonstrates how to:
24
- 1. Load a pre-trained whole body segmentation model from HuggingFace Hub
25
  2. Set up preprocessing and postprocessing pipelines
26
- 3. Perform sliding window inference on CT volumes
27
- 4. Save the segmentation results
28
-
29
- The model segments 118 different anatomical structures from CT scans.
30
 
31
  ## Setup
32
  Install requirements and import necessary packages
@@ -44,55 +42,27 @@ Install requirements and import necessary packages
44
  ```python
45
  # Imports
46
  import torch
47
- from lighter_zoo import SegResNet
48
  from monai.transforms import (
49
  Compose, LoadImage, EnsureType, Orientation,
50
- ScaleIntensityRange, CropForeground, Invert,
51
- Activations, AsDiscrete, KeepLargestConnectedComponent,
52
- SaveImage
53
  )
54
  from monai.inferers import SlidingWindowInferer
55
  ```
56
 
57
- Note: you may need to restart the kernel to use updated packages.
58
-
59
-
60
  ## Load Model
61
  Download and initialize the pre-trained model from HuggingFace Hub
62
 
63
 
64
  ```python
65
  # Load pre-trained model
66
- model = SegResNet.from_pretrained(
67
- "project-lighter/whole_body_segmentation",
68
- force_download=True
69
- )
70
- ```
71
-
72
-
73
- config.json: 0%| | 0.00/162 [00:00<?, ?B/s]
74
-
75
-
76
-
77
- model.safetensors: 0%| | 0.00/349M [00:00<?, ?B/s]
78
-
79
-
80
- ## Configure Inference
81
- Set up sliding window inference for processing large volumes
82
-
83
-
84
- ```python
85
- # Configure sliding window inference
86
- inferer = SlidingWindowInferer(
87
- roi_size=[96, 160, 160], # Size of patches to process
88
- sw_batch_size=2, # Number of windows to process in parallel
89
- overlap=0.625, # Overlap between windows (reduces boundary artifacts)
90
- mode="gaussian" # Gaussian weighting for overlap regions
91
  )
92
  ```
93
 
94
  ## Setup Processing Pipelines
95
- Define preprocessing and postprocessing transforms
96
 
97
 
98
  ```python
@@ -111,24 +81,13 @@ preprocess = Compose([
111
  ),
112
  CropForeground() # Remove background to reduce computation
113
  ])
114
-
115
- # Postprocessing pipeline
116
- postprocess = Compose([
117
- Activations(softmax=True), # Apply softmax to get probabilities
118
- AsDiscrete(argmax=True, dtype=torch.int32), # Convert to class labels
119
- KeepLargestConnectedComponent(), # Remove small disconnected regions
120
- Invert(transform=preprocess), # Restore original space
121
- # Save the result
122
- SaveImage(output_dir="./segmentations")
123
- ])
124
  ```
125
 
126
- /home/suraj/miniconda3/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
127
- warn_deprecated(argname, msg, warning_category)
128
 
129
 
130
  ## Run Inference
131
- Process an input CT scan and generate segmentation
132
 
133
 
134
  ```python
@@ -140,19 +99,36 @@ input_tensor = preprocess(input_path)
140
 
141
  # Run inference
142
  with torch.no_grad():
143
- output = inferer(input_tensor.unsqueeze(dim=0), model)[0]
144
 
145
- # Copy metadata from input
146
- output.applied_operations = input_tensor.applied_operations
147
- output.affine = input_tensor.affine
148
 
149
- # Postprocess and save result
150
- result = postprocess(output[0])
151
- print("✅ Segmentation completed and saved")
152
  ```
153
 
154
- 2025-01-16 18:41:57,674 INFO image_writer.py:197 - writing: /home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/segmentations/0/0_trans.nii.gz
155
- Segmentation completed and saved
 
 
 
 
 
 
 
 
156
 
157
 
 
 
 
 
 
 
 
 
 
 
158
 
 
18
  ## Running instructions
19
 
20
 
21
+ # CT-FM Feature Extractor
22
 
23
  This notebook demonstrates how to:
24
+ 1. Load a SSL pre-trained model
25
  2. Set up preprocessing and postprocessing pipelines
26
+ 3. Perform inference on CT volumes
27
+ 4. Plot distribution of features extracted
 
 
28
 
29
  ## Setup
30
  Install requirements and import necessary packages
 
42
  ```python
43
  # Imports
44
  import torch
45
+ from lighter_zoo import SegResEncoder
46
  from monai.transforms import (
47
  Compose, LoadImage, EnsureType, Orientation,
48
+ ScaleIntensityRange, CropForeground
 
 
49
  )
50
  from monai.inferers import SlidingWindowInferer
51
  ```
52
 
 
 
 
53
  ## Load Model
54
  Download and initialize the pre-trained model from HuggingFace Hub
55
 
56
 
57
  ```python
58
  # Load pre-trained model
59
+ model = SegResEncoder.from_pretrained(
60
+ "project-lighter/ct_fm_feature_extractor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
  ```
63
 
64
  ## Setup Processing Pipelines
65
+ Define preprocessing transforms
66
 
67
 
68
  ```python
 
81
  ),
82
  CropForeground() # Remove background to reduce computation
83
  ])
 
 
 
 
 
 
 
 
 
 
84
  ```
85
 
86
+ monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
 
87
 
88
 
89
  ## Run Inference
90
+ Process an input CT scan and extract features
91
 
92
 
93
  ```python
 
99
 
100
  # Run inference
101
  with torch.no_grad():
102
+ output = model(input_tensor.unsqueeze(0))[-1]
103
 
104
+ # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and
105
+ # use the output tensor directly which will give you the feature maps in a low-dimensional space.
106
+ avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()
107
 
108
+ print("✅ Feature extraction completed")
109
+ print(f"Output shape: {avg_output.shape}")
 
110
  ```
111
 
112
+ Feature extraction completed
113
+ Output shape: torch.Size([512])
114
+
115
+
116
+
117
+ ```python
118
+ # Plot distribution of features
119
+ import matplotlib.pyplot as plt
120
+ _ = plt.hist(avg_output.cpu().numpy(), bins=100)
121
+ ```
122
 
123
 
124
+
125
+ ![png](ct_fm_feature_extractor_files/ct_fm_feature_extractor_10_0.png)
126
+
127
+
128
+
129
+
130
+ ```python
131
+
132
+ ```
133
+
134