Spaces:
Paused
Paused
from typing import List | |
import torch | |
from transformers import AutoTokenizer, UMT5EncoderModel | |
from toolkit.models.loaders.comfy import get_comfy_path | |
def get_umt5_encoder( | |
model_path: str, | |
tokenizer_subfolder: str = None, | |
encoder_subfolder: str = None, | |
torch_dtype: str = torch.bfloat16, | |
comfy_files: List[str] = [ | |
"text_encoders/umt5_xxl_fp16.safetensors", | |
"text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", | |
], | |
) -> UMT5EncoderModel: | |
""" | |
Load the UMT5 encoder model from the specified path. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder) | |
comfy_path = get_comfy_path(comfy_files) | |
comfy_path = None | |
if comfy_path is not None: | |
text_encoder = UMT5EncoderModel.from_single_file( | |
comfy_path, torch_dtype=torch_dtype | |
) | |
else: | |
print(f"Using {model_path} for UMT5 encoder.") | |
text_encoder = UMT5EncoderModel.from_pretrained( | |
model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype | |
) | |
return tokenizer, text_encoder | |