jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 430 files
3cc1e25 verified
raw
history blame
1.1 kB
from typing import List
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from toolkit.models.loaders.comfy import get_comfy_path
def get_umt5_encoder(
model_path: str,
tokenizer_subfolder: str = None,
encoder_subfolder: str = None,
torch_dtype: str = torch.bfloat16,
comfy_files: List[str] = [
"text_encoders/umt5_xxl_fp16.safetensors",
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
],
) -> UMT5EncoderModel:
"""
Load the UMT5 encoder model from the specified path.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder)
comfy_path = get_comfy_path(comfy_files)
comfy_path = None
if comfy_path is not None:
text_encoder = UMT5EncoderModel.from_single_file(
comfy_path, torch_dtype=torch_dtype
)
else:
print(f"Using {model_path} for UMT5 encoder.")
text_encoder = UMT5EncoderModel.from_pretrained(
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype
)
return tokenizer, text_encoder