Batch Normalization equation derivation

406 Views Asked by At

In the paper, "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift," the authors note the various partial derivatives of the loss function. According to Kevin Zakka's blog post (https://kevinzakka.github.io/2016/09/14/batch_normalization/), this should be able to be reduced to:

$$ % <![CDATA[ \begin{eqnarray} \frac{\partial f}{\partial \mu} &=& \bigg(\sum\limits_{i=1}^m \frac{\partial f}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2 + \epsilon}} \bigg) + \bigg( \frac{\partial f}{\partial \sigma^2} \cdot \frac{1}{m} \sum\limits_{i=1}^m -2(x_i - \mu) \bigg) \qquad \\ &=& \bigg(\sum\limits_{i=1}^m \frac{\partial f}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2 + \epsilon}} \bigg) + \bigg( \frac{\partial f}{\partial \sigma^2} \cdot (-2) \cdot \frac{1}{m} \sum\limits_{i=1}^m x_i - \frac{1}{m} \sum\limits_{i=1}^m \mu \bigg) \qquad \\ &=& \bigg(\sum\limits_{i=1}^m \frac{\partial f}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2 + \epsilon}} \bigg) + \bigg( \frac{\partial f}{\partial \sigma^2} \cdot (-2) \cdot \mu - \frac{m \cdot \mu}{m} \bigg) \qquad \\ &=& \sum\limits_{i=1}^m \frac{\partial f}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2 + \epsilon}} \qquad \\ \end{eqnarray} %]]>$$

I just don't understand how he's getting rid of the 2nd component of the equation and keeping the first. Is the second component suppose to equal zero somehow? Thank you in advance for your expertise.

1

There are 1 best solutions below

1
On BEST ANSWER

There's an abuse of notation here. Batch normalization defines $\mu:=\frac{1}{m}\sum_{i=1}^mx_i$, so the second part should go:

$$\frac{\partial f}{\partial \sigma^2} \cdot \frac{1}{m} \sum\limits_{i=1}^m -2(x_i - \mu)=\frac{\partial f}{\partial \sigma^2} (-2)\left[\frac{1}{m}\sum_{i=1}^mx_i-\frac{1}{m}\sum_{i=1}^m\mu\right]=\frac{\partial f}{\partial \sigma^2} (-2)\left[\frac{1}{m}m\mu-\frac{1}{m}m\mu\right]=0$$