Update models/unet.py
Browse files- models/unet.py +5 -6
models/unet.py
CHANGED
@@ -856,14 +856,13 @@ class MotionCLR(nn.Module):
|
|
856 |
log_attn=self.vis_attn,
|
857 |
edit_config=edit_config,
|
858 |
)
|
|
|
|
|
|
|
|
|
|
|
859 |
|
860 |
def encode_text(self, raw_text, device):
|
861 |
-
print("00000000")
|
862 |
-
print(device)
|
863 |
-
print(next(self.clip_model.parameters()).device)
|
864 |
-
print("00000000")
|
865 |
-
self.clip_model = self.clip_model.to(device)
|
866 |
-
print("00000000")
|
867 |
with torch.no_grad():
|
868 |
texts = clip.tokenize(raw_text, truncate=True).to(
|
869 |
device
|
|
|
856 |
log_attn=self.vis_attn,
|
857 |
edit_config=edit_config,
|
858 |
)
|
859 |
+
|
860 |
+
self.embed_text = self.embed_text.to(device)
|
861 |
+
self.textTransEncoder = self.textTransEncoder.to(device)
|
862 |
+
self.text_ln = self.text_ln.to(device)
|
863 |
+
self.unet = self.unet.to(device)
|
864 |
|
865 |
def encode_text(self, raw_text, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
866 |
with torch.no_grad():
|
867 |
texts = clip.tokenize(raw_text, truncate=True).to(
|
868 |
device
|