|
import os |
|
from typing import Optional, Literal |
|
from types import ModuleType |
|
import enum |
|
from packaging import version |
|
|
|
import torch |
|
|
|
|
|
if version.parse(torch.__version__) >= version.parse("2.0.0"): |
|
SDP_IS_AVAILABLE = True |
|
else: |
|
SDP_IS_AVAILABLE = False |
|
|
|
try: |
|
import xformers |
|
import xformers.ops |
|
XFORMERS_IS_AVAILBLE = True |
|
except: |
|
XFORMERS_IS_AVAILBLE = False |
|
|
|
|
|
class AttnMode(enum.Enum): |
|
SDP = 0 |
|
XFORMERS = 1 |
|
VANILLA = 2 |
|
|
|
|
|
class Config: |
|
xformers: Optional[ModuleType] = None |
|
attn_mode: AttnMode = AttnMode.VANILLA |
|
|
|
|
|
|
|
if SDP_IS_AVAILABLE: |
|
Config.attn_mode = AttnMode.SDP |
|
print(f"use sdp attention as default") |
|
elif XFORMERS_IS_AVAILBLE: |
|
Config.attn_mode = AttnMode.XFORMERS |
|
print(f"use xformers attention as default") |
|
else: |
|
print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") |
|
|
|
if XFORMERS_IS_AVAILBLE: |
|
Config.xformers = xformers |
|
|
|
|
|
|
|
ATTN_MODE = os.environ.get("ATTN_MODE", None) |
|
if ATTN_MODE is not None: |
|
assert ATTN_MODE in ["vanilla", "sdp", "xformers"] |
|
if ATTN_MODE == "sdp": |
|
assert SDP_IS_AVAILABLE |
|
Config.attn_mode = AttnMode.SDP |
|
elif ATTN_MODE == "xformers": |
|
assert XFORMERS_IS_AVAILBLE |
|
Config.attn_mode = AttnMode.XFORMERS |
|
else: |
|
Config.attn_mode = AttnMode.VANILLA |
|
print(f"set attention mode to {ATTN_MODE}") |
|
else: |
|
print("keep default attention mode") |
|
|