cocktailpeanut commited on
Commit
a7bed20
·
1 Parent(s): 8611f7d
stable_diffusion/ldm/modules/encoders/modules.py CHANGED
@@ -5,9 +5,11 @@ import clip
5
  from einops import rearrange, repeat
6
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPVisionModel, CLIPModel
7
  import kornia
 
8
 
9
  from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10
 
 
11
 
12
  class AbstractEncoder(nn.Module):
13
  def __init__(self):
@@ -35,7 +37,7 @@ class ClassEmbedder(nn.Module):
35
 
36
  class TransformerEmbedder(AbstractEncoder):
37
  """Some transformer encoder layers"""
38
- def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39
  super().__init__()
40
  self.device = device
41
  self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
@@ -52,7 +54,7 @@ class TransformerEmbedder(AbstractEncoder):
52
 
53
  class BERTTokenizer(AbstractEncoder):
54
  """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55
- def __init__(self, device="cuda", vq_interface=True, max_length=77):
56
  super().__init__()
57
  from transformers import BertTokenizerFast # TODO: add to reuquirements
58
  self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
@@ -80,7 +82,7 @@ class BERTTokenizer(AbstractEncoder):
80
  class BERTEmbedder(AbstractEncoder):
81
  """Uses the BERT tokenizr model and add some transformer encoder layers"""
82
  def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84
  super().__init__()
85
  self.use_tknz_fn = use_tokenizer
86
  if self.use_tknz_fn:
@@ -136,7 +138,7 @@ class SpatialRescaler(nn.Module):
136
 
137
  class FrozenCLIPEmbedder(AbstractEncoder):
138
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140
  super().__init__()
141
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
142
  self.transformer = CLIPTextModel.from_pretrained(version)
@@ -163,7 +165,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
163
 
164
  class FrozenCLIPEmbedderBoth(AbstractEncoder):
165
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
166
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, antialias=False,):
167
  super().__init__()
168
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
169
  self.text_transformer = CLIPTextModel.from_pretrained(version)
@@ -217,7 +219,7 @@ class FrozenCLIPEmbedderBoth(AbstractEncoder):
217
 
218
  class CLIPEmbedderWithLearnableTokens(AbstractEncoder):
219
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
220
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, num_learnable_tokens=3):
221
  super().__init__()
222
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
223
  self.transformer = CLIPTextModel.from_pretrained(version)
@@ -253,7 +255,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
253
  """
254
  Uses the CLIP transformer encoder for text.
255
  """
256
- def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
257
  super().__init__()
258
  self.model, _ = clip.load(version, jit=False, device="cpu")
259
  self.device = device
@@ -289,7 +291,7 @@ class FrozenClipImageEmbedder(nn.Module):
289
  self,
290
  model,
291
  jit=False,
292
- device='cuda' if torch.cuda.is_available() else 'cpu',
293
  antialias=False,
294
  ):
295
  super().__init__()
@@ -319,4 +321,4 @@ if __name__ == "__main__":
319
  from ldm.util import count_params
320
  model = FrozenCLIPEmbedderBoth()
321
  breakpoint()
322
- count_params(model, verbose=True)
 
5
  from einops import rearrange, repeat
6
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPVisionModel, CLIPModel
7
  import kornia
8
+ import devicetorch
9
 
10
  from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
11
 
12
+ DEVICE = devicetorch.get(torch)
13
 
14
  class AbstractEncoder(nn.Module):
15
  def __init__(self):
 
37
 
38
  class TransformerEmbedder(AbstractEncoder):
39
  """Some transformer encoder layers"""
40
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=DEVICE):
41
  super().__init__()
42
  self.device = device
43
  self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
 
54
 
55
  class BERTTokenizer(AbstractEncoder):
56
  """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
57
+ def __init__(self, device=DEVICE, vq_interface=True, max_length=77):
58
  super().__init__()
59
  from transformers import BertTokenizerFast # TODO: add to reuquirements
60
  self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
 
82
  class BERTEmbedder(AbstractEncoder):
83
  """Uses the BERT tokenizr model and add some transformer encoder layers"""
84
  def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
85
+ device=DEVICE,use_tokenizer=True, embedding_dropout=0.0):
86
  super().__init__()
87
  self.use_tknz_fn = use_tokenizer
88
  if self.use_tknz_fn:
 
138
 
139
  class FrozenCLIPEmbedder(AbstractEncoder):
140
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
141
+ def __init__(self, version="openai/clip-vit-large-patch14", device=DEVICE, max_length=77):
142
  super().__init__()
143
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
144
  self.transformer = CLIPTextModel.from_pretrained(version)
 
165
 
166
  class FrozenCLIPEmbedderBoth(AbstractEncoder):
167
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
168
+ def __init__(self, version="openai/clip-vit-large-patch14", device=DEVICE, max_length=77, antialias=False,):
169
  super().__init__()
170
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
171
  self.text_transformer = CLIPTextModel.from_pretrained(version)
 
219
 
220
  class CLIPEmbedderWithLearnableTokens(AbstractEncoder):
221
  """Uses the CLIP transformer encoder for text (from Hugging Face)"""
222
+ def __init__(self, version="openai/clip-vit-large-patch14", device=DEVICE, max_length=77, num_learnable_tokens=3):
223
  super().__init__()
224
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
225
  self.transformer = CLIPTextModel.from_pretrained(version)
 
255
  """
256
  Uses the CLIP transformer encoder for text.
257
  """
258
+ def __init__(self, version='ViT-L/14', device=DEVICE, max_length=77, n_repeat=1, normalize=True):
259
  super().__init__()
260
  self.model, _ = clip.load(version, jit=False, device="cpu")
261
  self.device = device
 
291
  self,
292
  model,
293
  jit=False,
294
+ device=DEVICE,
295
  antialias=False,
296
  ):
297
  super().__init__()
 
321
  from ldm.util import count_params
322
  model = FrozenCLIPEmbedderBoth()
323
  breakpoint()
324
+ count_params(model, verbose=True)