Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from net.HVI_transform import RGB_HVI | |
from net.transformer_utils import * | |
from net.LCA import * | |
from huggingface_hub import PyTorchModelHubMixin | |
class CIDNet(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, | |
channels=[36, 36, 72, 144], | |
heads=[1, 2, 4, 8], | |
norm=False | |
): | |
super(CIDNet, self).__init__() | |
[ch1, ch2, ch3, ch4] = channels | |
[head1, head2, head3, head4] = heads | |
# HV_ways | |
self.HVE_block0 = nn.Sequential( | |
nn.ReplicationPad2d(1), | |
nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False) | |
) | |
self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm) | |
self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm) | |
self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm) | |
self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm) | |
self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm) | |
self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm) | |
self.HVD_block0 = nn.Sequential( | |
nn.ReplicationPad2d(1), | |
nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False) | |
) | |
# I_ways | |
self.IE_block0 = nn.Sequential( | |
nn.ReplicationPad2d(1), | |
nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False), | |
) | |
self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm) | |
self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm) | |
self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm) | |
self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm) | |
self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm) | |
self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm) | |
self.ID_block0 = nn.Sequential( | |
nn.ReplicationPad2d(1), | |
nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False), | |
) | |
self.HV_LCA1 = HV_LCA(ch2, head2) | |
self.HV_LCA2 = HV_LCA(ch3, head3) | |
self.HV_LCA3 = HV_LCA(ch4, head4) | |
self.HV_LCA4 = HV_LCA(ch4, head4) | |
self.HV_LCA5 = HV_LCA(ch3, head3) | |
self.HV_LCA6 = HV_LCA(ch2, head2) | |
self.I_LCA1 = I_LCA(ch2, head2) | |
self.I_LCA2 = I_LCA(ch3, head3) | |
self.I_LCA3 = I_LCA(ch4, head4) | |
self.I_LCA4 = I_LCA(ch4, head4) | |
self.I_LCA5 = I_LCA(ch3, head3) | |
self.I_LCA6 = I_LCA(ch2, head2) | |
self.trans = RGB_HVI() | |
def forward(self, x): | |
dtypes = x.dtype | |
hvi = self.trans.HVIT(x) | |
i = hvi[:,2,:,:].unsqueeze(1).to(dtypes) | |
# low | |
i_enc0 = self.IE_block0(i) | |
i_enc1 = self.IE_block1(i_enc0) | |
hv_0 = self.HVE_block0(hvi) | |
hv_1 = self.HVE_block1(hv_0) | |
i_jump0 = i_enc0 | |
hv_jump0 = hv_0 | |
i_enc2 = self.I_LCA1(i_enc1, hv_1) | |
hv_2 = self.HV_LCA1(hv_1, i_enc1) | |
v_jump1 = i_enc2 | |
hv_jump1 = hv_2 | |
i_enc2 = self.IE_block2(i_enc2) | |
hv_2 = self.HVE_block2(hv_2) | |
i_enc3 = self.I_LCA2(i_enc2, hv_2) | |
hv_3 = self.HV_LCA2(hv_2, i_enc2) | |
v_jump2 = i_enc3 | |
hv_jump2 = hv_3 | |
i_enc3 = self.IE_block3(i_enc2) | |
hv_3 = self.HVE_block3(hv_2) | |
i_enc4 = self.I_LCA3(i_enc3, hv_3) | |
hv_4 = self.HV_LCA3(hv_3, i_enc3) | |
i_dec4 = self.I_LCA4(i_enc4,hv_4) | |
hv_4 = self.HV_LCA4(hv_4, i_enc4) | |
hv_3 = self.HVD_block3(hv_4, hv_jump2) | |
i_dec3 = self.ID_block3(i_dec4, v_jump2) | |
i_dec2 = self.I_LCA5(i_dec3, hv_3) | |
hv_2 = self.HV_LCA5(hv_3, i_dec3) | |
hv_2 = self.HVD_block2(hv_2, hv_jump1) | |
i_dec2 = self.ID_block2(i_dec3, v_jump1) | |
i_dec1 = self.I_LCA6(i_dec2, hv_2) | |
hv_1 = self.HV_LCA6(hv_2, i_dec2) | |
i_dec1 = self.ID_block1(i_dec1, i_jump0) | |
i_dec0 = self.ID_block0(i_dec1) | |
hv_1 = self.HVD_block1(hv_1, hv_jump0) | |
hv_0 = self.HVD_block0(hv_1) | |
output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi | |
output_rgb = self.trans.PHVIT(output_hvi) | |
return output_rgb | |
def HVIT(self,x): | |
hvi = self.trans.HVIT(x) | |
return hvi | |