File size: 313 Bytes
c84c172
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
from flax import linen as nn
import jax
import jax.numpy as jnp

class LocalResponseNorm(nn.Module):
  @nn.compact
  def __call__(
      self,
      value: jax.Array
  ) -> jax.Array:
    return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1)