Batch
normalization is a recent idea proposed in [1] to improve training
procedures in deep neural networks (and related models). I’m a huge fan
of this idea because (a) its ridiculously simple, yet incredibly
powerful and sheds a lot of light on the difficulties of training deep
nets, and (2) the improvements in training it induces are quite
outstanding.
Like I said, BN is embarrassingly simple: normalize the inputs to nonlinearities in every hidden layer. Thats it. Seriously.
To go into a little more detail. (Almost) every hidden layer in a network is of the form:
where
is some nonlinear function (say a RELU), is the weight matrix associated with the layer, and is the output of that layer (dropping bias terms for simplicity). Let’s call BN proposes normalizing
as:
where
are the first and second moments of
respectively. During training we use the empirical moments for every
training batch. There are some extra elements used in practice to
improve the expressiveness of a normalized batch and to allow this to
work during test time, but the above is the core of BN.
The
general intuition as to why BN is so effective is as follows. Training
deep models is almost always done with gradient based procedures. Every
weight is being adjusted according to its gradient under the assumption
that all other weights will not change. In practice we change all
weights in every iteration. Importantly, changing the weights of layer
changes the distribution of the inputs to layer making any assumptions of the gradient step to the weights of layer
pretty weak, especially at the beginning of training where changes can
be dramatic. This greatly complicates training, and makes convergence
difficult and slow. By applying BN between layers, we are in a sense
enforcing that the inputs to every layer always be close to a standard
normal distribution. This eases the dependencies between the layers, and
means that the changes made are not counteracted by changes at previous
layers. That is a very rough intuition mind you, for more in-depth
explanations I definitely recommend reading [1].
This
works unbelievably well. In practice, training is an order of magnitude
faster (measured in training epochs), is much smoother, and training
can be done with much larger learning rates. Its really unbelievable how
well this works. My own research has to do with deep generative models,
where we typically train a number of dependent deep (Bayesian) nets
jointly. In my experience BN has been key to enable us to train the
bigger, more complex models.
Normalize the activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
0 Comments