File size: 676 Bytes
7f43945 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from torch import nn
from text_net.encoder import CBDE
from text_net.DGRN import DGRN
class AirNet(nn.Module):
def __init__(self, opt):
super(AirNet, self).__init__()
# Restorer
self.R = DGRN(opt)
# Encoder
self.E = CBDE(opt)
def forward(self, x_query, x_key, text_prompt):
if self.training:
fea, logits, labels, inter = self.E(x_query, x_key)
restored = self.R(x_query, inter, text_prompt)
return restored, logits, labels
else:
fea, inter = self.E(x_query, x_query)
restored = self.R(x_query, inter, text_prompt)
return restored
|