|
import inspect |
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union |
|
|
|
from .activation_checkpoint import apply_activation_checkpointing |
|
from .data import determine_batch_size, should_perform_precomputation |
|
from .diffusion import ( |
|
_enable_vae_memory_optimizations, |
|
default_flow_shift, |
|
get_scheduler_alphas, |
|
get_scheduler_sigmas, |
|
prepare_loss_weights, |
|
prepare_sigmas, |
|
prepare_target, |
|
resolution_dependent_timestep_flow_shift, |
|
) |
|
from .file import delete_files, find_files, string_to_filename |
|
from .hub import save_model_card |
|
from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous |
|
from .model import resolve_component_cls |
|
from .state_checkpoint import PTDCheckpointManager |
|
from .torch import ( |
|
align_device_and_dtype, |
|
clip_grad_norm_, |
|
enable_determinism, |
|
expand_tensor_dims, |
|
get_device_info, |
|
set_requires_grad, |
|
synchronize_device, |
|
unwrap_model, |
|
) |
|
|
|
|
|
def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]: |
|
if method_name is not None: |
|
obj = getattr(obj, method_name) |
|
return {name for name, _ in inspect.signature(obj).parameters.items()} |
|
|
|
|
|
def get_non_null_items( |
|
x: Union[List[Any], Tuple[Any], Dict[str, Any]] |
|
) -> Union[List[Any], Tuple[Any], Dict[str, Any]]: |
|
if isinstance(x, dict): |
|
return {k: v for k, v in x.items() if v is not None} |
|
if isinstance(x, (list, tuple)): |
|
return type(x)(v for v in x if v is not None) |
|
|