Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
pi = 3.141592653589793 | |
class RGB_HVI(nn.Module): | |
def __init__(self): | |
super(RGB_HVI, self).__init__() | |
self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned | |
self.gated = False | |
self.gated2= False | |
self.alpha = 1.0 | |
self.alpha_s = 1.3 | |
self.this_k = 0 | |
def HVIT(self, img): | |
eps = 1e-8 | |
device = img.device | |
dtypes = img.dtype | |
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes) | |
value = img.max(1)[0].to(dtypes) | |
img_min = img.min(1)[0].to(dtypes) | |
hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value] | |
hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value] | |
hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6 | |
hue[img.min(1)[0]==value] = 0.0 | |
hue = hue/6.0 | |
saturation = (value - img_min ) / (value + eps ) | |
saturation[value==0] = 0 | |
hue = hue.unsqueeze(1) | |
saturation = saturation.unsqueeze(1) | |
value = value.unsqueeze(1) | |
k = self.density_k | |
self.this_k = k.item() | |
color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k) | |
ch = (2.0 * pi * hue).cos() | |
cv = (2.0 * pi * hue).sin() | |
H = color_sensitive * saturation * ch | |
V = color_sensitive * saturation * cv | |
I = value | |
xyz = torch.cat([H, V, I],dim=1) | |
return xyz | |
def PHVIT(self, img): | |
eps = 1e-8 | |
H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:] | |
# clip | |
H = torch.clamp(H,-1,1) | |
V = torch.clamp(V,-1,1) | |
I = torch.clamp(I,0,1) | |
v = I | |
k = self.this_k | |
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k) | |
H = (H) / (color_sensitive + eps) | |
V = (V) / (color_sensitive + eps) | |
H = torch.clamp(H,-1,1) | |
V = torch.clamp(V,-1,1) | |
h = torch.atan2(V + eps,H + eps) / (2*pi) | |
h = h%1 | |
s = torch.sqrt(H**2 + V**2 + eps) | |
if self.gated: | |
s = s * self.alpha_s | |
s = torch.clamp(s,0,1) | |
v = torch.clamp(v,0,1) | |
r = torch.zeros_like(h) | |
g = torch.zeros_like(h) | |
b = torch.zeros_like(h) | |
hi = torch.floor(h * 6.0) | |
f = h * 6.0 - hi | |
p = v * (1. - s) | |
q = v * (1. - (f * s)) | |
t = v * (1. - ((1. - f) * s)) | |
hi0 = hi==0 | |
hi1 = hi==1 | |
hi2 = hi==2 | |
hi3 = hi==3 | |
hi4 = hi==4 | |
hi5 = hi==5 | |
r[hi0] = v[hi0] | |
g[hi0] = t[hi0] | |
b[hi0] = p[hi0] | |
r[hi1] = q[hi1] | |
g[hi1] = v[hi1] | |
b[hi1] = p[hi1] | |
r[hi2] = p[hi2] | |
g[hi2] = v[hi2] | |
b[hi2] = t[hi2] | |
r[hi3] = p[hi3] | |
g[hi3] = q[hi3] | |
b[hi3] = v[hi3] | |
r[hi4] = t[hi4] | |
g[hi4] = p[hi4] | |
b[hi4] = v[hi4] | |
r[hi5] = v[hi5] | |
g[hi5] = p[hi5] | |
b[hi5] = q[hi5] | |
r = r.unsqueeze(1) | |
g = g.unsqueeze(1) | |
b = b.unsqueeze(1) | |
rgb = torch.cat([r, g, b], dim=1) | |
if self.gated2: | |
rgb = rgb * self.alpha | |
return rgb | |