Problems in getting the derivatives of the batch normalization layer

106 Views Asked by At

I'm working on understanding the math used in the batch normalization layer in the CNN and found the original paper discussing this trick: Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

The backpropagation process goes like this: $$ \frac{\partial l}{\partial \hat x_i} = \frac{\partial l}{\partial y_i}\cdot \gamma \\ \frac{\partial l}{\partial \sigma_{\cal B}^2} = \sum_{i=1}^m\frac{\partial l}{\partial \hat x_i}\cdot (x_i - \mu_{\cal B})\cdot \frac{-1}2(\sigma_{\cal B}^2 + \epsilon)^{-3/2} \\ \frac{\partial l}{\partial \mu_{\cal B}} = \left(\sum_{i = 1}^m\frac{\partial l}{\partial \hat x_i}\cdot \frac{-1}{\sqrt{\sigma_{\cal B}^2 + \epsilon}}\right) + \frac{\partial l}{\partial \sigma^2_{\cal B}}\cdot\frac{\sum_{i=1}^m - 2(x_i - \mu_{\cal B})}{m} \\ \frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial \hat x_i} \cdot \frac{1}{\sqrt{\sigma^2_{\cal B} + \epsilon}} + \frac{\partial l}{\partial \sigma^2_{\cal B}}\cdot \frac{2(x_i - \mu_{\cal B})}{m} + \frac{\partial l}{\partial \mu_{\cal B}}\cdot \frac 1 m $$ My question is: since $$\hat x_i = \frac{x_i - \mu_{\cal B}}{\sqrt{\sigma_{\cal B}^2 + \epsilon}} = \frac{x_i - \mu_{\cal B}(x_1, x_2, \cdots, x_i, \cdots, x_m)}{\sqrt{\sigma_{\cal B}^2(x_1, x_2, \cdots, x_i, \cdots, x_m) + \epsilon}} \\ \frac{\partial \mu_{\cal B}}{\partial x_i} = \frac1m \\ \frac{\partial \sigma_{\cal B}^2}{\partial x_i} = \frac{2(x_i - \mu_{\cal B})}{m} $$

Why couldn't I compute $\frac{\partial l}{\partial x_i}$ directly by $$ \frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial \hat x_i}\cdot \frac{\partial \hat x_i}{\partial x_i} \\ \frac{\partial \hat x_i}{\partial x_i} = \frac{(1 - 1/m)\sqrt{\sigma^2_{\cal B} + \epsilon} - (x_i - \mu_{\cal B})\cdot \frac 1 2 \cdot\frac{1}{\sqrt{\sigma_{\cal B}^2+ \epsilon}}\cdot \frac{2(x_i - \mu_{\cal B})}{m}}{\sigma^2_{\cal B} + \epsilon}$$

The summarization component in the partial derivatives w.r.t. mean and variance will be lost. What's wrong?