profile
viewpoint

Ask questionsjax/stax BatchNorm: running average on the training set and l2 regularisation

Hi all,

Thanks for releasing this package to the research community! I have two specific questions about the BatchNorm layer in stax.

1/ The mean is supposed to be a running average on the training set in order to collect stats of feature maps during learning, so is the variance. It doesn't seem to be like this in the jax's implementation, right?

2/ When applying the l2 regularisation on a neural network with BatchNorm layers using the built-in "l2_norm" function, it also aggregates the learnable params, beta and gamma, which doesn't seem to be appropriate to me. Am I right?

Thanks again!

google/jax

Answer questions hawkinsp

Thanks for the issue report!

There's an existing bug open for the batch norm moving averages: https://github.com/google/jax/issues/139

In general stax is in jax.experimental and experimental means just that: it's an experiment in how to write neural network layers on top of Jax. It isn't intended to be the last word on the topic. We expect that we will replace it with a Stax 2 before long incorporating some of the things we have learnt from the current Stax, and this is definitely on the list of things to fix.

Note however that stax is only a tiny amount of code: https://github.com/google/jax/blob/master/jax/experimental/stax.py If it doesn't do what you want, you should feel free to copy, remix, and adapt it to your needs. We want to get away from the having monolithic frameworks that researchers can't change. As an example, the Trax project in tensor2tensor has built their own variant of layers that work on top of Jax: https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax/layers I'm not saying it is the final word in layers libraries either, but you should not feel limited to stax as it stands today.

Hope that helps!

useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Installing from source using Conda and CUDA could be improved hot 1
Clear GPU memory hot 1
Github User Rank List