Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,406 Bytes
3e648fb 2d65752 3e648fb 2d65752 3e648fb 0f856ef 3e648fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from net.HVI_transform import RGB_HVI
from net.transformer_utils import *
from net.LCA import *
from huggingface_hub import PyTorchModelHubMixin
class CIDNet(nn.Module, PyTorchModelHubMixin):
def __init__(self,
channels=[36, 36, 72, 144],
heads=[1, 2, 4, 8],
norm=False
):
super(CIDNet, self).__init__()
[ch1, ch2, ch3, ch4] = channels
[head1, head2, head3, head4] = heads
# HV_ways
self.HVE_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
)
self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
self.HVD_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
)
# I_ways
self.IE_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
)
self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
self.ID_block0 = nn.Sequential(
nn.ReplicationPad2d(1),
nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
)
self.HV_LCA1 = HV_LCA(ch2, head2)
self.HV_LCA2 = HV_LCA(ch3, head3)
self.HV_LCA3 = HV_LCA(ch4, head4)
self.HV_LCA4 = HV_LCA(ch4, head4)
self.HV_LCA5 = HV_LCA(ch3, head3)
self.HV_LCA6 = HV_LCA(ch2, head2)
self.I_LCA1 = I_LCA(ch2, head2)
self.I_LCA2 = I_LCA(ch3, head3)
self.I_LCA3 = I_LCA(ch4, head4)
self.I_LCA4 = I_LCA(ch4, head4)
self.I_LCA5 = I_LCA(ch3, head3)
self.I_LCA6 = I_LCA(ch2, head2)
self.trans = RGB_HVI()
def forward(self, x):
dtypes = x.dtype
hvi = self.trans.HVIT(x)
i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
# low
i_enc0 = self.IE_block0(i)
i_enc1 = self.IE_block1(i_enc0)
hv_0 = self.HVE_block0(hvi)
hv_1 = self.HVE_block1(hv_0)
i_jump0 = i_enc0
hv_jump0 = hv_0
i_enc2 = self.I_LCA1(i_enc1, hv_1)
hv_2 = self.HV_LCA1(hv_1, i_enc1)
v_jump1 = i_enc2
hv_jump1 = hv_2
i_enc2 = self.IE_block2(i_enc2)
hv_2 = self.HVE_block2(hv_2)
i_enc3 = self.I_LCA2(i_enc2, hv_2)
hv_3 = self.HV_LCA2(hv_2, i_enc2)
v_jump2 = i_enc3
hv_jump2 = hv_3
i_enc3 = self.IE_block3(i_enc2)
hv_3 = self.HVE_block3(hv_2)
i_enc4 = self.I_LCA3(i_enc3, hv_3)
hv_4 = self.HV_LCA3(hv_3, i_enc3)
i_dec4 = self.I_LCA4(i_enc4,hv_4)
hv_4 = self.HV_LCA4(hv_4, i_enc4)
hv_3 = self.HVD_block3(hv_4, hv_jump2)
i_dec3 = self.ID_block3(i_dec4, v_jump2)
i_dec2 = self.I_LCA5(i_dec3, hv_3)
hv_2 = self.HV_LCA5(hv_3, i_dec3)
hv_2 = self.HVD_block2(hv_2, hv_jump1)
i_dec2 = self.ID_block2(i_dec3, v_jump1)
i_dec1 = self.I_LCA6(i_dec2, hv_2)
hv_1 = self.HV_LCA6(hv_2, i_dec2)
i_dec1 = self.ID_block1(i_dec1, i_jump0)
i_dec0 = self.ID_block0(i_dec1)
hv_1 = self.HVD_block1(hv_1, hv_jump0)
hv_0 = self.HVD_block0(hv_1)
output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
output_rgb = self.trans.PHVIT(output_hvi)
return output_rgb
def HVIT(self,x):
hvi = self.trans.HVIT(x)
return hvi
|