Files changed (1) hide show
  1. config.yaml +0 -202
config.yaml DELETED
@@ -1,202 +0,0 @@
1
- # lightning.pytorch==2.1.1
2
- seed_everything: 42
3
- out_dtype: float32
4
- custom_modules_path: ./../custom_modules/
5
- ### Trainer configuration
6
- trainer:
7
- accelerator: auto
8
- strategy: auto
9
- devices: auto
10
- num_nodes: 1
11
- # precision: 16-mixed
12
- logger:
13
- class_path: TensorBoardLogger
14
- init_args:
15
- save_dir: ./../data/
16
- name: model_runs
17
- callbacks:
18
- - class_path: LearningRateMonitor
19
- init_args:
20
- logging_interval: epoch
21
- - class_path: EarlyStopping
22
- init_args:
23
- monitor: val/loss
24
- patience: 100
25
- max_epochs: 1
26
- check_val_every_n_epoch: 1
27
- log_every_n_steps: 5
28
- enable_checkpointing: true
29
- default_root_dir: ./../data/
30
-
31
- ### Data configuration
32
- data:
33
- class_path: terratorch.datamodules.GenericNonGeoPixelwiseRegressionDataModule
34
- init_args:
35
- batch_size: 8
36
- num_workers: 2
37
- train_transform:
38
- - class_path: albumentations.HorizontalFlip
39
- init_args:
40
- p: 0.5
41
- - class_path: albumentations.RandomCrop
42
- init_args:
43
- height: 42
44
- width: 42
45
- - class_path: albumentations.Rotate
46
- init_args:
47
- limit: 30
48
- border_mode: 0 # cv2.BORDER_CONSTANT
49
- value: 0
50
- # mask_value: 1
51
- p: 0.5
52
- - class_path: ToTensorV2
53
- # Specify all bands which are in the input data.
54
- # -1 are placeholders for bands that are in the data but that we will discard
55
- dataset_bands:
56
- - Oa01_reflectance
57
- - Oa02_reflectance
58
- - Oa03_reflectance
59
- - Oa04_reflectance
60
- - Oa05_reflectance
61
- - Oa06_reflectance
62
- - Oa07_reflectance
63
- - Oa08_reflectance
64
- - Oa09_reflectance
65
- - Oa10_reflectance
66
- - Oa11_reflectance
67
- - Oa12_reflectance
68
- - Oa16_reflectance
69
- - Oa17_reflectance
70
- - Oa18_reflectance
71
- - Oa21_reflectance
72
- - SST
73
- output_bands: #Specify the bands which are used from the input data.
74
- - Oa01_reflectance
75
- - Oa02_reflectance
76
- - Oa03_reflectance
77
- - Oa04_reflectance
78
- - Oa05_reflectance
79
- - Oa06_reflectance
80
- - Oa07_reflectance
81
- - Oa08_reflectance
82
- - Oa09_reflectance
83
- - Oa10_reflectance
84
- - Oa11_reflectance
85
- - Oa12_reflectance
86
- - Oa16_reflectance
87
- - Oa17_reflectance
88
- - Oa18_reflectance
89
- - Oa21_reflectance
90
- rgb_indices:
91
- - 2
92
- - 1
93
- - 0
94
- # Directory roots to training, validation and test datasplits:
95
- test_data_root: ./../data/fine-tuning
96
- test_label_data_root: ./../data/fine-tuning
97
- test_split: ./../data/fine-tuning/test_data.txt
98
- train_data_root: ./../data/fine-tuning
99
- train_label_data_root: ./../data/fine-tuning
100
- train_split: ./../data/fine-tuning/train_data.txt
101
- val_data_root: ./../data/fine-tuning
102
- val_label_data_root: ./../data/fine-tuning
103
- val_split: ./../data/fine-tuning/val_data.txt
104
- img_grep: "*_img.tif"
105
- label_grep: "*_lab.tif"
106
- means: # Mean value of the training dataset per band
107
- - 11378.33724842
108
- - 11379.51141294
109
- - 11291.99698672
110
- - 11116.38807044
111
- - 10898.95680699
112
- - 10686.41604621
113
- - 10466.67864162
114
- - 10456.52999209
115
- - 10462.41327758
116
- - 10464.24100298
117
- - 10443.59591923
118
- - 10448.53157824
119
- - 10470.36129347
120
- - 10454.74328843
121
- - 10453.79858959
122
- - 10452.88001737
123
- stds: # Standard deviation of the training dataset per band
124
- - 3125.36214152
125
- - 3118.65965249
126
- - 3088.88720386
127
- - 3055.0881767
128
- - 3026.73186213
129
- - 2997.72812315
130
- - 2968.12838628
131
- - 2968.75857855
132
- - 2969.94390514
133
- - 2970.39202078
134
- - 2964.1543642
135
- - 2973.0155451
136
- - 2985.89318262
137
- - 2975.50852528
138
- - 2973.00652761
139
- - 2973.00330406
140
- # Nodata value in label data
141
- no_label_replace: -1
142
- # Nodata value in the input data
143
- no_data_replace: 0
144
- ### Model configuration
145
- model:
146
- class_path: terratorch.tasks.PixelwiseRegressionTask
147
- init_args:
148
- model_args:
149
- backbone_pretrained: true
150
- backbone: prithvi_s3_v1
151
- backbone_pretrained_cfg_overlay:
152
- file: ./../data/checkpoints/checkpoint.pt
153
- backbone_pretrain_img_size: 42
154
- backbone_drop_path: 0.1
155
- backbone_bands:
156
- - Oa01_reflectance
157
- - Oa02_reflectance
158
- - Oa03_reflectance
159
- - Oa04_reflectance
160
- - Oa05_reflectance
161
- - Oa06_reflectance
162
- - Oa07_reflectance
163
- - Oa08_reflectance
164
- - Oa09_reflectance
165
- - Oa10_reflectance
166
- - Oa11_reflectance
167
- - Oa12_reflectance
168
- - Oa16_reflectance
169
- - Oa17_reflectance
170
- - Oa18_reflectance
171
- - Oa21_reflectance
172
- head_dropout: 0.16194593880230534
173
- head_channel_list: [64]
174
- necks:
175
- - name: SelectIndices
176
- indices: [2, 5, 8, 11]
177
- - name: ReshapeTokensToImage
178
- - name: LearnedInterpolateToPyramidal
179
- decoder: UNetDecoder
180
- decoder_channels: [256, 128, 64, 32]
181
- head_dropout: 0.1
182
- loss: rmse
183
- ignore_index: -1
184
- freeze_backbone: false
185
- freeze_decoder: false
186
- model_factory: EncoderDecoderFactory
187
- tiled_inference_parameters:
188
- h_crop: 64
189
- h_stride: 4
190
- w_crop: 64
191
- w_stride: 4
192
- delta: 8
193
- average_patches: true
194
- optimizer:
195
- class_path: torch.optim.AdamW
196
- init_args:
197
- lr: 0.00012
198
- weight_decay: 0.3
199
- lr_scheduler:
200
- class_path: ReduceLROnPlateau
201
- init_args:
202
- monitor: val/loss