A proof involving Batch-Normalization and SGD in Neural Networks

60 Views Asked by At

I am trying to understand a proof from this paper. Consider the following setting: We train a neural network layer with SGD, that is by updating the weights according to $$w_{t+1} = w_{t} - \eta \nabla L(w_t)$$ where $\eta > 0$ is a stepsize and $\nabla L(w_t)$ is the gradient of loss $L$ w.r.t. parameters $w_t$. It can be shown that a batch-normalization layer $BN$ is rescaling invariant, i.e., if $\hat w = \frac{w}{\lVert w \rVert_2}$ is the normalized version of w, then we have

$$BN(\lVert w\rVert_2 \hat w x) = BN(\hat w x).$$ As a consequence, one can derive that $$\frac{dBN(\lVert w\rVert_2 \hat w x)}{d\lVert w\rVert_2 \hat w} = \frac{1}{\lVert w\rVert_2} \frac{dBN(\hat w x)}{d\hat w}$$

I am trying to resolve the following equation, at which I am stuck in understanding the proof: Let $\rho_t = \lVert w \rVert_2$. Then: $$\rho_t \sqrt{1-2\eta\rho_t^{-2}\hat w_t^T\nabla L(\hat w_t) + \eta^2 \rho_t^{-4} \lVert \nabla L(\hat w_t)\rVert_2^2} = \rho_t - \eta \rho_t^{-1}\hat w_t \nabla L(\hat w_t) + \mathcal{O}(\eta^2)$$

I have a hard time understanding how the left side can be transformed into the right side. I feel like there must be some terms vanishing due to the $\mathcal O$ asymptotic notation, however I cannot figure out how. Any help is greatly appreciated! If there are more things to be clarified, please let me know in the comments. Thanks a lot!

1

There are 1 best solutions below

1
On BEST ANSWER

Start from $\mathbf{w}_{t+1} = \mathbf{w}_{t} - \eta \nabla L(\mathbf{w}_t)$

Taking the norm $$ \|\mathbf{w}_{t+1}\|^2 = \|\mathbf{w}_{t}\|^2 -2 \eta \mathbf{w}_{t}^T \nabla L(\mathbf{w}_t) + \eta^2 \| \nabla L(\mathbf{w}_t) \|^2 $$ Introducing the $\rho$ will give \begin{eqnarray} \rho_{t+1}^2 &=& \rho_{t}^2 - 2 \eta \mathbf{w}_{t}^T \nabla L(\mathbf{w}_t) + \eta^2 \| \nabla L(\mathbf{w}_t) \|^2 \\ &=& \rho_{t}^2 \left[1 - 2 \eta \rho_{t}^{-2} \mathbf{w}_{t}^T \nabla L(\mathbf{w}_t) + \eta^2 \rho_{t}^{-2} \| \nabla L(\mathbf{w}_t) \|^2 \right] \end{eqnarray}

Note: I do not see why there is a power -4 in the paper. But this is unimportant for your question.

The rest is, I think, a simple Taylor approximation $(1+x)^{1/2} \simeq 1+\frac12 x$ where the $\eta^2$ term is neglected

\begin{eqnarray} \rho_{t+1} &\simeq& \rho_{t} \left[1 - \frac12 \cdot 2 \eta \rho_{t}^{-2} \mathbf{w}_{t}^T \nabla L(\mathbf{w}_t) \right] \end{eqnarray}