Commit
Β·
dfb5c47
1
Parent(s):
12f6482
Fix app
Browse files- app.py +1 -43
- requirements.txt +1 -1
app.py
CHANGED
@@ -22,48 +22,6 @@ os.system(f'cp {model_inference} .')
|
|
22 |
|
23 |
from inference import process_channel_group, _convert_np_uint8, load_example, run_model
|
24 |
|
25 |
-
def extract_rgb_imgs(input_img, pred_img, channels):
|
26 |
-
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
27 |
-
Args:
|
28 |
-
input_img: input torch.Tensor with shape (C, H, W).
|
29 |
-
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
30 |
-
pred_img: mask torch.Tensor with shape (C, T, H, W).
|
31 |
-
channels: list of indices representing RGB channels.
|
32 |
-
mean: list of mean values for each band.
|
33 |
-
std: list of std values for each band.
|
34 |
-
output_dir: directory where to save outputs.
|
35 |
-
meta_data: list of dicts with geotiff meta info.
|
36 |
-
"""
|
37 |
-
rgb_orig_list = []
|
38 |
-
rgb_mask_list = []
|
39 |
-
rgb_pred_list = []
|
40 |
-
|
41 |
-
for t in range(input_img.shape[1]):
|
42 |
-
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
|
43 |
-
new_img=rec_img[:, t, :, :],
|
44 |
-
channels=channels,
|
45 |
-
mean=mean,
|
46 |
-
std=std)
|
47 |
-
|
48 |
-
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
49 |
-
|
50 |
-
# extract images
|
51 |
-
rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
|
52 |
-
rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
|
53 |
-
rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
|
54 |
-
|
55 |
-
# Add white dummy image values for missing timestamps
|
56 |
-
dummy = np.ones((20, 20), dtype=np.uint8) * 255
|
57 |
-
num_dummies = 4 - len(rgb_orig_list)
|
58 |
-
if num_dummies:
|
59 |
-
rgb_orig_list.extend([dummy] * num_dummies)
|
60 |
-
rgb_mask_list.extend([dummy] * num_dummies)
|
61 |
-
rgb_pred_list.extend([dummy] * num_dummies)
|
62 |
-
|
63 |
-
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
|
64 |
-
|
65 |
-
return outputs
|
66 |
-
|
67 |
|
68 |
def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
|
69 |
try:
|
@@ -81,7 +39,7 @@ def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
|
|
81 |
# Load model ---------------------------------------------------------------------------------
|
82 |
|
83 |
lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
|
84 |
-
img_size =
|
85 |
|
86 |
# Loading data ---------------------------------------------------------------------------------
|
87 |
|
|
|
22 |
|
23 |
from inference import process_channel_group, _convert_np_uint8, load_example, run_model
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
|
27 |
try:
|
|
|
39 |
# Load model ---------------------------------------------------------------------------------
|
40 |
|
41 |
lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
|
42 |
+
img_size = 512 # Size from Sen1Floods11 training
|
43 |
|
44 |
# Loading data ---------------------------------------------------------------------------------
|
45 |
|
requirements.txt
CHANGED
@@ -5,4 +5,4 @@ rasterio
|
|
5 |
einops
|
6 |
huggingface_hub
|
7 |
gradio
|
8 |
-
|
|
|
5 |
einops
|
6 |
huggingface_hub
|
7 |
gradio
|
8 |
+
terratorch==1.0.2
|