stasis_derm_calibration / gradio_app.py
nasskall's picture
Upload folder using huggingface_hub
3e96755 verified
import cv2
import gradio as gr
import numpy as np
from sklearn.linear_model import LinearRegression
checker_large_real = 10.8
checker_small_real = 6.35
def check_orientation(image):
#image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
orientation = np.argmax(image.shape)
if orientation == 0:
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
return image, orientation
def get_color_checker_table(data_points, y, yend):
sorted_points = sorted(data_points, key=lambda point: (point[1], point[0]))
differences_y = [sorted_points[0][1] - y] + \
[abs(sorted_points[i][1] - sorted_points[i + 1][1]) for i in range(len(sorted_points) - 1)] + \
[yend - sorted_points[-1][1]]
most_usual_y = 10
local_max = round((yend - y) * 0.2184)
lines = []
last_id = 0
label_upper = differences_y[0] // local_max if differences_y[0] > local_max + 10 else 0
label_lower = differences_y[-1] // local_max if differences_y[-1] > local_max + 10 else 0
for j in range(len(differences_y) - 1):
if differences_y[j] > local_max + 10:
lines.extend([[] for _ in range(label_upper)])
break
for i in range(1, len(differences_y) - 1):
if differences_y[i] > most_usual_y:
lines.append(sorted_points[last_id:i])
last_id = i
if differences_y[-1] < local_max + 10:
lines.append(sorted_points[last_id:])
else:
lines.append(sorted_points[last_id:])
lines.extend([[] for _ in range(label_lower)])
lines = [sorted(line, key=lambda point: point[0]) for line in lines]
return label_upper, label_lower, local_max, lines
def check_points(data_points, x, xend, y, yend, image):
most_usual = int((xend - x) / 7.016)
label_upper, label_lower, usual_y, lines = get_color_checker_table(data_points, y, yend)
for q in lines:
if not q:
continue
differences_x = [q[0][0] - x] + [q[i + 1][0] - q[i][0] for i in range(len(q) - 1)] + [xend - q[-1][0]]
threshold_x = int(most_usual * (1 + 1 / 5.6))
for j, distance in enumerate(differences_x[:-1]):
if distance > threshold_x:
positions = distance // int(most_usual * (1 - 1 / 11.2)) - 1
for t in range(positions):
cnt = (q[j][0] - (t + 1) * most_usual, q[j][1])
# cv2.circle(image, cnt, 5, (255, 0, 0), -1)
data_points.append(cnt)
if differences_x[-1] > threshold_x:
positions = differences_x[-1] // int(most_usual * (1 - 1 / 11.2)) - 1
for t in range(positions):
cnt = (q[-1][0] + (t + 1) * most_usual, q[-1][1])
# cv2.circle(image, cnt, 5, (255, 0, 0), -1)
data_points.append(cnt)
data_points.sort(key=lambda point: (point[1], point[0]))
_, _, _, new_lines = get_color_checker_table(data_points, y, yend)
return label_upper, label_lower, usual_y, image, new_lines, data_points
def get_reference_values(points, image):
values = []
for i in points:
point_value = image[i[1], i[0]]
values.append(point_value)
return values
def detect_RGB_values(image, dst):
x1, y1 = map(round, dst[0][0])
x2, y2 = map(round, dst[2][0])
y2 = max(0, y2)
image_checker = image[y1:y2, x2:x1]
if image_checker.size != 0:
# Apply GaussianBlur to reduce noise and improve edge detection
blurred = cv2.GaussianBlur(image_checker, (5, 5), 0)
# Apply edge detection
edges = cv2.Canny(blurred, 50, 120)
# Find contours
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
centers = [
(x + w // 2 + x2, y + h // 2 + y1)
for contour in contours
for x, y, w, h in [cv2.boundingRect(contour)]
if 0.8 < (aspect_ratio := w / float(h)) < 1.2 and (area := cv2.contourArea(contour)) > 100
]
if centers:
# Filter out centers too close to the edges
centers = [
center for center in centers
if abs(center[0] - x2) >= (x1 - x2) / 7.29 and abs(center[0] - x1) >= (x1 - x2) / 7.29
]
if centers:
label_upper, label_lower, usual, image, new_lines, M_T = check_points(centers, x2, x1, y1, y2, image)
else:
label_upper, label_lower, M_T = 0, 0, []
else:
label_upper, label_lower, M_T = 0, 0, []
else:
label_upper, label_lower, M_T = 0, 0, []
M_R = np.array([
[52, 52, 52], [85, 85, 85], [122, 122, 121], [160, 160, 160],
[200, 200, 200], [243, 243, 242], [8, 133, 161], [187, 86, 149],
[231, 199, 31], [175, 54, 60], [70, 148, 73], [56, 61, 150],
[224, 163, 46], [157, 188, 64], [94, 60, 108], [193, 90, 99],
[80, 91, 166], [214, 126, 44], [103, 189, 170], [133, 128, 177],
[87, 108, 67], [98, 122, 157], [194, 150, 130], [115, 82, 68]
])
if len(M_T) < 24:
for i in range(label_upper):
new_lines[0] = [(x, y - round(usual)) for x, y in new_lines[1]]
for j in range(label_lower):
new_lines[-1] = [(x, y + round(usual)) for x, y in new_lines[-2]]
if len(M_T) != 24:
new_lines = []
M_T = [point for sublist in new_lines for point in sublist]
M_T_values = np.array(get_reference_values(M_T, image))
return M_T_values, M_R
css = ".input_image {height: 10% !important; width: 10% !important;}"
def detect_template(image, orientation):
MIN_MATCH_COUNT = 10
template_path = 'template_img.png'
template_image = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image_c = image.copy()
# Initiate SIFT detector
sift = cv2.SIFT_create()
keypoints1, descriptors1 = sift.detectAndCompute(template_image, None)
keypoints2, descriptors2 = sift.detectAndCompute(gray_image, None)
# FLANN parameters
index_params = dict(algorithm=1, trees=5)
search_params = dict(checks=50)
flann = cv2.FlannBasedMatcher(index_params, search_params)
matches = flann.knnMatch(descriptors1, descriptors2, k=2)
# Apply ratio test
good_matches = [m for m, n in matches if m.distance < 0.7 * n.distance]
if len(good_matches) > MIN_MATCH_COUNT:
src_points = np.float32([keypoints1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_points = np.float32([keypoints2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
M, mask = cv2.findHomography(src_points, dst_points, cv2.RANSAC, 5.0)
h, w = template_image.shape
template_corners = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]]).reshape(-1, 1, 2)
dst_corners = cv2.perspectiveTransform(template_corners, M)
x1, y1 = map(round, dst_corners[0][0])
x2, y2 = map(round, dst_corners[2][0])
# if orientation == 0:
# checker_large_im = abs(y2 - y1)
# checker_small_im = abs(x2 - x1)
# else:
checker_small_im = abs(y2 - y1)
checker_large_im = abs(x2 - x1)
if checker_small_im != 0 and checker_large_im != 0:
px_cm_ratio_small = checker_small_real / checker_small_im
px_cm_ratio_large = checker_large_real / checker_large_im
else:
px_cm_ratio_small = 0
px_cm_ratio_large = 0
annotated_image = cv2.polylines(image_c, [np.int32(dst_corners)], True, 255, 3, cv2.LINE_AA)
if orientation == 0:
annotated_image = cv2.rotate(annotated_image, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
print(f"Not enough matches are found - {len(good_matches)}/{MIN_MATCH_COUNT}")
return None, 0, 0
if orientation ==0:
cm_per_pixel_width = px_cm_ratio_small
cm_per_pixel_height = px_cm_ratio_large
else:
cm_per_pixel_width = px_cm_ratio_large
cm_per_pixel_height = px_cm_ratio_small
return annotated_image, dst_corners, cm_per_pixel_width, cm_per_pixel_height,checker_small_im,checker_large_im
def srgb_to_linear(rgb):
rgb = rgb / 255.0
linear_rgb = np.where(rgb <= 0.04045, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4)
return linear_rgb
def linear_to_srgb(linear_rgb):
# Clip linear_rgb to ensure no negative values
linear_rgb = np.clip(linear_rgb, 0, 1)
srgb = np.where(linear_rgb <= 0.0031308, linear_rgb * 12.92, 1.055 * (linear_rgb ** (1 / 2.4)) - 0.055)
srgb = np.clip(srgb * 255, 0, 255)
return srgb.astype(np.uint8)
def calculate_color_correction_matrix_ransac(sample_rgb, reference_rgb):
sample_rgb = sample_rgb[::-1]
sample_rgb_linear = srgb_to_linear(sample_rgb)
reference_rgb_linear = srgb_to_linear(reference_rgb)
# Reshape the data for RANSAC
X = sample_rgb_linear
y = reference_rgb_linear
# Initialize RANSAC regressor for each color channel
models = []
scores = []
for i in range(3): # For each RGB channel
ransac = LinearRegression()
ransac.fit(X, y[:, i])
scores.append(ransac.score(X, y[:, i]))
models.append(ransac.coef_)
score = np.mean(scores)
# Stack coefficients to form the transformation matrix
M = np.stack(models, axis=-1)
return M, score
def apply_color_correction(image, M):
image_linear = srgb_to_linear(image)
corrected_image_linear = np.dot(image_linear, M)
corrected_image_srgb = linear_to_srgb(corrected_image_linear)
return corrected_image_srgb
def calibrate_img(img):
image, orientation = check_orientation(img)
annotated_image, polygon, px_width, px_height,small_side,large_side = detect_template(image, orientation)
a, b = detect_RGB_values(image, polygon)
if len(a) == 24:
M, score = calculate_color_correction_matrix_ransac(a, b)
if orientation == 0:
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
corrected_image = apply_color_correction(image, M)
#corrected_image = cv2.cvtColor(corrected_image, cv2.COLOR_BGR2RGB)
if orientation == 0:
width= small_side
height= large_side
else:
width = large_side
height = small_side
return annotated_image, corrected_image, px_width, px_height, width, height
def process_img(img):
return calibrate_img(img)
app = gr.Interface(
fn=process_img,
inputs=gr.Image(label="Input"),
css=css,
outputs=[gr.Image(label="Output"), gr.Image(label="Corrected"), gr.Label(label='Cm/px for Width'),
gr.Label(label='Cm/px for Height'), gr.Label(label='Checker Width'),
gr.Label(label='Checker Height'),],
allow_flagging='never')
app.launch(share=True)