AIPlane2 / local_response_norm.py
PrakhAI's picture
Duplicate from PrakhAI/AIPlane
c84c172
raw
history blame contribute delete
313 Bytes
from flax import linen as nn
import jax
import jax.numpy as jnp
class LocalResponseNorm(nn.Module):
@nn.compact
def __call__(
self,
value: jax.Array
) -> jax.Array:
return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1)