import streamlit as st import torch import numpy as np from PIL import Image import rasterio from rasterio.windows import Window from tqdm.auto import tqdm import io import zipfile # Assuming you have these functions defined elsewhere from your_module import preprocess, best_model, DEVICE def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4): tiles = [] with rasterio.open(map_file) as src: height = src.height width = src.width effective_tile_size = tile_size - overlap for y in tqdm(range(0, height, effective_tile_size)): for x in range(0, width, effective_tile_size): batch_images = [] batch_metas = [] for i in range(batch_size): curr_y = y + (i * effective_tile_size) if curr_y >= height: break window = Window(x, curr_y, tile_size, tile_size) out_image = src.read(window=window) if out_image.shape[0] == 1: out_image = np.repeat(out_image, 3, axis=0) elif out_image.shape[0] != 3: raise ValueError("The number of channels in the image is not supported") out_image = np.transpose(out_image, (1, 2, 0)) tile_image = Image.fromarray(out_image.astype(np.uint8)) out_meta = src.meta.copy() out_meta.update({ "driver": "GTiff", "height": tile_size, "width": tile_size, "transform": rasterio.windows.transform(window, src.transform) }) tile_image = np.array(tile_image) preprocessed_tile = preprocess(image=tile_image)['image'] batch_images.append(preprocessed_tile) batch_metas.append(out_meta) if not batch_images: break # Concatenate batch images batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0) # Perform inference on the batch with torch.no_grad(): batch_masks = model(batch_tensor.to(DEVICE)) batch_masks = torch.sigmoid(batch_masks) batch_masks = (batch_masks > 0.6).float() # Process each mask in the batch for j, mask_tensor in enumerate(batch_masks): mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0) mask_array = mask_resized.squeeze().cpu().numpy() if mask_array.any() == 1: tiles.append([mask_array, batch_metas[j]]) return tiles def main(): st.title("TIF File Processor") uploaded_file = st.file_uploader("Choose a TIF file", type="tif") if uploaded_file is not None: st.write("File uploaded successfully!") # Process button if st.button("Process File"): st.write("Processing...") # Save the uploaded file temporarily with open("temp.tif", "wb") as f: f.write(uploaded_file.getbuffer()) # Process the file best_model.float() tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4) st.write("Processing complete!") # Prepare zip file for download zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: for i, (mask_array, meta) in enumerate(tiles): # Save each tile as a separate TIF file with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst: dst.write(mask_array, 1) # Add the tile to the zip file zip_file.write(f"tile_{i}.tif") # Offer the zip file for download st.download_button( label="Download processed tiles", data=zip_buffer.getvalue(), file_name="processed_tiles.zip", mime="application/zip" ) if __name__ == "__main__": main()