Spaces:
Sleeping
Sleeping
File size: 4,763 Bytes
625dedf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import torch
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from tqdm.auto import tqdm
import io
import zipfile
# Assuming you have these functions defined elsewhere
from your_module import preprocess, best_model, DEVICE
def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
tiles = []
with rasterio.open(map_file) as src:
height = src.height
width = src.width
effective_tile_size = tile_size - overlap
for y in tqdm(range(0, height, effective_tile_size)):
for x in range(0, width, effective_tile_size):
batch_images = []
batch_metas = []
for i in range(batch_size):
curr_y = y + (i * effective_tile_size)
if curr_y >= height:
break
window = Window(x, curr_y, tile_size, tile_size)
out_image = src.read(window=window)
if out_image.shape[0] == 1:
out_image = np.repeat(out_image, 3, axis=0)
elif out_image.shape[0] != 3:
raise ValueError("The number of channels in the image is not supported")
out_image = np.transpose(out_image, (1, 2, 0))
tile_image = Image.fromarray(out_image.astype(np.uint8))
out_meta = src.meta.copy()
out_meta.update({
"driver": "GTiff",
"height": tile_size,
"width": tile_size,
"transform": rasterio.windows.transform(window, src.transform)
})
tile_image = np.array(tile_image)
preprocessed_tile = preprocess(image=tile_image)['image']
batch_images.append(preprocessed_tile)
batch_metas.append(out_meta)
if not batch_images:
break
# Concatenate batch images
batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0)
# Perform inference on the batch
with torch.no_grad():
batch_masks = model(batch_tensor.to(DEVICE))
batch_masks = torch.sigmoid(batch_masks)
batch_masks = (batch_masks > 0.6).float()
# Process each mask in the batch
for j, mask_tensor in enumerate(batch_masks):
mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0)
mask_array = mask_resized.squeeze().cpu().numpy()
if mask_array.any() == 1:
tiles.append([mask_array, batch_metas[j]])
return tiles
def main():
st.title("TIF File Processor")
uploaded_file = st.file_uploader("Choose a TIF file", type="tif")
if uploaded_file is not None:
st.write("File uploaded successfully!")
# Process button
if st.button("Process File"):
st.write("Processing...")
# Save the uploaded file temporarily
with open("temp.tif", "wb") as f:
f.write(uploaded_file.getbuffer())
# Process the file
best_model.float()
tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4)
st.write("Processing complete!")
# Prepare zip file for download
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
for i, (mask_array, meta) in enumerate(tiles):
# Save each tile as a separate TIF file
with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst:
dst.write(mask_array, 1)
# Add the tile to the zip file
zip_file.write(f"tile_{i}.tif")
# Offer the zip file for download
st.download_button(
label="Download processed tiles",
data=zip_buffer.getvalue(),
file_name="processed_tiles.zip",
mime="application/zip"
)
if __name__ == "__main__":
main() |