Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from ldm.modules.x_transformer import Encoder, TransformerWrapper | |
class BERTTokenizer(nn.Module): | |
def __init__(self, vq_interface=True, max_length=77): | |
super().__init__() | |
from transformers import BertTokenizerFast | |
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |
self.vq_interface = vq_interface | |
self.max_length = max_length | |
def forward(self, text, return_batch_encoding=False): | |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"] | |
if return_batch_encoding: | |
return tokens, batch_encoding | |
return tokens | |
def encode(self, text): | |
tokens = self(text) | |
if not self.vq_interface: | |
return tokens | |
return None, None, [None, None, tokens] | |
def decode(self, text): | |
return text | |
class BERTEmbedder(nn.Module): | |
"""Uses the BERT tokenizer model and adds some transformer encoder layers""" | |
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, use_tokenizer=True, embedding_dropout=0.0): | |
super().__init__() | |
self.use_tknz_fn = use_tokenizer | |
if self.use_tknz_fn: | |
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) | |
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, | |
attn_layers=Encoder(dim=n_embed, depth=n_layer), | |
emb_dropout=embedding_dropout) | |
def forward(self, cond, text): | |
assert cond is None # Not supported for now (LDM conditioning key == "concat") | |
if self.use_tknz_fn: | |
tokens = self.tknz_fn(text) | |
if next(self.transformer.parameters()).is_cuda: | |
tokens = tokens.cuda() | |
else: | |
tokens = text | |
z = self.transformer(tokens, return_embeddings=True) # Size: [batch_size, max_seq_len, n_embed] | |
return z |