Spaces:
Runtime error
Runtime error
Update ldm/modules/encoders/modules.py
Browse files
ldm/modules/encoders/modules.py
CHANGED
|
@@ -310,5 +310,42 @@ class FrozenFLANEmbedder(AbstractEncoder):
|
|
| 310 |
z = outputs.last_hidden_state
|
| 311 |
return z
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
def encode(self, text):
|
| 314 |
return self(text)
|
|
|
|
| 310 |
z = outputs.last_hidden_state
|
| 311 |
return z
|
| 312 |
|
| 313 |
+
def encode(self, text):
|
| 314 |
+
return self(text)
|
| 315 |
+
|
| 316 |
+
class FrozenGlobalNormOpenCLIPEmbedder(AbstractEncoder):
|
| 317 |
+
"""
|
| 318 |
+
Uses the OpenCLIP transformer encoder for text
|
| 319 |
+
"""
|
| 320 |
+
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", freeze=True, delvisual=True):
|
| 321 |
+
super().__init__()
|
| 322 |
+
model, _, preprocess = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
| 323 |
+
if delvisual:
|
| 324 |
+
del model.visual
|
| 325 |
+
del preprocess
|
| 326 |
+
else:
|
| 327 |
+
self.preprocess = preprocess
|
| 328 |
+
self.model = model
|
| 329 |
+
|
| 330 |
+
self.device = device
|
| 331 |
+
if freeze:
|
| 332 |
+
self.freeze()
|
| 333 |
+
|
| 334 |
+
def freeze(self):
|
| 335 |
+
self.model = self.model.eval()
|
| 336 |
+
for param in self.parameters():
|
| 337 |
+
param.requires_grad = False
|
| 338 |
+
|
| 339 |
+
def forward(self, text):
|
| 340 |
+
tokens = open_clip.tokenize(text)
|
| 341 |
+
z = self.model.encode_text(tokens.to(self.device))
|
| 342 |
+
z /= z.norm(dim=-1, keepdim=True)
|
| 343 |
+
return z.unsqueeze(1)
|
| 344 |
+
|
| 345 |
+
def forward_img(self, image):
|
| 346 |
+
z = self.model.encode_image(image.to(self.device))
|
| 347 |
+
z /= z.norm(dim=-1, keepdim=True)
|
| 348 |
+
return z.unsqueeze(1)
|
| 349 |
+
|
| 350 |
def encode(self, text):
|
| 351 |
return self(text)
|