blumenstiel commited on
Commit
dfb5c47
Β·
1 Parent(s): 12f6482
Files changed (2) hide show
  1. app.py +1 -43
  2. requirements.txt +1 -1
app.py CHANGED
@@ -22,48 +22,6 @@ os.system(f'cp {model_inference} .')
22
 
23
  from inference import process_channel_group, _convert_np_uint8, load_example, run_model
24
 
25
- def extract_rgb_imgs(input_img, pred_img, channels):
26
- """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
27
- Args:
28
- input_img: input torch.Tensor with shape (C, H, W).
29
- rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
30
- pred_img: mask torch.Tensor with shape (C, T, H, W).
31
- channels: list of indices representing RGB channels.
32
- mean: list of mean values for each band.
33
- std: list of std values for each band.
34
- output_dir: directory where to save outputs.
35
- meta_data: list of dicts with geotiff meta info.
36
- """
37
- rgb_orig_list = []
38
- rgb_mask_list = []
39
- rgb_pred_list = []
40
-
41
- for t in range(input_img.shape[1]):
42
- rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
43
- new_img=rec_img[:, t, :, :],
44
- channels=channels,
45
- mean=mean,
46
- std=std)
47
-
48
- rgb_mask = mask_img[channels, t, :, :] * rgb_orig
49
-
50
- # extract images
51
- rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
52
- rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
53
- rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
54
-
55
- # Add white dummy image values for missing timestamps
56
- dummy = np.ones((20, 20), dtype=np.uint8) * 255
57
- num_dummies = 4 - len(rgb_orig_list)
58
- if num_dummies:
59
- rgb_orig_list.extend([dummy] * num_dummies)
60
- rgb_mask_list.extend([dummy] * num_dummies)
61
- rgb_pred_list.extend([dummy] * num_dummies)
62
-
63
- outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
64
-
65
- return outputs
66
-
67
 
68
  def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
69
  try:
@@ -81,7 +39,7 @@ def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
81
  # Load model ---------------------------------------------------------------------------------
82
 
83
  lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
84
- img_size = 256 # Size of Sen1Floods11
85
 
86
  # Loading data ---------------------------------------------------------------------------------
87
 
 
22
 
23
  from inference import process_channel_group, _convert_np_uint8, load_example, run_model
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
27
  try:
 
39
  # Load model ---------------------------------------------------------------------------------
40
 
41
  lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
42
+ img_size = 512 # Size from Sen1Floods11 training
43
 
44
  # Loading data ---------------------------------------------------------------------------------
45
 
requirements.txt CHANGED
@@ -5,4 +5,4 @@ rasterio
5
  einops
6
  huggingface_hub
7
  gradio
8
- git+https://github.com/IBM/terratorch.git
 
5
  einops
6
  huggingface_hub
7
  gradio
8
+ terratorch==1.0.2