Update utils.py
Browse files
utils.py
CHANGED
|
@@ -7,11 +7,14 @@ from constants import (
|
|
| 7 |
HF_TOKEN,
|
| 8 |
MODEL_TYPE_CLASS,
|
| 9 |
DIRECTORY_LORAS,
|
|
|
|
| 10 |
)
|
| 11 |
from huggingface_hub import HfApi
|
|
|
|
| 12 |
from diffusers import DiffusionPipeline
|
| 13 |
from huggingface_hub import model_info as model_info_data
|
| 14 |
from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
|
|
|
|
| 15 |
from pathlib import PosixPath
|
| 16 |
from unidecode import unidecode
|
| 17 |
import urllib.parse
|
|
@@ -283,10 +286,15 @@ def get_model_type(repo_id: str):
|
|
| 283 |
api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
|
| 284 |
default = "SD 1.5"
|
| 285 |
try:
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
except Exception:
|
| 291 |
return default
|
| 292 |
return default
|
|
@@ -371,17 +379,23 @@ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "mai
|
|
| 371 |
if len(variant_filenames):
|
| 372 |
variant = "fp16"
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
if isinstance(cached_folder, PosixPath):
|
| 387 |
cached_folder = cached_folder.as_posix()
|
|
|
|
| 7 |
HF_TOKEN,
|
| 8 |
MODEL_TYPE_CLASS,
|
| 9 |
DIRECTORY_LORAS,
|
| 10 |
+
DIFFUSECRAFT_CHECKPOINT_NAME,
|
| 11 |
)
|
| 12 |
from huggingface_hub import HfApi
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
from diffusers import DiffusionPipeline
|
| 15 |
from huggingface_hub import model_info as model_info_data
|
| 16 |
from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
|
| 17 |
+
from stablepy.diffusers_vanilla.utils import checkpoint_model_type
|
| 18 |
from pathlib import PosixPath
|
| 19 |
from unidecode import unidecode
|
| 20 |
import urllib.parse
|
|
|
|
| 286 |
api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
|
| 287 |
default = "SD 1.5"
|
| 288 |
try:
|
| 289 |
+
if os.path.exists(repo_id):
|
| 290 |
+
tag = checkpoint_model_type(repo_id)
|
| 291 |
+
return DIFFUSECRAFT_CHECKPOINT_NAME[tag]
|
| 292 |
+
else:
|
| 293 |
+
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
| 294 |
+
tags = model.tags
|
| 295 |
+
for tag in tags:
|
| 296 |
+
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
|
| 297 |
+
|
| 298 |
except Exception:
|
| 299 |
return default
|
| 300 |
return default
|
|
|
|
| 379 |
if len(variant_filenames):
|
| 380 |
variant = "fp16"
|
| 381 |
|
| 382 |
+
if model_type == "FLUX":
|
| 383 |
+
cached_folder = snapshot_download(
|
| 384 |
+
repo_id=repo_name,
|
| 385 |
+
allow_patterns="transformer/*"
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
cached_folder = DiffusionPipeline.download(
|
| 389 |
+
pretrained_model_name=repo_name,
|
| 390 |
+
force_download=False,
|
| 391 |
+
token=token,
|
| 392 |
+
revision=revision,
|
| 393 |
+
# mirror="https://hf-mirror.com",
|
| 394 |
+
variant=variant,
|
| 395 |
+
use_safetensors=True,
|
| 396 |
+
trust_remote_code=False,
|
| 397 |
+
timeout=5.0,
|
| 398 |
+
)
|
| 399 |
|
| 400 |
if isinstance(cached_folder, PosixPath):
|
| 401 |
cached_folder = cached_folder.as_posix()
|