import os
import rasterio
import geopandas as gpd
from shapely.geometry import box
from rasterio.mask import mask
from PIL import Image
from PIL import ImageOps
import numpy as np
import warnings
from rasterio.errors import NodataShadowWarning
import sys

warnings.filterwarnings("ignore", category=NodataShadowWarning)

def cut_trees(output_dir, geojson_path, tif_path):
    # create output directory if it doesnt exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)


    # Load the GeoDataFrame

    gdf = gpd.read_file(geojson_path)


    # Clear the terminal screen
    # os.system('cls' if os.name == 'nt' else 'clear')

    # Open the .tif file
    with rasterio.open(tif_path) as src:
        # Get the bounds of the .tif image
        tif_bounds = box(*src.bounds)
        tif_bounds = gpd.GeoDataFrame(geometry=[tif_bounds], crs=gdf.crs)
        tif_bounds = tif_bounds['geometry'].iloc[0]
        # Get the CRS (Coordinate Reference System) of the .tif image
        #tif_crs = src.crs
        
        # Reproject the GeoDataFrame to the CRS of the .tif file
       # gdf = gdf.to_crs(tif_crs)
        #print(tif_bounds.crs.to_epsg())
        #print(gdf.crs.to_epsg())

        # Loop through each polygon in the GeoDataFrame
        N = len(gdf)
        n = int(N/10)
        print(f"Processing {N} polygons...")
        image_counter = 0
        for idx, row in gdf.iterrows():
            if idx % n == 0:
                progress = f"{round(idx/N*100)} % complete --> {idx}/{N}"
                sys.stdout.write('\r' + progress)
                sys.stdout.flush()

            # Extract the geometry (polygon)
            geom = row['geometry']
            name = row['id']

    
            
            # Check if the polygon intersects the image bounds
            if geom.intersects(tif_bounds):
                # Create a mask for the current polygon
                out_image, out_transform = mask(src, [geom], crop=True)
                
                # Convert the masked image to a numpy array
                out_image = out_image.transpose(1, 2, 0)  # rearrange dimensions for PIL (H, W, C)

                # Ensure the array is not empty
                if out_image.size == 0:
                    print("Empty image")
                    gdf.drop(idx, inplace=True)
                    message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an empty image and will be skipped."
                    sys.stdout.write('\r' + message)
                    sys.stdout.flush()
                    continue
                
                # Remove the zero-padded areas (optional)
                mask_array = (out_image[:, :, 0] != src.nodata)
                non_zero_rows = np.any(mask_array, axis=1)
                non_zero_cols = np.any(mask_array, axis=0)

                # Ensure there are non-zero rows and columns
                if not np.any(non_zero_rows) or not np.any(non_zero_cols):
                    print("Non zero rows or columns")
                    gdf.drop(idx, inplace=True)
                    message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an invalid image area and will be skipped."
                    sys.stdout.write('\r' + message)
                    sys.stdout.flush()
                    continue

                out_image = out_image[non_zero_rows][:, non_zero_cols]

                # Convert to a PIL Image and save as PNG
                out_image = Image.fromarray(out_image.astype(np.uint8))  # Ensure correct type for PIL
                output_path = os.path.join(output_dir, f'tree_{name}.png')
                out_image.save(output_path)
                image_counter += 1
            else:
                gdf.drop(idx, inplace=True)
                print("Does not intersect")
                message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} is outside the image bounds and will be skipped."
                sys.stdout.write('\r' + message)
                sys.stdout.flush()
    print(len(gdf))
    gdf.to_file(geojson_path, driver='GeoJSON')
    print(f'\n {image_counter}/{N} Tree images have been successfully saved in the "detected_trees" folder.')


def resize_images(input_folder, output_folder, target_size):
    # Create the output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    counter = 0
    # Loop through all files in the input folder
    for filename in os.listdir(input_folder):
        if filename.endswith('.png'):  # Check for PNG files
            # Open image
            with Image.open(os.path.join(input_folder, filename)) as img:
                # Resize image while preserving aspect ratio
                #print("Original image size: ", img.size)

                img.thumbnail(target_size, Image.LANCZOS)

                if img.size[0] < target_size[0] or img.size[1] < target_size[1]:
        
                    # Calculate padding dimensions
                    pad_width = target_size[0] - img.size[0]
                    pad_height = target_size[1] - img.size[1]
        
                    # Calculate padding
                    padding = (pad_width // 2, pad_height // 2, pad_width - (pad_width // 2), pad_height - (pad_height // 2))
        
                    # Pad the image
                    img = ImageOps.expand(img, padding, fill=(0, 0, 0))
                
                #print ("Resized image size: ", img.size)


                # Calculate paste position to center image in canvas
                paste_pos = ((target_size[0] - img.size[0]) // 2, (target_size[1] - img.size[1]) // 2)
                #print("Paste position: ", paste_pos)
                # Create a new blank canvas with the target size and black background
                new_img = Image.new("RGBA", target_size, (0, 0, 0, 255))
                img = img.convert("RGBA")
                # Paste resized image onto the canvas
                new_img.paste(img, paste_pos, img)
                # Convert to RGB to remove transparency by merging with black background
                new_img = new_img.convert("RGB")
                # Save resized image to output folder
                new_img.save(os.path.join(output_folder, filename))
                
            counter += 1
            # Display the counter
            if counter % 100 == 0:
                message = f"Processed {counter} images"
                print(message, end='\r')
    
    # Final message after processing all images
    print(f"Processed a total of {counter} images.")


# THIS IS THE FUNCTION TO IMPORT
def generate_tree_images(geojson_path, tif_path, target_size = (224, 224)):
    """
    INPUT: geojson path, tif_path that contain the trees, optional target_size of the resulting images
    
    RETURNS: nothing
    
    Action: It creates two folders: + "detected trees" --> the cut tree images
                                    + "tree_images" --> the processed cut tree images, ready to use for species recognition
    """


    # Set input and output folders
    folder_cut_trees = "detected_trees"
    folder_finished_images = "tree_images"
    # Set target size (width, height)    
    cut_trees(geojson_path = geojson_path, tif_path = tif_path, output_dir = folder_cut_trees)
    resize_images(input_folder = folder_cut_trees, output_folder = folder_finished_images, target_size = target_size)