mouadenna commited on
Commit
625dedf
·
verified ·
1 Parent(s): e2c1125

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import rasterio
6
+ from rasterio.windows import Window
7
+ from tqdm.auto import tqdm
8
+ import io
9
+ import zipfile
10
+
11
+ # Assuming you have these functions defined elsewhere
12
+ from your_module import preprocess, best_model, DEVICE
13
+
14
+ def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
15
+ tiles = []
16
+
17
+ with rasterio.open(map_file) as src:
18
+ height = src.height
19
+ width = src.width
20
+
21
+ effective_tile_size = tile_size - overlap
22
+
23
+ for y in tqdm(range(0, height, effective_tile_size)):
24
+ for x in range(0, width, effective_tile_size):
25
+ batch_images = []
26
+ batch_metas = []
27
+
28
+ for i in range(batch_size):
29
+ curr_y = y + (i * effective_tile_size)
30
+ if curr_y >= height:
31
+ break
32
+
33
+ window = Window(x, curr_y, tile_size, tile_size)
34
+ out_image = src.read(window=window)
35
+
36
+ if out_image.shape[0] == 1:
37
+ out_image = np.repeat(out_image, 3, axis=0)
38
+ elif out_image.shape[0] != 3:
39
+ raise ValueError("The number of channels in the image is not supported")
40
+
41
+ out_image = np.transpose(out_image, (1, 2, 0))
42
+ tile_image = Image.fromarray(out_image.astype(np.uint8))
43
+
44
+ out_meta = src.meta.copy()
45
+ out_meta.update({
46
+ "driver": "GTiff",
47
+ "height": tile_size,
48
+ "width": tile_size,
49
+ "transform": rasterio.windows.transform(window, src.transform)
50
+ })
51
+ tile_image = np.array(tile_image)
52
+
53
+ preprocessed_tile = preprocess(image=tile_image)['image']
54
+ batch_images.append(preprocessed_tile)
55
+ batch_metas.append(out_meta)
56
+
57
+ if not batch_images:
58
+ break
59
+
60
+ # Concatenate batch images
61
+ batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0)
62
+ # Perform inference on the batch
63
+ with torch.no_grad():
64
+ batch_masks = model(batch_tensor.to(DEVICE))
65
+
66
+ batch_masks = torch.sigmoid(batch_masks)
67
+ batch_masks = (batch_masks > 0.6).float()
68
+
69
+ # Process each mask in the batch
70
+ for j, mask_tensor in enumerate(batch_masks):
71
+ mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0)
72
+
73
+ mask_array = mask_resized.squeeze().cpu().numpy()
74
+
75
+ if mask_array.any() == 1:
76
+ tiles.append([mask_array, batch_metas[j]])
77
+
78
+ return tiles
79
+
80
+ def main():
81
+ st.title("TIF File Processor")
82
+
83
+ uploaded_file = st.file_uploader("Choose a TIF file", type="tif")
84
+
85
+ if uploaded_file is not None:
86
+ st.write("File uploaded successfully!")
87
+
88
+ # Process button
89
+ if st.button("Process File"):
90
+ st.write("Processing...")
91
+
92
+ # Save the uploaded file temporarily
93
+ with open("temp.tif", "wb") as f:
94
+ f.write(uploaded_file.getbuffer())
95
+
96
+ # Process the file
97
+ best_model.float()
98
+ tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4)
99
+
100
+ st.write("Processing complete!")
101
+
102
+ # Prepare zip file for download
103
+ zip_buffer = io.BytesIO()
104
+ with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
105
+ for i, (mask_array, meta) in enumerate(tiles):
106
+ # Save each tile as a separate TIF file
107
+ with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst:
108
+ dst.write(mask_array, 1)
109
+
110
+ # Add the tile to the zip file
111
+ zip_file.write(f"tile_{i}.tif")
112
+
113
+ # Offer the zip file for download
114
+ st.download_button(
115
+ label="Download processed tiles",
116
+ data=zip_buffer.getvalue(),
117
+ file_name="processed_tiles.zip",
118
+ mime="application/zip"
119
+ )
120
+
121
+ if __name__ == "__main__":
122
+ main()