WeCanopy / generate_tree_images /generate_tree_images.py
babypoby's picture
IM DONE BITCH
3f65192
raw
history blame
7.47 kB
import os
import rasterio
import geopandas as gpd
from shapely.geometry import box
from rasterio.mask import mask
from PIL import Image
from PIL import ImageOps
import numpy as np
import warnings
from rasterio.errors import NodataShadowWarning
import sys
warnings.filterwarnings("ignore", category=NodataShadowWarning)
def cut_trees(output_dir, geojson_path, tif_path):
# create output directory if it doesnt exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Load the GeoDataFrame
gdf = gpd.read_file(geojson_path)
# Clear the terminal screen
# os.system('cls' if os.name == 'nt' else 'clear')
# Open the .tif file
with rasterio.open(tif_path) as src:
# Get the bounds of the .tif image
tif_bounds = box(*src.bounds)
tif_bounds = gpd.GeoDataFrame(geometry=[tif_bounds], crs=gdf.crs)
tif_bounds = tif_bounds['geometry'].iloc[0]
# Get the CRS (Coordinate Reference System) of the .tif image
#tif_crs = src.crs
# Reproject the GeoDataFrame to the CRS of the .tif file
# gdf = gdf.to_crs(tif_crs)
#print(tif_bounds.crs.to_epsg())
#print(gdf.crs.to_epsg())
# Loop through each polygon in the GeoDataFrame
N = len(gdf)
n = int(N/10)
print(f"Processing {N} polygons...")
image_counter = 0
for idx, row in gdf.iterrows():
if idx % n == 0:
progress = f"{round(idx/N*100)} % complete --> {idx}/{N}"
sys.stdout.write('\r' + progress)
sys.stdout.flush()
# Extract the geometry (polygon)
geom = row['geometry']
name = row['id']
# Check if the polygon intersects the image bounds
if geom.intersects(tif_bounds):
# Create a mask for the current polygon
out_image, out_transform = mask(src, [geom], crop=True)
# Convert the masked image to a numpy array
out_image = out_image.transpose(1, 2, 0) # rearrange dimensions for PIL (H, W, C)
# Ensure the array is not empty
if out_image.size == 0:
print("Empty image")
gdf.drop(idx, inplace=True)
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an empty image and will be skipped."
sys.stdout.write('\r' + message)
sys.stdout.flush()
continue
# Remove the zero-padded areas (optional)
mask_array = (out_image[:, :, 0] != src.nodata)
non_zero_rows = np.any(mask_array, axis=1)
non_zero_cols = np.any(mask_array, axis=0)
# Ensure there are non-zero rows and columns
if not np.any(non_zero_rows) or not np.any(non_zero_cols):
print("Non zero rows or columns")
gdf.drop(idx, inplace=True)
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an invalid image area and will be skipped."
sys.stdout.write('\r' + message)
sys.stdout.flush()
continue
out_image = out_image[non_zero_rows][:, non_zero_cols]
# Convert to a PIL Image and save as PNG
out_image = Image.fromarray(out_image.astype(np.uint8)) # Ensure correct type for PIL
output_path = os.path.join(output_dir, f'tree_{name}.png')
out_image.save(output_path)
image_counter += 1
else:
gdf.drop(idx, inplace=True)
print("Does not intersect")
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} is outside the image bounds and will be skipped."
sys.stdout.write('\r' + message)
sys.stdout.flush()
print(len(gdf))
gdf.to_file(geojson_path, driver='GeoJSON')
print(f'\n {image_counter}/{N} Tree images have been successfully saved in the "detected_trees" folder.')
def resize_images(input_folder, output_folder, target_size):
# Create the output folder if it doesn't exist
if not os.path.exists(output_folder):
os.makedirs(output_folder)
counter = 0
# Loop through all files in the input folder
for filename in os.listdir(input_folder):
if filename.endswith('.png'): # Check for PNG files
# Open image
with Image.open(os.path.join(input_folder, filename)) as img:
# Resize image while preserving aspect ratio
#print("Original image size: ", img.size)
img.thumbnail(target_size, Image.LANCZOS)
if img.size[0] < target_size[0] or img.size[1] < target_size[1]:
# Calculate padding dimensions
pad_width = target_size[0] - img.size[0]
pad_height = target_size[1] - img.size[1]
# Calculate padding
padding = (pad_width // 2, pad_height // 2, pad_width - (pad_width // 2), pad_height - (pad_height // 2))
# Pad the image
img = ImageOps.expand(img, padding, fill=(0, 0, 0))
#print ("Resized image size: ", img.size)
# Calculate paste position to center image in canvas
paste_pos = ((target_size[0] - img.size[0]) // 2, (target_size[1] - img.size[1]) // 2)
#print("Paste position: ", paste_pos)
# Create a new blank canvas with the target size and black background
new_img = Image.new("RGBA", target_size, (0, 0, 0, 255))
img = img.convert("RGBA")
# Paste resized image onto the canvas
new_img.paste(img, paste_pos, img)
# Convert to RGB to remove transparency by merging with black background
new_img = new_img.convert("RGB")
# Save resized image to output folder
new_img.save(os.path.join(output_folder, filename))
counter += 1
# Display the counter
if counter % 100 == 0:
message = f"Processed {counter} images"
print(message, end='\r')
# Final message after processing all images
print(f"Processed a total of {counter} images.")
# THIS IS THE FUNCTION TO IMPORT
def generate_tree_images(geojson_path, tif_path, target_size = (224, 224)):
"""
INPUT: geojson path, tif_path that contain the trees, optional target_size of the resulting images
RETURNS: nothing
Action: It creates two folders: + "detected trees" --> the cut tree images
+ "tree_images" --> the processed cut tree images, ready to use for species recognition
"""
# Set input and output folders
folder_cut_trees = "detected_trees"
folder_finished_images = "tree_images"
# Set target size (width, height)
cut_trees(geojson_path = geojson_path, tif_path = tif_path, output_dir = folder_cut_trees)
resize_images(input_folder = folder_cut_trees, output_folder = folder_finished_images, target_size = target_size)