|
from geffnet import config |
|
from geffnet.activations.activations_me import * |
|
from geffnet.activations.activations_jit import * |
|
from geffnet.activations.activations import * |
|
import torch |
|
|
|
_has_silu = 'silu' in dir(torch.nn.functional) |
|
|
|
_ACT_FN_DEFAULT = dict( |
|
silu=F.silu if _has_silu else swish, |
|
swish=F.silu if _has_silu else swish, |
|
mish=mish, |
|
relu=F.relu, |
|
relu6=F.relu6, |
|
sigmoid=sigmoid, |
|
tanh=tanh, |
|
hard_sigmoid=hard_sigmoid, |
|
hard_swish=hard_swish, |
|
) |
|
|
|
_ACT_FN_JIT = dict( |
|
silu=F.silu if _has_silu else swish_jit, |
|
swish=F.silu if _has_silu else swish_jit, |
|
mish=mish_jit, |
|
) |
|
|
|
_ACT_FN_ME = dict( |
|
silu=F.silu if _has_silu else swish_me, |
|
swish=F.silu if _has_silu else swish_me, |
|
mish=mish_me, |
|
hard_swish=hard_swish_me, |
|
hard_sigmoid_jit=hard_sigmoid_me, |
|
) |
|
|
|
_ACT_LAYER_DEFAULT = dict( |
|
silu=nn.SiLU if _has_silu else Swish, |
|
swish=nn.SiLU if _has_silu else Swish, |
|
mish=Mish, |
|
relu=nn.ReLU, |
|
relu6=nn.ReLU6, |
|
sigmoid=Sigmoid, |
|
tanh=Tanh, |
|
hard_sigmoid=HardSigmoid, |
|
hard_swish=HardSwish, |
|
) |
|
|
|
_ACT_LAYER_JIT = dict( |
|
silu=nn.SiLU if _has_silu else SwishJit, |
|
swish=nn.SiLU if _has_silu else SwishJit, |
|
mish=MishJit, |
|
) |
|
|
|
_ACT_LAYER_ME = dict( |
|
silu=nn.SiLU if _has_silu else SwishMe, |
|
swish=nn.SiLU if _has_silu else SwishMe, |
|
mish=MishMe, |
|
hard_swish=HardSwishMe, |
|
hard_sigmoid=HardSigmoidMe |
|
) |
|
|
|
_OVERRIDE_FN = dict() |
|
_OVERRIDE_LAYER = dict() |
|
|
|
|
|
def add_override_act_fn(name, fn): |
|
global _OVERRIDE_FN |
|
_OVERRIDE_FN[name] = fn |
|
|
|
|
|
def update_override_act_fn(overrides): |
|
assert isinstance(overrides, dict) |
|
global _OVERRIDE_FN |
|
_OVERRIDE_FN.update(overrides) |
|
|
|
|
|
def clear_override_act_fn(): |
|
global _OVERRIDE_FN |
|
_OVERRIDE_FN = dict() |
|
|
|
|
|
def add_override_act_layer(name, fn): |
|
_OVERRIDE_LAYER[name] = fn |
|
|
|
|
|
def update_override_act_layer(overrides): |
|
assert isinstance(overrides, dict) |
|
global _OVERRIDE_LAYER |
|
_OVERRIDE_LAYER.update(overrides) |
|
|
|
|
|
def clear_override_act_layer(): |
|
global _OVERRIDE_LAYER |
|
_OVERRIDE_LAYER = dict() |
|
|
|
|
|
def get_act_fn(name='relu'): |
|
""" Activation Function Factory |
|
Fetching activation fns by name with this function allows export or torch script friendly |
|
functions to be returned dynamically based on current config. |
|
""" |
|
if name in _OVERRIDE_FN: |
|
return _OVERRIDE_FN[name] |
|
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) |
|
if use_me and name in _ACT_FN_ME: |
|
|
|
|
|
return _ACT_FN_ME[name] |
|
if config.is_exportable() and name in ('silu', 'swish'): |
|
|
|
return swish |
|
use_jit = not (config.is_exportable() or config.is_no_jit()) |
|
|
|
if use_jit and name in _ACT_FN_JIT: |
|
return _ACT_FN_JIT[name] |
|
return _ACT_FN_DEFAULT[name] |
|
|
|
|
|
def get_act_layer(name='relu'): |
|
""" Activation Layer Factory |
|
Fetching activation layers by name with this function allows export or torch script friendly |
|
functions to be returned dynamically based on current config. |
|
""" |
|
if name in _OVERRIDE_LAYER: |
|
return _OVERRIDE_LAYER[name] |
|
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) |
|
if use_me and name in _ACT_LAYER_ME: |
|
return _ACT_LAYER_ME[name] |
|
if config.is_exportable() and name in ('silu', 'swish'): |
|
|
|
return Swish |
|
use_jit = not (config.is_exportable() or config.is_no_jit()) |
|
|
|
if use_jit and name in _ACT_FN_JIT: |
|
return _ACT_LAYER_JIT[name] |
|
return _ACT_LAYER_DEFAULT[name] |
|
|
|
|
|
|