import numpy as np
import cv2 as cv
from PIL import Image

def norm_mat(mat):
    return cv.normalize(mat, None, 0, 255, cv.NORM_MINMAX).astype(np.uint8)

def equalize_img(img):
    ycrcb = cv.cvtColor(img, cv.COLOR_BGR2YCrCb)
    ycrcb[:, :, 0] = cv.equalizeHist(ycrcb[:, :, 0])
    return cv.cvtColor(ycrcb, cv.COLOR_YCrCb2BGR)

def create_lut(intensity, gamma):
    lut = np.zeros((256, 1, 3), dtype=np.uint8)
    for i in range(256):
        lut[i, 0, 0] = min(255, max(0, i + intensity))
        lut[i, 0, 1] = min(255, max(0, i + intensity))
        lut[i, 0, 2] = min(255, max(0, i + intensity))
    return lut

def gradient_processing(image, intensity=90, blue_mode="Abs", invert=False, equalize=False):
    image = np.array(image)
    dx, dy = cv.spatialGradient(cv.cvtColor(image, cv.COLOR_BGR2GRAY))
    intensity = int(intensity / 100 * 127)
    if invert:
        dx = (-dx).astype(np.float32)
        dy = (-dy).astype(np.float32)
    else:
        dx = (+dx).astype(np.float32)
        dy = (+dy).astype(np.float32)
    dx_abs = np.abs(dx)
    dy_abs = np.abs(dy)
    red = ((dx / np.max(dx_abs) * 127) + 127).astype(np.uint8)
    green = ((dy / np.max(dy_abs) * 127) + 127).astype(np.uint8)
    if blue_mode == "None":
        blue = np.zeros_like(red)
    elif blue_mode == "Flat":
        blue = np.full_like(red, 255)
    elif blue_mode == "Abs":
        blue = norm_mat(dx_abs + dy_abs)
    elif blue_mode == "Norm":
        blue = norm_mat(np.linalg.norm(cv.merge((red, green)), axis=2))
    else:
        blue = None
    gradient = cv.merge([blue, green, red])
    if equalize:
        gradient = equalize_img(gradient)
    elif intensity > 0:
        gradient = cv.LUT(gradient, create_lut(intensity, intensity))
    return Image.fromarray(gradient)