import random import skimage.io from descartes import PolygonPatch from matplotlib.collections import PatchCollection from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.figure import Figure import matplotlib.pyplot as plt import numpy as np import torch import shapely.geometry from lydorn_utils import math_utils from torch_lydorn import torchvision def get_seg_display(seg): dtype = seg.dtype seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=dtype) if len(seg.shape) == 2: seg_display[..., 0] = seg seg_display[..., 3] = seg else: for i in range(seg.shape[-1]): seg_display[..., i] = seg[..., i] clip_max = 255 if dtype == np.uint8 else 1 seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, clip_max) return seg_display def get_tensorboard_image_seg_display(image, seg, crossfield=None): assert len(image.shape) == 4 and image.shape[1] == 3, f"image should be (N, 3, H, W), not {image.shape}." assert len(seg.shape) == 4 and seg.shape[1] <= 3, f"image should be (N, C<=3, H, W), not {seg.shape}." assert image.shape[0] == seg.shape[0], "image and seg should have the same batch size." assert image.shape[2] == seg.shape[2], "image and seg should have the same image height." assert image.shape[3] == seg.shape[3], "image and seg should have the same image width." if crossfield is not None: assert len(crossfield.shape) == 4 and crossfield.shape[ 1] == 4, f"crossfield should be (N, 4, H, W), not {crossfield.shape}." assert image.shape[0] == crossfield.shape[0], "image and crossfield should have the same batch size." assert image.shape[2] == crossfield.shape[2], "image and crossfield should have the same image height." assert image.shape[3] == crossfield.shape[3], "image and crossfield should have the same image width." alpha = torch.clamp(torch.sum(seg, dim=1, keepdim=True), 0, 1) # Add missing seg channels seg_display = torch.zeros_like(image) seg_display[:, :seg.shape[1], ...] = seg image_seg_display = (1 - alpha) * image + alpha * seg_display image_seg_display = image_seg_display.cpu() if crossfield is not None: np_crossfield = crossfield.cpu().detach().numpy().transpose(0, 2, 3, 1) image_plot_crossfield_list = [get_image_plot_crossfield(_crossfield, crossfield_stride=10) for _crossfield in np_crossfield] image_plot_crossfield_list = [torchvision.transforms.functional.to_tensor(image_plot_crossfield).float() / 255 for image_plot_crossfield in image_plot_crossfield_list] image_plot_crossfield = torch.stack(image_plot_crossfield_list, dim=0) alpha = image_plot_crossfield[:, 3:4, :, :] image_seg_display = (1 - alpha) * image_seg_display + alpha * image_plot_crossfield[:, :3, :, :] # image_seg_display = image_plot_crossfield[:, :3, :, :] return image_seg_display def plot_crossfield(axis, crossfield, crossfield_stride, alpha=0.5, width=0.5, add_scale=1, invert_y=True): x = np.arange(0, crossfield.shape[1], crossfield_stride) y = np.arange(0, crossfield.shape[0], crossfield_stride) x, y = np.meshgrid(x, y) i = y if invert_y: i = crossfield.shape[0] - 1 - y j = x scale = add_scale * 1 / crossfield_stride c0c2 = crossfield[i, j, :] u, v = math_utils.compute_crossfield_uv(c0c2) # u_angle = 0.5 # u.real = np.cos(u_angle) # u.imag = np.sin(u_angle) # v *= 0 quiveropts = dict(color=(0, 0, 1, alpha), headaxislength=0, headlength=0, pivot='middle', angles="xy", units='xy', scale=scale, width=width, headwidth=1) axis.quiver(x, y, u.imag, -u.real, **quiveropts) axis.quiver(x, y, v.imag, -v.real, **quiveropts) def get_image_plot_crossfield(crossfield, crossfield_stride): fig = Figure(figsize=(crossfield.shape[1] / 100, crossfield.shape[0] / 100), dpi=100) canvas = FigureCanvasAgg(fig) ax = fig.gca() plot_crossfield(ax, crossfield, crossfield_stride, alpha=1.0, width=2.0, add_scale=1) ax.axis('off') fig.tight_layout(pad=0) # To remove the huge white borders ax.margins(0) canvas.draw() image_from_plot = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8) image_from_plot = image_from_plot.reshape(canvas.get_width_height()[::-1] + (4,)) image_from_plot = np.roll(image_from_plot, -1, axis=-1) # Convert ARGB to RGBA # Fix alpha (white to alpha) # mask = np.sum(image_from_plot[:, :, :3], axis=2) == 3*255 # image_from_plot[mask, 3] = 0 mini = image_from_plot.min() image_from_plot[:, :, 3] = np.max(255 - image_from_plot[:, :, :3] + mini, axis=2) return image_from_plot def plot_polygons(axis, polygons, polygon_probs=None, draw_vertices=True, linewidths=2, markersize=10, alpha=0.2, color_choices=None): if len(polygons) == 0: return patches = [] for i, geometry in enumerate(polygons): polygon = shapely.geometry.Polygon(geometry) if not polygon.is_empty: patch = PolygonPatch(polygon) patches.append(patch) random.seed(1) if color_choices is None: color_choices = [ [0, 0, 1, 1], [0, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 1, 1], [0.5, 1, 0, 1], [1, 0.5, 0, 1], [0.5, 0, 1, 1], [1, 0, 0.5, 1], [0, 0.5, 1, 1], [0, 1, 0.5, 1], ] colors = random.choices(color_choices, k=len(patches)) edgecolors = np.array(colors, dtype=np.float) facecolors = edgecolors.copy() if polygon_probs is not None: facecolors[:, -1] = alpha * np.array(polygon_probs) + 0.1 else: facecolors[:, -1] = alpha p = PatchCollection(patches, facecolors=facecolors, edgecolors=edgecolors, linewidths=linewidths) axis.add_collection(p) if draw_vertices: for i, polygon in enumerate(polygons): axis.plot(*polygon.exterior.xy, marker="o", color=edgecolors[i], markersize=markersize) for interior in polygon.interiors: axis.plot(*interior.xy, marker="o", color=edgecolors[i], markersize=markersize) def plot_line_strings(axis, line_strings, draw_vertices=True, linewidths=2, markersize=5): artists = [] marker = "o" if draw_vertices else None for line_string in line_strings: artist, = axis.plot(*line_string.xy, marker=marker, markersize=markersize) artists.append(artist) return artists def plot_geometries(axis, geometries, draw_vertices=True, linewidths=2, markersize=3): polygons = [] line_strings = [] for geometry in geometries: if isinstance(geometry, shapely.geometry.Polygon): polygons.append(geometry) elif isinstance(geometry, shapely.geometry.LineString): line_strings.append(geometry) elif isinstance(geometry, shapely.geometry.MultiLineString): for line_string in geometry: line_strings.append(line_string) else: raise NotImplementedError(f"Geometry type {type(geometry)} not implemented") if len(polygons): plot_polygons(axis, polygons, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize) if len(line_strings): artists = plot_line_strings(axis, line_strings, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize) return artists def save_poly_viz(image, polygons, out_filepath, linewidths=2, markersize=20, alpha=0.2, draw_vertices=True, corners=None, crossfield=None, polygon_probs=None, seg=None, color_choices=None, dpi=10): assert isinstance(polygons, list), f"polygons should be of type list, not {type(polygons)}" if len(polygons): assert (type(polygons[0]) == np.ndarray or type(polygons[0]) == shapely.geometry.Polygon), \ f"Item of the polygons list should be of type ndarray or shapely Polygon, not {type(polygons[0])}" if polygon_probs is not None: assert type(polygon_probs) == list assert len(polygons) == len(polygon_probs), \ "len(polygons)={} should be equal to len(polygon_probs)={}".format(len(polygons), len(polygon_probs)) # Setup plot height = image.shape[0] width = image.shape[1] f, axis = plt.subplots(1, 1, figsize=(width / 10, height / 10), dpi=10) axis.imshow(image) if seg is not None: seg *= 0.9 axis.imshow(seg) if crossfield is not None: plot_crossfield(axis, crossfield, crossfield_stride=1, alpha=0.5, width=0.1, add_scale=1.1, invert_y=False) plot_polygons(axis, polygons, polygon_probs=polygon_probs, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize, alpha=alpha, color_choices=color_choices) if corners is not None and len(corners): assert len(corners[0].shape) == 2 for corner_array in corners: plt.plot(corner_array[:, 0], corner_array[:, 1], marker="o", linewidth=0, markersize=20, color="red") axis.autoscale(False) axis.axis('equal') axis.axis('off') plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # Plot without margins plt.savefig(out_filepath, transparent=True, dpi=dpi) plt.close() def main(): image = torch.zeros((2, 3, 512, 512)) + 0.5 seg = torch.zeros((2, 2, 512, 512)) seg[:, 0, 100:200, 100:200] = 1 crossfield = torch.zeros((2, 4, 512, 512)) # u_angle = np.random.random(10000) * np.pi # v_angle = np.random.random(10000) * np.pi u_angle = 0.25 v_angle = u_angle + np.pi / 2 u = np.cos(u_angle) + 1j * np.sin(u_angle) v = np.cos(v_angle) + 1j * np.sin(v_angle) c0 = np.power(u, 2) * np.power(v, 2) c2 = - (np.power(u, 2) + np.power(v, 2)) # print("c0:") # print(np.abs(c0).min(), np.abs(c0).mean(), np.abs(c0).max()) # print(c0.real.min(), c0.real.mean(), c0.real.mean()) # print(c0.imag.min(), c0.imag.mean(), c0.imag.max()) # print("c2:") # print(np.abs(c2).min(), np.abs(c2).mean(), np.abs(c2).max()) # print(c2.real.min(), c2.real.mean(), c2.real.max()) # print(c2.real.min(), c2.imag.mean(), c2.imag.max()) crossfield[:, 0, :, :] = c0.real crossfield[:, 1, :, :] = c0.imag crossfield[:, 2, :, :] = c2.real crossfield[:, 3, :, :] = c2.imag image_seg_display = get_tensorboard_image_seg_display(image, seg, crossfield=crossfield) image_seg_display = image_seg_display.cpu().numpy().transpose(0, 2, 3, 1) skimage.io.imsave("image_seg_display.png", image_seg_display[0]) if __name__ == "__main__": main()