File size: 10,866 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import random
import skimage.io
from descartes import PolygonPatch
from matplotlib.collections import PatchCollection
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import torch
import shapely.geometry
from lydorn_utils import math_utils
from torch_lydorn import torchvision
def get_seg_display(seg):
dtype = seg.dtype
seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=dtype)
if len(seg.shape) == 2:
seg_display[..., 0] = seg
seg_display[..., 3] = seg
else:
for i in range(seg.shape[-1]):
seg_display[..., i] = seg[..., i]
clip_max = 255 if dtype == np.uint8 else 1
seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, clip_max)
return seg_display
def get_tensorboard_image_seg_display(image, seg, crossfield=None):
assert len(image.shape) == 4 and image.shape[1] == 3, f"image should be (N, 3, H, W), not {image.shape}."
assert len(seg.shape) == 4 and seg.shape[1] <= 3, f"image should be (N, C<=3, H, W), not {seg.shape}."
assert image.shape[0] == seg.shape[0], "image and seg should have the same batch size."
assert image.shape[2] == seg.shape[2], "image and seg should have the same image height."
assert image.shape[3] == seg.shape[3], "image and seg should have the same image width."
if crossfield is not None:
assert len(crossfield.shape) == 4 and crossfield.shape[
1] == 4, f"crossfield should be (N, 4, H, W), not {crossfield.shape}."
assert image.shape[0] == crossfield.shape[0], "image and crossfield should have the same batch size."
assert image.shape[2] == crossfield.shape[2], "image and crossfield should have the same image height."
assert image.shape[3] == crossfield.shape[3], "image and crossfield should have the same image width."
alpha = torch.clamp(torch.sum(seg, dim=1, keepdim=True), 0, 1)
# Add missing seg channels
seg_display = torch.zeros_like(image)
seg_display[:, :seg.shape[1], ...] = seg
image_seg_display = (1 - alpha) * image + alpha * seg_display
image_seg_display = image_seg_display.cpu()
if crossfield is not None:
np_crossfield = crossfield.cpu().detach().numpy().transpose(0, 2, 3, 1)
image_plot_crossfield_list = [get_image_plot_crossfield(_crossfield, crossfield_stride=10) for _crossfield in
np_crossfield]
image_plot_crossfield_list = [torchvision.transforms.functional.to_tensor(image_plot_crossfield).float() / 255
for image_plot_crossfield in image_plot_crossfield_list]
image_plot_crossfield = torch.stack(image_plot_crossfield_list, dim=0)
alpha = image_plot_crossfield[:, 3:4, :, :]
image_seg_display = (1 - alpha) * image_seg_display + alpha * image_plot_crossfield[:, :3, :, :]
# image_seg_display = image_plot_crossfield[:, :3, :, :]
return image_seg_display
def plot_crossfield(axis, crossfield, crossfield_stride, alpha=0.5, width=0.5, add_scale=1, invert_y=True):
x = np.arange(0, crossfield.shape[1], crossfield_stride)
y = np.arange(0, crossfield.shape[0], crossfield_stride)
x, y = np.meshgrid(x, y)
i = y
if invert_y:
i = crossfield.shape[0] - 1 - y
j = x
scale = add_scale * 1 / crossfield_stride
c0c2 = crossfield[i, j, :]
u, v = math_utils.compute_crossfield_uv(c0c2)
# u_angle = 0.5
# u.real = np.cos(u_angle)
# u.imag = np.sin(u_angle)
# v *= 0
quiveropts = dict(color=(0, 0, 1, alpha), headaxislength=0, headlength=0, pivot='middle', angles="xy", units='xy',
scale=scale, width=width, headwidth=1)
axis.quiver(x, y, u.imag, -u.real, **quiveropts)
axis.quiver(x, y, v.imag, -v.real, **quiveropts)
def get_image_plot_crossfield(crossfield, crossfield_stride):
fig = Figure(figsize=(crossfield.shape[1] / 100, crossfield.shape[0] / 100), dpi=100)
canvas = FigureCanvasAgg(fig)
ax = fig.gca()
plot_crossfield(ax, crossfield, crossfield_stride, alpha=1.0, width=2.0, add_scale=1)
ax.axis('off')
fig.tight_layout(pad=0)
# To remove the huge white borders
ax.margins(0)
canvas.draw()
image_from_plot = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8)
image_from_plot = image_from_plot.reshape(canvas.get_width_height()[::-1] + (4,))
image_from_plot = np.roll(image_from_plot, -1, axis=-1) # Convert ARGB to RGBA
# Fix alpha (white to alpha)
# mask = np.sum(image_from_plot[:, :, :3], axis=2) == 3*255
# image_from_plot[mask, 3] = 0
mini = image_from_plot.min()
image_from_plot[:, :, 3] = np.max(255 - image_from_plot[:, :, :3] + mini, axis=2)
return image_from_plot
def plot_polygons(axis, polygons, polygon_probs=None, draw_vertices=True, linewidths=2, markersize=10, alpha=0.2,
color_choices=None):
if len(polygons) == 0:
return
patches = []
for i, geometry in enumerate(polygons):
polygon = shapely.geometry.Polygon(geometry)
if not polygon.is_empty:
patch = PolygonPatch(polygon)
patches.append(patch)
random.seed(1)
if color_choices is None:
color_choices = [
[0, 0, 1, 1],
[0, 1, 0, 1],
[1, 0, 0, 1],
[1, 1, 0, 1],
[1, 0, 1, 1],
[0, 1, 1, 1],
[0.5, 1, 0, 1],
[1, 0.5, 0, 1],
[0.5, 0, 1, 1],
[1, 0, 0.5, 1],
[0, 0.5, 1, 1],
[0, 1, 0.5, 1],
]
colors = random.choices(color_choices, k=len(patches))
edgecolors = np.array(colors, dtype=np.float)
facecolors = edgecolors.copy()
if polygon_probs is not None:
facecolors[:, -1] = alpha * np.array(polygon_probs) + 0.1
else:
facecolors[:, -1] = alpha
p = PatchCollection(patches, facecolors=facecolors, edgecolors=edgecolors, linewidths=linewidths)
axis.add_collection(p)
if draw_vertices:
for i, polygon in enumerate(polygons):
axis.plot(*polygon.exterior.xy, marker="o", color=edgecolors[i], markersize=markersize)
for interior in polygon.interiors:
axis.plot(*interior.xy, marker="o", color=edgecolors[i], markersize=markersize)
def plot_line_strings(axis, line_strings, draw_vertices=True, linewidths=2, markersize=5):
artists = []
marker = "o" if draw_vertices else None
for line_string in line_strings:
artist, = axis.plot(*line_string.xy, marker=marker, markersize=markersize)
artists.append(artist)
return artists
def plot_geometries(axis, geometries, draw_vertices=True, linewidths=2, markersize=3):
polygons = []
line_strings = []
for geometry in geometries:
if isinstance(geometry, shapely.geometry.Polygon):
polygons.append(geometry)
elif isinstance(geometry, shapely.geometry.LineString):
line_strings.append(geometry)
elif isinstance(geometry, shapely.geometry.MultiLineString):
for line_string in geometry:
line_strings.append(line_string)
else:
raise NotImplementedError(f"Geometry type {type(geometry)} not implemented")
if len(polygons):
plot_polygons(axis, polygons, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize)
if len(line_strings):
artists = plot_line_strings(axis, line_strings, draw_vertices=draw_vertices, linewidths=linewidths, markersize=markersize)
return artists
def save_poly_viz(image, polygons, out_filepath, linewidths=2, markersize=20, alpha=0.2, draw_vertices=True,
corners=None, crossfield=None, polygon_probs=None, seg=None, color_choices=None, dpi=10):
assert isinstance(polygons, list), f"polygons should be of type list, not {type(polygons)}"
if len(polygons):
assert (type(polygons[0]) == np.ndarray or type(polygons[0]) == shapely.geometry.Polygon), \
f"Item of the polygons list should be of type ndarray or shapely Polygon, not {type(polygons[0])}"
if polygon_probs is not None:
assert type(polygon_probs) == list
assert len(polygons) == len(polygon_probs), \
"len(polygons)={} should be equal to len(polygon_probs)={}".format(len(polygons), len(polygon_probs))
# Setup plot
height = image.shape[0]
width = image.shape[1]
f, axis = plt.subplots(1, 1, figsize=(width / 10, height / 10), dpi=10)
axis.imshow(image)
if seg is not None:
seg *= 0.9
axis.imshow(seg)
if crossfield is not None:
plot_crossfield(axis, crossfield, crossfield_stride=1, alpha=0.5, width=0.1, add_scale=1.1, invert_y=False)
plot_polygons(axis, polygons, polygon_probs=polygon_probs, draw_vertices=draw_vertices, linewidths=linewidths,
markersize=markersize, alpha=alpha, color_choices=color_choices)
if corners is not None and len(corners):
assert len(corners[0].shape) == 2
for corner_array in corners:
plt.plot(corner_array[:, 0], corner_array[:, 1], marker="o", linewidth=0, markersize=20, color="red")
axis.autoscale(False)
axis.axis('equal')
axis.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # Plot without margins
plt.savefig(out_filepath, transparent=True, dpi=dpi)
plt.close()
def main():
image = torch.zeros((2, 3, 512, 512)) + 0.5
seg = torch.zeros((2, 2, 512, 512))
seg[:, 0, 100:200, 100:200] = 1
crossfield = torch.zeros((2, 4, 512, 512))
# u_angle = np.random.random(10000) * np.pi
# v_angle = np.random.random(10000) * np.pi
u_angle = 0.25
v_angle = u_angle + np.pi / 2
u = np.cos(u_angle) + 1j * np.sin(u_angle)
v = np.cos(v_angle) + 1j * np.sin(v_angle)
c0 = np.power(u, 2) * np.power(v, 2)
c2 = - (np.power(u, 2) + np.power(v, 2))
# print("c0:")
# print(np.abs(c0).min(), np.abs(c0).mean(), np.abs(c0).max())
# print(c0.real.min(), c0.real.mean(), c0.real.mean())
# print(c0.imag.min(), c0.imag.mean(), c0.imag.max())
# print("c2:")
# print(np.abs(c2).min(), np.abs(c2).mean(), np.abs(c2).max())
# print(c2.real.min(), c2.real.mean(), c2.real.max())
# print(c2.real.min(), c2.imag.mean(), c2.imag.max())
crossfield[:, 0, :, :] = c0.real
crossfield[:, 1, :, :] = c0.imag
crossfield[:, 2, :, :] = c2.real
crossfield[:, 3, :, :] = c2.imag
image_seg_display = get_tensorboard_image_seg_display(image, seg, crossfield=crossfield)
image_seg_display = image_seg_display.cpu().numpy().transpose(0, 2, 3, 1)
skimage.io.imsave("image_seg_display.png", image_seg_display[0])
if __name__ == "__main__":
main()
|