|
import torch |
|
from typing import Callable, Dict, Optional, Union, List |
|
from urllib.parse import urlparse |
|
from transformers import PreTrainedModel, PreTrainedTokenizer, CLIPTokenizer |
|
import copy |
|
import random |
|
from io import BytesIO |
|
from compel.embeddings_provider import BaseTextualInversionManager |
|
|
|
class TextualInversionLoaderMixin(BaseTextualInversionManager): |
|
r""" |
|
Mixin class for adding textual inversion tokens and embeddings to the tokenizer and text encoder with method: |
|
- [`~TextualInversionLoaderMixin.load_textual_inversion_embeddings`] |
|
- [`~TextualInversionLoaderMixin.add_textual_inversion_embedding`] |
|
""" |
|
|
|
def load_textual_inversion_embeddings( |
|
self, |
|
embedding_path_dict_or_list: Union[Dict[str, str], List[Dict[str, str]]], |
|
allow_replacement: bool = False, |
|
): |
|
r""" |
|
Loads textual inversion embeddings and adds them to the tokenizer's vocabulary and the text encoder's embeddings. |
|
Arguments: |
|
embeddings_path_dict_or_list (`Dict[str, str]` or `List[str]`): |
|
Dictionary of token to embedding path or List of embedding paths to embedding dictionaries. |
|
The dictionary must have the following keys: |
|
- `token`: name of the token to be added to the tokenizers' vocabulary |
|
- `embedding`: path to the embedding of the token to be added to the text encoder's embedding matrix |
|
The list must contain paths to embedding dictionaries where the keys are the tokens and the |
|
values are the embeddings (same as above dictionary definition). |
|
allow_replacement (`bool`, *optional*, defaults to `False`): |
|
Whether to allow replacement of existing tokens in the tokenizer's vocabulary. If `False` |
|
and a token is already in the vocabulary, an error will be raised. |
|
Returns: |
|
None |
|
""" |
|
|
|
self._validate_method_call(self.load_textual_inversion_embeddings) |
|
|
|
if isinstance(embedding_path_dict_or_list, dict): |
|
for token, embedding_path in embedding_path_dict_or_list.items(): |
|
|
|
embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) |
|
embedding, is_multi_vec_token = self._extract_embedding_from_dict(embedding_dict) |
|
|
|
self._validate_token_update(token, allow_replacement, is_multi_vec_token) |
|
self.add_textual_inversion_embedding(token, embedding) |
|
elif isinstance(embedding_path_dict_or_list, list): |
|
for embedding_path in embedding_path_dict_or_list: |
|
embedding_dict = torch.load(embedding_path, map_location=self.text_encoder.device) |
|
token = self._extract_token_from_dict(embedding_dict) |
|
embedding, is_multi_vec_token = self._extract_embedding_from_dict(embedding_dict) |
|
|
|
self._validate_token_update(token, allow_replacement, is_multi_vec_token) |
|
self.add_textual_inversion_embedding(token, embedding) |
|
else: |
|
raise ValueError( |
|
f"Type {type(embedding_path_dict_or_list)} is invalid. The value passed to `embedding_path_dict_or_list` " |
|
"must be a dictionary that maps a token to it's embedding file path " |
|
"or a list of paths to embedding files containing embedding dictionaries." |
|
) |
|
|
|
def add_textual_inversion_embedding(self, token: str, embedding: torch.Tensor): |
|
r""" |
|
Adds a token to the tokenizer's vocabulary and an embedding to the text encoder's embedding matrix. |
|
Arguments: |
|
token (`str`): |
|
The token to be added to the tokenizers' vocabulary |
|
embedding (`torch.Tensor`): |
|
The embedding of the token to be added to the text encoder's embedding matrix |
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
self._validate_method_call(self.load_textual_inversion_embeddings) |
|
|
|
embedding = embedding.to(self.text_encoder.dtype) |
|
|
|
if not isinstance(self.tokenizer, MultiTokenCLIPTokenizer): |
|
if token in self.tokenizer.get_vocab(): |
|
|
|
|
|
token_id = self.tokenizer.convert_tokens_to_ids(token) |
|
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding |
|
else: |
|
|
|
|
|
self.tokenizer.add_tokens([token]) |
|
|
|
token_id = self.tokenizer.convert_tokens_to_ids(token) |
|
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) |
|
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding |
|
else: |
|
if token in self.tokenizer.token_map: |
|
|
|
|
|
|
|
indices_to_remove = [] |
|
for token_to_remove in self.tokenizer.token_map[token]: |
|
indices_to_remove.append(self.tokenizer.get_added_vocab()[token_to_remove]) |
|
|
|
|
|
self.tokenizer.added_tokens_encoder.pop(token_to_remove) |
|
|
|
|
|
indices_to_remove = torch.LongTensor(indices_to_remove) |
|
|
|
|
|
token_embeds = self.text_encoder.get_input_embeddings().weight.data |
|
indices_to_keep = torch.arange(0, token_embeds.shape[0]) |
|
indices_to_keep = indices_to_keep[indices_to_keep != indices_to_remove].squeeze() |
|
token_embeds = token_embeds[indices_to_keep] |
|
|
|
|
|
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
|
|
|
|
self.tokenizer.token_map.pop(token) |
|
|
|
|
|
embedding_dims = len(embedding.shape) |
|
num_vec_per_token = 1 if embedding_dims == 1 else embedding.shape[0] |
|
|
|
self.tokenizer.add_placeholder_tokens(token, num_vec_per_token=num_vec_per_token) |
|
self.text_encoder.resize_token_embeddings(len(self.tokenizer)) |
|
token_ids = self.tokenizer.encode(token, add_special_tokens=False) |
|
|
|
if embedding_dims > 1: |
|
for i, token_id in enumerate(token_ids): |
|
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding[i] |
|
else: |
|
self.text_encoder.get_input_embeddings().weight.data[token_ids] = embedding |
|
|
|
def _extract_embedding_from_dict(self, embedding_dict: Dict[str, str]) -> torch.Tensor: |
|
r""" |
|
Extracts the embedding from the embedding dictionary. |
|
Arguments: |
|
embedding_dict (`Dict[str, str]`): |
|
The embedding dictionary loaded from the embedding path |
|
Returns: |
|
embedding (`torch.Tensor`): |
|
The embedding to be added to the text encoder's embedding matrix |
|
is_multi_vec_token (`bool`): |
|
Whether the embedding is a multi-vector token or not |
|
""" |
|
is_multi_vec_token = False |
|
|
|
if "string_to_param" in embedding_dict: |
|
embedding_dict = embedding_dict["string_to_param"] |
|
embedding = embedding_dict["*"] |
|
else: |
|
embedding = list(embedding_dict.values())[0] |
|
|
|
if len(embedding.shape) > 1: |
|
|
|
|
|
|
|
if not isinstance(self.tokenizer, MultiTokenCLIPTokenizer): |
|
raise ValueError( |
|
f"{self.__class__.__name__} requires `self.tokenizer` of type `MultiTokenCLIPTokenizer` for loading embeddings with more than one dimension." |
|
) |
|
is_multi_vec_token = True |
|
|
|
return embedding, is_multi_vec_token |
|
|
|
def _extract_token_from_dict(self, embedding_dict: Dict[str, str]) -> str: |
|
r""" |
|
Extracts the token from the embedding dictionary. |
|
Arguments: |
|
embedding_dict (`Dict[str, str]`): |
|
The embedding dictionary loaded from the embedding path |
|
Returns: |
|
token (`str`): |
|
The token to be added to the tokenizers' vocabulary |
|
""" |
|
|
|
if "string_to_param" in embedding_dict: |
|
token = embedding_dict["name"] |
|
return token |
|
|
|
return list(embedding_dict.keys())[0] |
|
|
|
def _validate_method_call(self, method: Callable): |
|
r""" |
|
Validates that the method is being called from a class instance that has the required attributes. |
|
Arguments: |
|
method (`function`): |
|
The class's method being called |
|
Raises: |
|
ValueError: |
|
If the method is being called from a class instance that does not have |
|
the required attributes, the method will not be callable. |
|
Returns: |
|
None |
|
""" |
|
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): |
|
raise ValueError( |
|
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling `{method.__name__}`" |
|
) |
|
|
|
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): |
|
raise ValueError( |
|
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling `{method.__name__}`" |
|
) |
|
|
|
def _validate_token_update(self, token, allow_replacement=False, is_multi_vec_token=False): |
|
r"""Validates that the token is not already in the tokenizer's vocabulary. |
|
Arguments: |
|
token (`str`): |
|
The token to be added to the tokenizers' vocabulary |
|
allow_replacement (`bool`): |
|
Whether to allow replacement of the token if it already exists in the tokenizer's vocabulary |
|
is_multi_vec_token (`bool`): |
|
Whether the embedding is a multi-vector token or not |
|
Raises: |
|
ValueError: |
|
If the token is already in the tokenizer's vocabulary and `allow_replacement` is False. |
|
Returns: |
|
None |
|
""" |
|
if (not is_multi_vec_token and token in self.tokenizer.get_vocab()) or ( |
|
is_multi_vec_token and token in self.tokenizer.token_map |
|
): |
|
if allow_replacement: |
|
print( |
|
f"Token {token} already in tokenizer vocabulary. Overwriting existing token and embedding with the new one." |
|
) |
|
else: |
|
raise ValueError( |
|
f"Token {token} already in tokenizer vocabulary. Please choose a different token name." |
|
) |
|
|
|
|
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: List[int]) -> List[int]: |
|
pass |
|
|
|
class MultiTokenCLIPTokenizer(CLIPTokenizer): |
|
"""Tokenizer for CLIP models that have multi-vector tokens.""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.token_map = {} |
|
|
|
def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs): |
|
r"""Adds placeholder tokens to the tokenizer's vocabulary. |
|
Arguments: |
|
placeholder_token (`str`): |
|
The placeholder token to be added to the tokenizers' vocabulary and token map. |
|
num_vec_per_token (`int`): |
|
The number of vectors per token. Defaults to 1. |
|
*args: |
|
The arguments to be passed to the tokenizer's `add_tokens` method. |
|
**kwargs: |
|
The keyword arguments to be passed to the tokenizer's `add_tokens` method. |
|
Returns: |
|
None |
|
""" |
|
output = [] |
|
if num_vec_per_token == 1: |
|
self.add_tokens(placeholder_token, *args, **kwargs) |
|
output.append(placeholder_token) |
|
else: |
|
output = [] |
|
for i in range(num_vec_per_token): |
|
ith_token = placeholder_token + f"_{i}" |
|
self.add_tokens(ith_token, *args, **kwargs) |
|
output.append(ith_token) |
|
|
|
for token in self.token_map: |
|
if token in placeholder_token: |
|
raise ValueError( |
|
f"The tokenizer already has placeholder token {token} that can get confused with" |
|
f" {placeholder_token}keep placeholder tokens independent" |
|
) |
|
self.token_map[placeholder_token] = output |
|
|
|
def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0): |
|
r"""Replaces placeholder tokens in text with the tokens in the token map. |
|
Opttionally, implements: |
|
a) vector shuffling (https://github.com/rinongal/textual_inversion/pull/119)where |
|
shuffling tokens were found to force the model to learn the concepts more descriptively. |
|
b) proportional token loading so that not every token in the token map is loaded on each call; |
|
used as part of progressive token loading during training which can improve generalization |
|
during inference. |
|
Arguments: |
|
text (`str`): |
|
The text to be processed. |
|
vector_shuffle (`bool`): |
|
Whether to shuffle the vectors in the token map. Defaults to False. |
|
prop_tokens_to_load (`float`): |
|
The proportion of tokens to load from the token map. Defaults to 1.0. |
|
Returns: |
|
`str`: The processed text. |
|
""" |
|
if isinstance(text, list): |
|
output = [] |
|
for i in range(len(text)): |
|
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) |
|
return output |
|
for placeholder_token in self.token_map: |
|
if placeholder_token in text: |
|
tokens = self.token_map[placeholder_token] |
|
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] |
|
if vector_shuffle: |
|
tokens = copy.copy(tokens) |
|
random.shuffle(tokens) |
|
text = text.replace(placeholder_token, " ".join(tokens)) |
|
return text |
|
|
|
def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): |
|
"""Wrapper around [`~transformers.tokenization_utils.PreTrainedTokenizerBase.__call__`] method |
|
but first replace placeholder tokens in text with the tokens in the token map. |
|
Returns: |
|
[`~transformers.tokenization_utils_base.BatchEncoding`] |
|
""" |
|
return super().__call__( |
|
self.replace_placeholder_tokens_in_text( |
|
text, |
|
vector_shuffle=vector_shuffle, |
|
prop_tokens_to_load=prop_tokens_to_load, |
|
), |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): |
|
"""Wrapper around the tokenizer's [`transformers.tokenization_utils.PreTrainedTokenizerBase.encode`] method |
|
but first replaces placeholder tokens in text with the tokens in the token map. |
|
Arguments: |
|
text (`str`): |
|
The text to be encoded. |
|
*args: |
|
The arguments to be passed to the tokenizer's `encode` method. |
|
vector_shuffle (`bool`): |
|
Whether to shuffle the vectors in the token map. Defaults to False. |
|
prop_tokens_to_load (`float`): |
|
The proportion of tokens to load from the token map. Defaults to 1.0. |
|
**kwargs: |
|
The keyword arguments to be passed to the tokenizer's `encode` method. |
|
Returns: |
|
List[`int`]: sequence of ids (integer) |
|
""" |
|
return super().encode( |
|
self.replace_placeholder_tokens_in_text( |
|
text, |
|
vector_shuffle=vector_shuffle, |
|
prop_tokens_to_load=prop_tokens_to_load, |
|
), |
|
*args, |
|
**kwargs, |
|
) |
|
|