Fediory's picture
feat: huggingface_hub
2d65752
raw
history blame
4.41 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from net.HVI_transform import RGB_HVI
from net.transformer_utils import *
from net.LCA import *
from huggingface_hub import PyTorchModelHubMixin
class CIDNet(nn.Module, PyTorchModelHubMixin):
def __init__(self,
channels=[36, 36, 72, 144],
heads=[1, 2, 4, 8],
norm=False
):
super(CIDNet, self).__init__()
[ch1, ch2, ch3, ch4] = channels
[head1, head2, head3, head4] = heads
# HV_ways
self.HVE_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
)
self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
self.HVD_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
)
# I_ways
self.IE_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
)
self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
self.ID_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
)
self.HV_LCA1 = HV_LCA(ch2, head2)
self.HV_LCA2 = HV_LCA(ch3, head3)
self.HV_LCA3 = HV_LCA(ch4, head4)
self.HV_LCA4 = HV_LCA(ch4, head4)
self.HV_LCA5 = HV_LCA(ch3, head3)
self.HV_LCA6 = HV_LCA(ch2, head2)
self.I_LCA1 = I_LCA(ch2, head2)
self.I_LCA2 = I_LCA(ch3, head3)
self.I_LCA3 = I_LCA(ch4, head4)
self.I_LCA4 = I_LCA(ch4, head4)
self.I_LCA5 = I_LCA(ch3, head3)
self.I_LCA6 = I_LCA(ch2, head2)
self.trans = RGB_HVI()
def forward(self, x):
dtypes = x.dtype
hvi = self.trans.HVIT(x)
i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
# low
i_enc0 = self.IE_block0(i)
i_enc1 = self.IE_block1(i_enc0)
hv_0 = self.HVE_block0(hvi)
hv_1 = self.HVE_block1(hv_0)
i_jump0 = i_enc0
hv_jump0 = hv_0
i_enc2 = self.I_LCA1(i_enc1, hv_1)
hv_2 = self.HV_LCA1(hv_1, i_enc1)
v_jump1 = i_enc2
hv_jump1 = hv_2
i_enc2 = self.IE_block2(i_enc2)
hv_2 = self.HVE_block2(hv_2)
i_enc3 = self.I_LCA2(i_enc2, hv_2)
hv_3 = self.HV_LCA2(hv_2, i_enc2)
v_jump2 = i_enc3
hv_jump2 = hv_3
i_enc3 = self.IE_block3(i_enc2)
hv_3 = self.HVE_block3(hv_2)
i_enc4 = self.I_LCA3(i_enc3, hv_3)
hv_4 = self.HV_LCA3(hv_3, i_enc3)
i_dec4 = self.I_LCA4(i_enc4,hv_4)
hv_4 = self.HV_LCA4(hv_4, i_enc4)
hv_3 = self.HVD_block3(hv_4, hv_jump2)
i_dec3 = self.ID_block3(i_dec4, v_jump2)
i_dec2 = self.I_LCA5(i_dec3, hv_3)
hv_2 = self.HV_LCA5(hv_3, i_dec3)
hv_2 = self.HVD_block2(hv_2, hv_jump1)
i_dec2 = self.ID_block2(i_dec3, v_jump1)
i_dec1 = self.I_LCA6(i_dec2, hv_2)
hv_1 = self.HV_LCA6(hv_2, i_dec2)
i_dec1 = self.ID_block1(i_dec1, i_jump0)
i_dec0 = self.ID_block0(i_dec1)
hv_1 = self.HVD_block1(hv_1, hv_jump0)
hv_0 = self.HVD_block0(hv_1)
output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
output_rgb = self.trans.PHVIT(output_hvi)
return output_rgb
def HVIT(self,x):
hvi = self.trans.HVIT(x)
return hvi