|
from torch import nn |
|
|
|
from text_net.encoder import CBDE |
|
from text_net.DGRN import DGRN |
|
|
|
|
|
class AirNet(nn.Module): |
|
def __init__(self, opt): |
|
super(AirNet, self).__init__() |
|
|
|
|
|
self.R = DGRN(opt) |
|
|
|
|
|
self.E = CBDE(opt) |
|
|
|
def forward(self, x_query, x_key, text_prompt): |
|
if self.training: |
|
fea, logits, labels, inter = self.E(x_query, x_key) |
|
|
|
restored = self.R(x_query, inter, text_prompt) |
|
|
|
return restored, logits, labels |
|
else: |
|
fea, inter = self.E(x_query, x_query) |
|
|
|
restored = self.R(x_query, inter, text_prompt) |
|
|
|
return restored |
|
|