mouadenna commited on
Commit
74cc41d
·
verified ·
1 Parent(s): fc6e330

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -76,7 +76,7 @@ def process_and_predict(image, model):
76
  return mask_image
77
 
78
 
79
- def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
80
  tiles = []
81
 
82
  with rasterio.open(map_file) as src:
@@ -127,7 +127,7 @@ def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
127
  batch_masks = model(batch_tensor.to(DEVICE))
128
 
129
  batch_masks = torch.sigmoid(batch_masks)
130
- batch_masks = (batch_masks > 0.6).float()
131
 
132
  for j, mask_tensor in enumerate(batch_masks):
133
  mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0),
@@ -182,6 +182,23 @@ def main():
182
  if uploaded_file is not None:
183
  st.write("File uploaded successfully!")
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  if st.button("Process File"):
186
  st.write("Processing...")
187
 
@@ -189,7 +206,7 @@ def main():
189
  f.write(uploaded_file.getbuffer())
190
 
191
  best_model.float()
192
- tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4)
193
 
194
  st.write("Processing complete!")
195
 
 
76
  return mask_image
77
 
78
 
79
+ def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4,threshold=0.6):
80
  tiles = []
81
 
82
  with rasterio.open(map_file) as src:
 
127
  batch_masks = model(batch_tensor.to(DEVICE))
128
 
129
  batch_masks = torch.sigmoid(batch_masks)
130
+ batch_masks = (batch_masks > threshold).float()
131
 
132
  for j, mask_tensor in enumerate(batch_masks):
133
  mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0),
 
182
  if uploaded_file is not None:
183
  st.write("File uploaded successfully!")
184
 
185
+ threshold= st.slider(
186
+ 'Select a float value',
187
+ min_value=0.1,
188
+ max_value=0.9,
189
+ value=0.5,
190
+ step=0.05
191
+ )
192
+ overlap= st.slider(
193
+ 'Select a float value',
194
+ min_value=50,
195
+ max_value=150,
196
+ value=100,
197
+ step=25
198
+ )
199
+ st.write('Selected threshold value:', threshold)
200
+ st.write('Selected overlap value:', overlap)
201
+
202
  if st.button("Process File"):
203
  st.write("Processing...")
204
 
 
206
  f.write(uploaded_file.getbuffer())
207
 
208
  best_model.float()
209
+ tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=overlap, batch_size=4,threshold=threshold)
210
 
211
  st.write("Processing complete!")
212