from flax import linen as nn import jax import jax.numpy as jnp class LocalResponseNorm(nn.Module): @nn.compact def __call__( self, value: jax.Array ) -> jax.Array: return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1)