Light-amplification / net /HVI_transform.py
Fediory's picture
Add application file
3e648fb
raw
history blame
3.46 kB
import torch
import torch.nn as nn
pi = 3.141592653589793
class RGB_HVI(nn.Module):
def __init__(self):
super(RGB_HVI, self).__init__()
self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned
self.gated = False
self.gated2= False
self.alpha = 1.0
self.alpha_s = 1.3
self.this_k = 0
def HVIT(self, img):
eps = 1e-8
device = img.device
dtypes = img.dtype
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
value = img.max(1)[0].to(dtypes)
img_min = img.min(1)[0].to(dtypes)
hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
hue[img.min(1)[0]==value] = 0.0
hue = hue/6.0
saturation = (value - img_min ) / (value + eps )
saturation[value==0] = 0
hue = hue.unsqueeze(1)
saturation = saturation.unsqueeze(1)
value = value.unsqueeze(1)
k = self.density_k
self.this_k = k.item()
color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
ch = (2.0 * pi * hue).cos()
cv = (2.0 * pi * hue).sin()
H = color_sensitive * saturation * ch
V = color_sensitive * saturation * cv
I = value
xyz = torch.cat([H, V, I],dim=1)
return xyz
def PHVIT(self, img):
eps = 1e-8
H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:]
# clip
H = torch.clamp(H,-1,1)
V = torch.clamp(V,-1,1)
I = torch.clamp(I,0,1)
v = I
k = self.this_k
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
H = (H) / (color_sensitive + eps)
V = (V) / (color_sensitive + eps)
H = torch.clamp(H,-1,1)
V = torch.clamp(V,-1,1)
h = torch.atan2(V + eps,H + eps) / (2*pi)
h = h%1
s = torch.sqrt(H**2 + V**2 + eps)
if self.gated:
s = s * self.alpha_s
s = torch.clamp(s,0,1)
v = torch.clamp(v,0,1)
r = torch.zeros_like(h)
g = torch.zeros_like(h)
b = torch.zeros_like(h)
hi = torch.floor(h * 6.0)
f = h * 6.0 - hi
p = v * (1. - s)
q = v * (1. - (f * s))
t = v * (1. - ((1. - f) * s))
hi0 = hi==0
hi1 = hi==1
hi2 = hi==2
hi3 = hi==3
hi4 = hi==4
hi5 = hi==5
r[hi0] = v[hi0]
g[hi0] = t[hi0]
b[hi0] = p[hi0]
r[hi1] = q[hi1]
g[hi1] = v[hi1]
b[hi1] = p[hi1]
r[hi2] = p[hi2]
g[hi2] = v[hi2]
b[hi2] = t[hi2]
r[hi3] = p[hi3]
g[hi3] = q[hi3]
b[hi3] = v[hi3]
r[hi4] = t[hi4]
g[hi4] = p[hi4]
b[hi4] = v[hi4]
r[hi5] = v[hi5]
g[hi5] = p[hi5]
b[hi5] = q[hi5]
r = r.unsqueeze(1)
g = g.unsqueeze(1)
b = b.unsqueeze(1)
rgb = torch.cat([r, g, b], dim=1)
if self.gated2:
rgb = rgb * self.alpha
return rgb