|
import torch |
|
import numpy as np |
|
|
|
|
|
def generate_meshgrid_2d(h: int, w: int, device) -> torch.tensor: |
|
x = torch.linspace(-1, 1, h, device=device) |
|
y = torch.linspace(-1, 1, w, device=device) |
|
grid_x, grid_y = torch.meshgrid(x, y) |
|
grid = torch.stack([grid_x, grid_y], dim=2) |
|
return grid |
|
|
|
|
|
def his_match(src, dst): |
|
src = src * 255.0 |
|
dst = dst * 255.0 |
|
src = src.astype(np.uint8) |
|
dst = dst.astype(np.uint8) |
|
res = np.zeros_like(dst) |
|
|
|
cdf_src = np.zeros((3, 256)) |
|
cdf_dst = np.zeros((3, 256)) |
|
cdf_res = np.zeros((3, 256)) |
|
kw = dict(bins=256, range=(0, 256), density=True) |
|
for ch in range(3): |
|
his_src, _ = np.histogram(src[:, :, ch], **kw) |
|
hist_dst, _ = np.histogram(dst[:, :, ch], **kw) |
|
cdf_src[ch] = np.cumsum(his_src) |
|
cdf_dst[ch] = np.cumsum(hist_dst) |
|
index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left") |
|
np.clip(index, 0, 255, out=index) |
|
res[:, :, ch] = index[dst[:, :, ch]] |
|
his_res, _ = np.histogram(res[:, :, ch], **kw) |
|
cdf_res[ch] = np.cumsum(his_res) |
|
return res / 255.0 |
|
|