Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from net.transformer_utils import * | |
# Cross Attention Block | |
class CAB(nn.Module): | |
def __init__(self, dim, num_heads, bias): | |
super(CAB, self).__init__() | |
self.num_heads = num_heads | |
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) | |
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) | |
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) | |
self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) | |
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
def forward(self, x, y): | |
b, c, h, w = x.shape | |
q = self.q_dwconv(self.q(x)) | |
kv = self.kv_dwconv(self.kv(y)) | |
k, v = kv.chunk(2, dim=1) | |
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
q = torch.nn.functional.normalize(q, dim=-1) | |
k = torch.nn.functional.normalize(k, dim=-1) | |
attn = (q @ k.transpose(-2, -1)) * self.temperature | |
attn = nn.functional.softmax(attn,dim=-1) | |
out = (attn @ v) | |
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) | |
out = self.project_out(out) | |
return out | |
# Intensity Enhancement Layer | |
class IEL(nn.Module): | |
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False): | |
super(IEL, self).__init__() | |
hidden_features = int(dim*ffn_expansion_factor) | |
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) | |
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) | |
self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias) | |
self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias) | |
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) | |
self.Tanh = nn.Tanh() | |
def forward(self, x): | |
x = self.project_in(x) | |
x1, x2 = self.dwconv(x).chunk(2, dim=1) | |
x1 = self.Tanh(self.dwconv1(x1)) + x1 | |
x2 = self.Tanh(self.dwconv2(x2)) + x2 | |
x = x1 * x2 | |
x = self.project_out(x) | |
return x | |
# Lightweight Cross Attention | |
class HV_LCA(nn.Module): | |
def __init__(self, dim,num_heads, bias=False): | |
super(HV_LCA, self).__init__() | |
self.gdfn = IEL(dim) # IEL and CDL have same structure | |
self.norm = LayerNorm(dim) | |
self.ffn = CAB(dim, num_heads, bias) | |
def forward(self, x, y): | |
x = x + self.ffn(self.norm(x),self.norm(y)) | |
x = self.gdfn(self.norm(x)) | |
return x | |
class I_LCA(nn.Module): | |
def __init__(self, dim,num_heads, bias=False): | |
super(I_LCA, self).__init__() | |
self.norm = LayerNorm(dim) | |
self.gdfn = IEL(dim) | |
self.ffn = CAB(dim, num_heads, bias=bias) | |
def forward(self, x, y): | |
x = x + self.ffn(self.norm(x),self.norm(y)) | |
x = x + self.gdfn(self.norm(x)) | |
return x | |