ICDR / text_net /model.py
Siwon123's picture
q
7f43945
raw
history blame contribute delete
676 Bytes
from torch import nn
from text_net.encoder import CBDE
from text_net.DGRN import DGRN
class AirNet(nn.Module):
def __init__(self, opt):
super(AirNet, self).__init__()
# Restorer
self.R = DGRN(opt)
# Encoder
self.E = CBDE(opt)
def forward(self, x_query, x_key, text_prompt):
if self.training:
fea, logits, labels, inter = self.E(x_query, x_key)
restored = self.R(x_query, inter, text_prompt)
return restored, logits, labels
else:
fea, inter = self.E(x_query, x_query)
restored = self.R(x_query, inter, text_prompt)
return restored