import random |
import skimage.io |
from descartes import PolygonPatch |
from matplotlib.collections import PatchCollection |
from matplotlib.backends.backend_agg import FigureCanvasAgg |
from matplotlib.figure import Figure |
import matplotlib.pyplot as plt |
import numpy as np |
import torch |
import shapely.geometry |
from lydorn_utils import math_utils |
from torch_lydorn import torchvision |
def get_seg_display(seg): |
dtype = seg.dtype |
seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=dtype) |
if len(seg.shape) == 2: |
seg_display[..., 0] = seg |
seg_display[..., 3] = seg |
else: |
for i in range(seg.shape[-1]): |
seg_display[..., i] = seg[..., i] |
clip_max = 255 if dtype == np.uint8 else 1 |
seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, clip_max) |
return seg_display |
def get_tensorboard_image_seg_display(image, seg, crossfield=None): |
assert len(image.shape) == 4 and image.shape[1] == 3, f"image should be (N, 3, H, W), not {image.shape}." |
assert len(seg.shape) == 4 and seg.shape[1] <= 3, f"image should be (N, C<=3, H, W), not {seg.shape}." |
assert image.shape[0] == seg.shape[0], "image and seg should have the same batch size." |
assert image.shape[2] == seg.shape[2], "image and seg should have the same image height." |
assert image.shape[3] == seg.shape[3], "image and seg should have the same image width." |
if crossfield is not None: |
assert len(crossfield.shape) == 4 and crossfield.shape[ |
1] == 4, f"crossfield should be (N, 4, H, W), not {crossfield.shape}." |
assert image.shape[0] == crossfield.shape[0], "image and crossfield should have the same batch size." |
assert image.shape[2] == crossfield.shape[2], "image and crossfield should have the same image height." |
assert image.shape[3] == crossfield.shape[3], "image and crossfield should have the same image width." |
alpha = torch.clamp(torch.sum(seg, dim=1, keepdim=True), 0, 1) |
seg_display = torch.zeros_like(image) |
seg_display[:, :seg.shape[1], ...] = seg |
image_seg_display = (1 - alpha) * image + alpha * seg_display |
image_seg_display = image_seg_display.cpu() |
if crossfield is not None: |
np_crossfield = crossfield.cpu().detach().numpy().transpose(0, 2, 3, 1) |
image_plot_crossfield_list = [get_image_plot_crossfield(_crossfield, crossfield_stride=10) for _crossfield in |
np_crossfield] |
image_plot_crossfield_list = [torchvision.transforms.functional.to_tensor(image_plot_crossfield).float() / 255 |
for image_plot_crossfield in image_plot_crossfield_list] |
image_plot_crossfield = torch.stack(image_plot_crossfield_list, dim=0) |
alpha = image_plot_crossfield[:, 3:4, :, :] |
image_seg_display = (1 - alpha) * image_seg_display + alpha * image_plot_crossfield[:, :3, :, :] |
return image_seg_display |
def plot_crossfield(axis, crossfield, crossfield_stride, alpha=0.5, width=0.5, add_scale=1, invert_y=True): |
x = np.arange(0, crossfield.shape[1], crossfield_stride) |
y = np.arange(0, crossfield.shape[0], crossfield_stride) |
x, y = np.meshgrid(x, y) |
i = y |
if invert_y: |
i = crossfield.shape[0] - 1 - y |
j = x |
scale = add_scale * 1 / crossfield_stride |
c0c2 = crossfield[i, j, :] |
u, v = math_utils.compute_crossfield_uv(c0c2) |
quiveropts = dict(color=(0, 0, 1, alpha), headaxislength=0, headlength=0, pivot='middle', angles="xy", units='xy', |
scale=scale, width=width, headwidth=1) |
axis.quiver(x, y, u.imag, -u.real, **quiveropts) |
axis.quiver(x, y, v.imag, -v.real, **quiveropts) |
def get_image_plot_crossfield(crossfield, crossfield_stride): |
fig = Figure(figsize=(crossfield.shape[1] / 100, crossfield.shape[0] / 100), dpi=100) |
canvas = FigureCanvasAgg(fig) |
ax = fig.gca() |
plot_crossfield(ax, crossfield, crossfield_stride, alpha=1.0, width=2.0, add_scale=1) |
ax.axis('off') |
fig.tight_layout(pad=0) |
ax.margins(0) |
canvas.draw() |
image_from_plot = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8) |
image_from_plot = image_from_plot.reshape(canvas.get_width_height()[::-1] + (4,)) |
image_from_plot = np.roll(image_from_plot, -1, axis=-1) |
mini = image_from_plot.min() |
image_from_plot[:, :, 3] = np.max(255 - image_from_plot[:, :, :3] + mini, axis=2) |
return image_from_plot |
def plot_polygons(axis, polygons, polygon_probs=None, draw_vertices=True, linewidths=2, markersize=10, alpha=0.2, |
color_choices=None): |
if len(polygons) == 0: |
return |
patches = [] |
for i, geometry in enumerate(polygons): |
polygon = shapely.geometry.Polygon(geometry) |
if not polygon.is_empty: |
patch = PolygonPatch(polygon) |
patches.append(patch) |
random.seed(1) |
if color_choices is None: |
color_choices = [ |
[0, 0, 1, 1], |
[0, 1, 0, 1], |
[1, 0, 0, 1], |
[1, 1, 0, 1], |
[1, 0, 1, 1], |
[0, 1, 1, 1], |
[0.5, 1, 0, 1], |
[1, 0.5, 0, 1], |
[0.5, 0, 1, 1], |
[1, 0, 0.5, 1], |
[0, 0.5, 1, 1], |
[0, 1, 0.5, 1], |
] |
colors = random.choices(color_choices, k=len(patches)) |
edgecolors = np.array(colors, dtype=np.float) |
facecolors = edgecolors.copy() |
if polygon_probs is not None: |
facecolors[:, -1] = alpha * np.array(polygon_probs) + 0.1 |
else: |
facecolors[:, -1] = alpha |
p = PatchCollection(patches, facecolors=facecolors, edgecolors=edgecolors, linewidths=linewidths) |
axis.add_collection(p) |
if draw_vertices: |
for i, polygon in enumerate(polygons): |
axis.plot(*polygon.exterior.xy, marker="o", color=edgecolors[i], markersize=markersize) |
for interior in polygon.interiors: |
axis.plot(*interior.xy, marker="o", color=edgecolors[i], markersize=markersize) |
def plot_line_strings(axis, line_strings, draw_vertices=True, linewidths=2, markersize=5): |
artists = [] |
marker = "o" if draw_vertices else None |
for line_string in line_strings: |
artist, = axis.plot(*line_string.xy, marker=marker, markersize=markersize) |
artists.append(artist) |
return artists |
def plot_geometries(axis, geometries, draw_vertices=True, linewidths=2, markersize=3): |
polygons = [] |
line_strings = [] |
for geometry in geometries: |
if isinstance(geometry, shapely.geometry.Polygon): |
polygons.append(geometry) |
elif isinstance(geometry, shapely.geometry.LineString): |
line_strings.append(geometry) |
elif isinstance(geometry, shapely.geometry.MultiLineString): |
for line_string in geometry: |
line_strings.append(line_string) |
else: |
raise NotImplementedError(f"Geometry type {type(geometry)} not implemented") |
if len(polygons): |
plot_polygons(axis, polygons, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize) |
if len(line_strings): |
artists = plot_line_strings(axis, line_strings, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize) |
return artists |
def save_poly_viz(image, polygons, out_filepath, linewidths=2, markersize=20, alpha=0.2, draw_vertices=True, |
corners=None, crossfield=None, polygon_probs=None, seg=None, color_choices=None, dpi=10): |
assert isinstance(polygons, list), f"polygons should be of type list, not {type(polygons)}" |
if len(polygons): |
assert (type(polygons[0]) == np.ndarray or type(polygons[0]) == shapely.geometry.Polygon), \ |
f"Item of the polygons list should be of type ndarray or shapely Polygon, not {type(polygons[0])}" |
if polygon_probs is not None: |
assert type(polygon_probs) == list |
assert len(polygons) == len(polygon_probs), \ |
"len(polygons)={} should be equal to len(polygon_probs)={}".format(len(polygons), len(polygon_probs)) |
height = image.shape[0] |
width = image.shape[1] |
f, axis = plt.subplots(1, 1, figsize=(width / 10, height / 10), dpi=10) |
axis.imshow(image) |
if seg is not None: |
seg *= 0.9 |
axis.imshow(seg) |
if crossfield is not None: |
plot_crossfield(axis, crossfield, crossfield_stride=1, alpha=0.5, width=0.1, add_scale=1.1, invert_y=False) |
plot_polygons(axis, polygons, polygon_probs=polygon_probs, draw_vertices=draw_vertices, linewidths=linewidths, |
markersize=markersize, alpha=alpha, color_choices=color_choices) |
if corners is not None and len(corners): |
assert len(corners[0].shape) == 2 |
for corner_array in corners: |
plt.plot(corner_array[:, 0], corner_array[:, 1], marker="o", linewidth=0, markersize=20, color="red") |
axis.autoscale(False) |
axis.axis('equal') |
axis.axis('off') |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
plt.savefig(out_filepath, transparent=True, dpi=dpi) |
plt.close() |
def main(): |
image = torch.zeros((2, 3, 512, 512)) + 0.5 |
seg = torch.zeros((2, 2, 512, 512)) |
seg[:, 0, 100:200, 100:200] = 1 |
crossfield = torch.zeros((2, 4, 512, 512)) |
u_angle = 0.25 |
v_angle = u_angle + np.pi / 2 |
u = np.cos(u_angle) + 1j * np.sin(u_angle) |
v = np.cos(v_angle) + 1j * np.sin(v_angle) |
c0 = np.power(u, 2) * np.power(v, 2) |
c2 = - (np.power(u, 2) + np.power(v, 2)) |
crossfield[:, 0, :, :] = c0.real |
crossfield[:, 1, :, :] = c0.imag |
crossfield[:, 2, :, :] = c2.real |
crossfield[:, 3, :, :] = c2.imag |
image_seg_display = get_tensorboard_image_seg_display(image, seg, crossfield=crossfield) |
image_seg_display = image_seg_display.cpu().numpy().transpose(0, 2, 3, 1) |
skimage.io.imsave("image_seg_display.png", image_seg_display[0]) |
if __name__ == "__main__": |
main() |