EvanTHU commited on
Commit
75b59c5
·
verified ·
1 Parent(s): 51c06a5

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +4 -4
models/unet.py CHANGED
@@ -857,10 +857,10 @@ class MotionCLR(nn.Module):
857
  edit_config=edit_config,
858
  )
859
 
860
- self.embed_text = self.embed_text.to(device)
861
- self.textTransEncoder = self.textTransEncoder.to(device)
862
- self.text_ln = self.text_ln.to(device)
863
- self.unet = self.unet.to(device)
864
 
865
  def encode_text(self, raw_text, device):
866
  with torch.no_grad():
 
857
  edit_config=edit_config,
858
  )
859
 
860
+ self.embed_text = self.embed_text.cuda()
861
+ self.textTransEncoder = self.textTransEncoder.cuda()
862
+ self.text_ln = self.text_ln.cuda()
863
+ self.unet = self.unet.cuda()
864
 
865
  def encode_text(self, raw_text, device):
866
  with torch.no_grad():