import torch import torch.nn as nn pi = 3.141592653589793 class RGB_HVI(nn.Module): def __init__(self): super(RGB_HVI, self).__init__() self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned self.gated = False self.gated2= False self.alpha = 1.0 self.alpha_s = 1.3 self.this_k = 0 def HVIT(self, img): eps = 1e-8 device = img.device dtypes = img.dtype hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes) value = img.max(1)[0].to(dtypes) img_min = img.min(1)[0].to(dtypes) hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value] hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value] hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6 hue[img.min(1)[0]==value] = 0.0 hue = hue/6.0 saturation = (value - img_min ) / (value + eps ) saturation[value==0] = 0 hue = hue.unsqueeze(1) saturation = saturation.unsqueeze(1) value = value.unsqueeze(1) k = self.density_k self.this_k = k.item() color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k) ch = (2.0 * pi * hue).cos() cv = (2.0 * pi * hue).sin() H = color_sensitive * saturation * ch V = color_sensitive * saturation * cv I = value xyz = torch.cat([H, V, I],dim=1) return xyz def PHVIT(self, img): eps = 1e-8 H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:] # clip H = torch.clamp(H,-1,1) V = torch.clamp(V,-1,1) I = torch.clamp(I,0,1) v = I k = self.this_k color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k) H = (H) / (color_sensitive + eps) V = (V) / (color_sensitive + eps) H = torch.clamp(H,-1,1) V = torch.clamp(V,-1,1) h = torch.atan2(V + eps,H + eps) / (2*pi) h = h%1 s = torch.sqrt(H**2 + V**2 + eps) if self.gated: s = s * self.alpha_s s = torch.clamp(s,0,1) v = torch.clamp(v,0,1) r = torch.zeros_like(h) g = torch.zeros_like(h) b = torch.zeros_like(h) hi = torch.floor(h * 6.0) f = h * 6.0 - hi p = v * (1. - s) q = v * (1. - (f * s)) t = v * (1. - ((1. - f) * s)) hi0 = hi==0 hi1 = hi==1 hi2 = hi==2 hi3 = hi==3 hi4 = hi==4 hi5 = hi==5 r[hi0] = v[hi0] g[hi0] = t[hi0] b[hi0] = p[hi0] r[hi1] = q[hi1] g[hi1] = v[hi1] b[hi1] = p[hi1] r[hi2] = p[hi2] g[hi2] = v[hi2] b[hi2] = t[hi2] r[hi3] = p[hi3] g[hi3] = q[hi3] b[hi3] = v[hi3] r[hi4] = t[hi4] g[hi4] = p[hi4] b[hi4] = v[hi4] r[hi5] = v[hi5] g[hi5] = p[hi5] b[hi5] = q[hi5] r = r.unsqueeze(1) g = g.unsqueeze(1) b = b.unsqueeze(1) rgb = torch.cat([r, g, b], dim=1) if self.gated2: rgb = rgb * self.alpha return rgb