File size: 4,763 Bytes
625dedf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import torch
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from tqdm.auto import tqdm
import io
import zipfile

# Assuming you have these functions defined elsewhere
from your_module import preprocess, best_model, DEVICE

def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
    tiles = []
    
    with rasterio.open(map_file) as src:
        height = src.height
        width = src.width
        
        effective_tile_size = tile_size - overlap
        
        for y in tqdm(range(0, height, effective_tile_size)):
            for x in range(0, width, effective_tile_size):
                batch_images = []
                batch_metas = []
                
                for i in range(batch_size):
                    curr_y = y + (i * effective_tile_size)
                    if curr_y >= height:
                        break
                    
                    window = Window(x, curr_y, tile_size, tile_size)
                    out_image = src.read(window=window)
                    
                    if out_image.shape[0] == 1:
                        out_image = np.repeat(out_image, 3, axis=0)
                    elif out_image.shape[0] != 3:
                        raise ValueError("The number of channels in the image is not supported")
                    
                    out_image = np.transpose(out_image, (1, 2, 0))
                    tile_image = Image.fromarray(out_image.astype(np.uint8))
                    
                    out_meta = src.meta.copy()
                    out_meta.update({
                        "driver": "GTiff",
                        "height": tile_size,
                        "width": tile_size,
                        "transform": rasterio.windows.transform(window, src.transform)
                    })
                    tile_image = np.array(tile_image)
                    
                    preprocessed_tile = preprocess(image=tile_image)['image']
                    batch_images.append(preprocessed_tile)
                    batch_metas.append(out_meta)
                
                if not batch_images:
                    break
                
                # Concatenate batch images
                batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0)
                # Perform inference on the batch
                with torch.no_grad():
                    batch_masks = model(batch_tensor.to(DEVICE))
                
                batch_masks = torch.sigmoid(batch_masks)
                batch_masks = (batch_masks > 0.6).float()
                
                # Process each mask in the batch
                for j, mask_tensor in enumerate(batch_masks):
                    mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0)
                    
                    mask_array = mask_resized.squeeze().cpu().numpy()
                    
                    if mask_array.any() == 1:
                        tiles.append([mask_array, batch_metas[j]])
    
    return tiles

def main():
    st.title("TIF File Processor")

    uploaded_file = st.file_uploader("Choose a TIF file", type="tif")

    if uploaded_file is not None:
        st.write("File uploaded successfully!")

        # Process button
        if st.button("Process File"):
            st.write("Processing...")
            
            # Save the uploaded file temporarily
            with open("temp.tif", "wb") as f:
                f.write(uploaded_file.getbuffer())
            
            # Process the file
            best_model.float()
            tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4)
            
            st.write("Processing complete!")

            # Prepare zip file for download
            zip_buffer = io.BytesIO()
            with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
                for i, (mask_array, meta) in enumerate(tiles):
                    # Save each tile as a separate TIF file
                    with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst:
                        dst.write(mask_array, 1)
                    
                    # Add the tile to the zip file
                    zip_file.write(f"tile_{i}.tif")

            # Offer the zip file for download
            st.download_button(
                label="Download processed tiles",
                data=zip_buffer.getvalue(),
                file_name="processed_tiles.zip",
                mime="application/zip"
            )

if __name__ == "__main__":
    main()