abyildirim's picture
gradio files are synced with the github repo
2d42726
import torch
import torch.nn as nn
from ldm.modules.x_transformer import Encoder, TransformerWrapper
class BERTTokenizer(nn.Module):
def __init__(self, vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text, return_batch_encoding=False):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
if return_batch_encoding:
return tokens, batch_encoding
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(nn.Module):
"""Uses the BERT tokenizer model and adds some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, cond, text):
assert cond is None # Not supported for now (LDM conditioning key == "concat")
if self.use_tknz_fn:
tokens = self.tknz_fn(text)
if next(self.transformer.parameters()).is_cuda:
tokens = tokens.cuda()
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True) # Size: [batch_size, max_seq_len, n_embed]
return z