|
import random |
|
|
|
import skimage.io |
|
from descartes import PolygonPatch |
|
from matplotlib.collections import PatchCollection |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
from matplotlib.figure import Figure |
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
import torch |
|
import shapely.geometry |
|
|
|
from lydorn_utils import math_utils |
|
from torch_lydorn import torchvision |
|
|
|
|
|
def get_seg_display(seg): |
|
dtype = seg.dtype |
|
seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=dtype) |
|
if len(seg.shape) == 2: |
|
seg_display[..., 0] = seg |
|
seg_display[..., 3] = seg |
|
else: |
|
for i in range(seg.shape[-1]): |
|
seg_display[..., i] = seg[..., i] |
|
clip_max = 255 if dtype == np.uint8 else 1 |
|
seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, clip_max) |
|
return seg_display |
|
|
|
|
|
def get_tensorboard_image_seg_display(image, seg, crossfield=None): |
|
assert len(image.shape) == 4 and image.shape[1] == 3, f"image should be (N, 3, H, W), not {image.shape}." |
|
assert len(seg.shape) == 4 and seg.shape[1] <= 3, f"image should be (N, C<=3, H, W), not {seg.shape}." |
|
assert image.shape[0] == seg.shape[0], "image and seg should have the same batch size." |
|
assert image.shape[2] == seg.shape[2], "image and seg should have the same image height." |
|
assert image.shape[3] == seg.shape[3], "image and seg should have the same image width." |
|
if crossfield is not None: |
|
assert len(crossfield.shape) == 4 and crossfield.shape[ |
|
1] == 4, f"crossfield should be (N, 4, H, W), not {crossfield.shape}." |
|
assert image.shape[0] == crossfield.shape[0], "image and crossfield should have the same batch size." |
|
assert image.shape[2] == crossfield.shape[2], "image and crossfield should have the same image height." |
|
assert image.shape[3] == crossfield.shape[3], "image and crossfield should have the same image width." |
|
|
|
alpha = torch.clamp(torch.sum(seg, dim=1, keepdim=True), 0, 1) |
|
|
|
|
|
seg_display = torch.zeros_like(image) |
|
seg_display[:, :seg.shape[1], ...] = seg |
|
|
|
image_seg_display = (1 - alpha) * image + alpha * seg_display |
|
image_seg_display = image_seg_display.cpu() |
|
|
|
if crossfield is not None: |
|
np_crossfield = crossfield.cpu().detach().numpy().transpose(0, 2, 3, 1) |
|
image_plot_crossfield_list = [get_image_plot_crossfield(_crossfield, crossfield_stride=10) for _crossfield in |
|
np_crossfield] |
|
image_plot_crossfield_list = [torchvision.transforms.functional.to_tensor(image_plot_crossfield).float() / 255 |
|
for image_plot_crossfield in image_plot_crossfield_list] |
|
image_plot_crossfield = torch.stack(image_plot_crossfield_list, dim=0) |
|
alpha = image_plot_crossfield[:, 3:4, :, :] |
|
image_seg_display = (1 - alpha) * image_seg_display + alpha * image_plot_crossfield[:, :3, :, :] |
|
|
|
|
|
return image_seg_display |
|
|
|
|
|
def plot_crossfield(axis, crossfield, crossfield_stride, alpha=0.5, width=0.5, add_scale=1, invert_y=True): |
|
x = np.arange(0, crossfield.shape[1], crossfield_stride) |
|
y = np.arange(0, crossfield.shape[0], crossfield_stride) |
|
x, y = np.meshgrid(x, y) |
|
i = y |
|
if invert_y: |
|
i = crossfield.shape[0] - 1 - y |
|
j = x |
|
scale = add_scale * 1 / crossfield_stride |
|
|
|
c0c2 = crossfield[i, j, :] |
|
u, v = math_utils.compute_crossfield_uv(c0c2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
quiveropts = dict(color=(0, 0, 1, alpha), headaxislength=0, headlength=0, pivot='middle', angles="xy", units='xy', |
|
scale=scale, width=width, headwidth=1) |
|
axis.quiver(x, y, u.imag, -u.real, **quiveropts) |
|
axis.quiver(x, y, v.imag, -v.real, **quiveropts) |
|
|
|
|
|
def get_image_plot_crossfield(crossfield, crossfield_stride): |
|
fig = Figure(figsize=(crossfield.shape[1] / 100, crossfield.shape[0] / 100), dpi=100) |
|
canvas = FigureCanvasAgg(fig) |
|
ax = fig.gca() |
|
|
|
plot_crossfield(ax, crossfield, crossfield_stride, alpha=1.0, width=2.0, add_scale=1) |
|
|
|
ax.axis('off') |
|
fig.tight_layout(pad=0) |
|
|
|
ax.margins(0) |
|
|
|
canvas.draw() |
|
image_from_plot = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8) |
|
image_from_plot = image_from_plot.reshape(canvas.get_width_height()[::-1] + (4,)) |
|
image_from_plot = np.roll(image_from_plot, -1, axis=-1) |
|
|
|
|
|
|
|
|
|
mini = image_from_plot.min() |
|
image_from_plot[:, :, 3] = np.max(255 - image_from_plot[:, :, :3] + mini, axis=2) |
|
|
|
return image_from_plot |
|
|
|
|
|
def plot_polygons(axis, polygons, polygon_probs=None, draw_vertices=True, linewidths=2, markersize=10, alpha=0.2, |
|
color_choices=None): |
|
if len(polygons) == 0: |
|
return |
|
patches = [] |
|
for i, geometry in enumerate(polygons): |
|
polygon = shapely.geometry.Polygon(geometry) |
|
if not polygon.is_empty: |
|
patch = PolygonPatch(polygon) |
|
patches.append(patch) |
|
random.seed(1) |
|
if color_choices is None: |
|
color_choices = [ |
|
[0, 0, 1, 1], |
|
[0, 1, 0, 1], |
|
[1, 0, 0, 1], |
|
[1, 1, 0, 1], |
|
[1, 0, 1, 1], |
|
[0, 1, 1, 1], |
|
[0.5, 1, 0, 1], |
|
[1, 0.5, 0, 1], |
|
[0.5, 0, 1, 1], |
|
[1, 0, 0.5, 1], |
|
[0, 0.5, 1, 1], |
|
[0, 1, 0.5, 1], |
|
] |
|
colors = random.choices(color_choices, k=len(patches)) |
|
edgecolors = np.array(colors, dtype=np.float) |
|
facecolors = edgecolors.copy() |
|
if polygon_probs is not None: |
|
facecolors[:, -1] = alpha * np.array(polygon_probs) + 0.1 |
|
else: |
|
facecolors[:, -1] = alpha |
|
p = PatchCollection(patches, facecolors=facecolors, edgecolors=edgecolors, linewidths=linewidths) |
|
axis.add_collection(p) |
|
|
|
if draw_vertices: |
|
for i, polygon in enumerate(polygons): |
|
axis.plot(*polygon.exterior.xy, marker="o", color=edgecolors[i], markersize=markersize) |
|
for interior in polygon.interiors: |
|
axis.plot(*interior.xy, marker="o", color=edgecolors[i], markersize=markersize) |
|
|
|
|
|
def plot_line_strings(axis, line_strings, draw_vertices=True, linewidths=2, markersize=5): |
|
artists = [] |
|
marker = "o" if draw_vertices else None |
|
for line_string in line_strings: |
|
artist, = axis.plot(*line_string.xy, marker=marker, markersize=markersize) |
|
artists.append(artist) |
|
return artists |
|
|
|
|
|
def plot_geometries(axis, geometries, draw_vertices=True, linewidths=2, markersize=3): |
|
polygons = [] |
|
line_strings = [] |
|
for geometry in geometries: |
|
if isinstance(geometry, shapely.geometry.Polygon): |
|
polygons.append(geometry) |
|
elif isinstance(geometry, shapely.geometry.LineString): |
|
line_strings.append(geometry) |
|
elif isinstance(geometry, shapely.geometry.MultiLineString): |
|
for line_string in geometry: |
|
line_strings.append(line_string) |
|
else: |
|
raise NotImplementedError(f"Geometry type {type(geometry)} not implemented") |
|
|
|
if len(polygons): |
|
plot_polygons(axis, polygons, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize) |
|
|
|
if len(line_strings): |
|
artists = plot_line_strings(axis, line_strings, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize) |
|
return artists |
|
|
|
|
|
def save_poly_viz(image, polygons, out_filepath, linewidths=2, markersize=20, alpha=0.2, draw_vertices=True, |
|
corners=None, crossfield=None, polygon_probs=None, seg=None, color_choices=None, dpi=10): |
|
assert isinstance(polygons, list), f"polygons should be of type list, not {type(polygons)}" |
|
if len(polygons): |
|
assert (type(polygons[0]) == np.ndarray or type(polygons[0]) == shapely.geometry.Polygon), \ |
|
f"Item of the polygons list should be of type ndarray or shapely Polygon, not {type(polygons[0])}" |
|
if polygon_probs is not None: |
|
assert type(polygon_probs) == list |
|
assert len(polygons) == len(polygon_probs), \ |
|
"len(polygons)={} should be equal to len(polygon_probs)={}".format(len(polygons), len(polygon_probs)) |
|
|
|
height = image.shape[0] |
|
width = image.shape[1] |
|
f, axis = plt.subplots(1, 1, figsize=(width / 10, height / 10), dpi=10) |
|
|
|
axis.imshow(image) |
|
|
|
if seg is not None: |
|
seg *= 0.9 |
|
axis.imshow(seg) |
|
|
|
if crossfield is not None: |
|
plot_crossfield(axis, crossfield, crossfield_stride=1, alpha=0.5, width=0.1, add_scale=1.1, invert_y=False) |
|
|
|
plot_polygons(axis, polygons, polygon_probs=polygon_probs, draw_vertices=draw_vertices, linewidths=linewidths, |
|
markersize=markersize, alpha=alpha, color_choices=color_choices) |
|
|
|
if corners is not None and len(corners): |
|
assert len(corners[0].shape) == 2 |
|
for corner_array in corners: |
|
plt.plot(corner_array[:, 0], corner_array[:, 1], marker="o", linewidth=0, markersize=20, color="red") |
|
|
|
axis.autoscale(False) |
|
axis.axis('equal') |
|
axis.axis('off') |
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
plt.savefig(out_filepath, transparent=True, dpi=dpi) |
|
plt.close() |
|
|
|
|
|
def main(): |
|
image = torch.zeros((2, 3, 512, 512)) + 0.5 |
|
seg = torch.zeros((2, 2, 512, 512)) |
|
seg[:, 0, 100:200, 100:200] = 1 |
|
crossfield = torch.zeros((2, 4, 512, 512)) |
|
|
|
|
|
u_angle = 0.25 |
|
v_angle = u_angle + np.pi / 2 |
|
u = np.cos(u_angle) + 1j * np.sin(u_angle) |
|
v = np.cos(v_angle) + 1j * np.sin(v_angle) |
|
c0 = np.power(u, 2) * np.power(v, 2) |
|
c2 = - (np.power(u, 2) + np.power(v, 2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
crossfield[:, 0, :, :] = c0.real |
|
crossfield[:, 1, :, :] = c0.imag |
|
crossfield[:, 2, :, :] = c2.real |
|
crossfield[:, 3, :, :] = c2.imag |
|
|
|
image_seg_display = get_tensorboard_image_seg_display(image, seg, crossfield=crossfield) |
|
image_seg_display = image_seg_display.cpu().numpy().transpose(0, 2, 3, 1) |
|
skimage.io.imsave("image_seg_display.png", image_seg_display[0]) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|