Ask questionsjax/stax BatchNorm: running average on the training set and l2 regularisation
Thanks for releasing this package to the research community! I have two specific questions about the BatchNorm layer in stax.
1/ The mean is supposed to be a running average on the training set in order to collect stats of feature maps during learning, so is the variance. It doesn't seem to be like this in the jax's implementation, right?
2/ When applying the l2 regularisation on a neural network with BatchNorm layers using the built-in "l2_norm" function, it also aggregates the learnable params, beta and gamma, which doesn't seem to be appropriate to me. Am I right?
Answer questions hawkinsp
Thanks for the issue report!
There's an existing bug open for the batch norm moving averages: https://github.com/google/jax/issues/139
stax is in
experimental means just that: it's an experiment in how to write neural network layers on top of Jax. It isn't intended to be the last word on the topic. We expect that we will replace it with a Stax 2 before long incorporating some of the things we have learnt from the current Stax, and this is definitely on the list of things to fix.
Note however that
stax is only a tiny amount of code:
If it doesn't do what you want, you should feel free to copy, remix, and adapt it to your needs. We want to get away from the having monolithic frameworks that researchers can't change. As an example, the Trax project in tensor2tensor has built their own variant of layers that work on top of Jax: https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax/layers
I'm not saying it is the final word in layers libraries either, but you should not feel limited to stax as it stands today.
Hope that helps!