TerraTorch
Prithvi-EO-2.0-300M-BurnScars / burn_scars_config.yaml
blumenstiel's picture
Upload 3 files
6a5609b verified
# lightning.pytorch==2.4.0
seed_everything: 2
trainer:
logger: true
max_epochs: 100
log_every_n_steps: 1
callbacks:
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 15
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
enable_progress_bar: false
precision: bf16-mixed
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_factory: EncoderDecoderFactory
model_args:
backbone: prithvi_eo_v2_300
backbone_pretrained: true
backbone_bands: ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"]
necks:
- name: SelectIndices
indices: [5, 11, 17, 23]
- name: ReshapeTokensToImage
- name: LearnedInterpolateToPyramidal
decoder: UNetDecoder
decoder_channels: [512, 256, 128, 64]
num_classes: 2
loss: ce
ignore_index: -1
freeze_backbone: false
plot_on_val: false
class_names: [Not burned, Burn scar]
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1.e-4
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
factor: 0.5
patience: 4
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 8
num_workers: 8
dataset_bands: # Dataset bands
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
output_bands: # Model input bands
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: hls_burn_scars/data
val_data_root: hls_burn_scars/data
test_data_root: hls_burn_scars/data
train_split: hls_burn_scars/splits/train.txt
val_split: hls_burn_scars/splits/val.txt
test_split: hls_burn_scars/splits/test.txt
img_grep: "*_merged.tif"
label_grep: "*.mask.tif"
means:
- 0.033349706741586264
- 0.05701185520536176
- 0.05889748132001316
- 0.2323245113436119
- 0.1972854853760658
- 0.11944914225186566
stds:
- 0.02269135568823774
- 0.026807560223070237
- 0.04004109844362779
- 0.07791732423672691
- 0.08708738838140137
- 0.07241979477437814
num_classes: 2
train_transform:
- class_path: albumentations.D4
- class_path: ToTensorV2
test_transform:
- class_path: ToTensorV2
no_data_replace: 0
no_label_replace: -1