Chain rule in DARTS – Differentiable Architecture Search

249 Views Asked by At

For https://arxiv.org/pdf/1806.09055.pdf#page=4 , could anyone help to see how equation (7) is the chain rule result of equation (6) ?

chain rule

1

There are 1 best solutions below

14
On BEST ANSWER

In order to show that $\color{green}{\text{equation (7)}}$ is the chain rule result of $\color{red}{\text{equation (6)}}$, let's denote $w' = w-\xi \nabla_w \mathcal{L}_{train}(w, \alpha)$

$$\begin{equation} \begin{split} \color{red}{\nabla_\alpha \mathcal{L}_{val}(w-\xi \nabla_w \mathcal{L}_{train}(w, \alpha), \alpha)} &= \nabla_\alpha \mathcal{L}_{val}(w', \alpha) \\ &= \frac{\partial \mathcal{L}_{val}(w', \alpha)}{\partial \alpha} \\ &= \frac{\partial \mathcal{L}_{val}(w', \alpha)}{\partial w'} \frac{\partial w'}{\partial \alpha} + \frac{\partial \mathcal{L}_{val}(w', \alpha)}{\partial \alpha} \\ &= -\xi \color{blue}{\frac{\partial (\nabla_w \mathcal{L}_{train} (w, \alpha))}{\partial \alpha}} \nabla_{w'} \mathcal{L}_{val} (w', \alpha) + \nabla_\alpha \mathcal{L}_{val}(w', \alpha) \\ \end{split} \end{equation}$$

$$\begin{equation} \begin{split} \color{blue}{\frac{\partial (\nabla_w \mathcal{L}_{train} (w, \alpha))}{\partial \alpha}} = \frac{\partial \frac{\partial \mathcal{L}_{train} (w, \alpha)}{\partial w}}{\partial \alpha} = \frac{\partial^2 \mathcal{L}_{train} (w, \alpha)}{\partial \alpha \; \partial w} = \nabla_\alpha (\nabla_w \mathcal{L}_{train} (w, \alpha) ) = \nabla_{\alpha, w}^2 \mathcal{L}_{train} (w, \alpha) \end{split} \end{equation}$$

$$\begin{equation} \begin{split} \color{red}{\nabla_\alpha \mathcal{L}_{val}(w-\xi \nabla_w \mathcal{L}_{train}(w, \alpha), \alpha)} &= -\xi \color{blue}{\frac{\partial (\nabla_w \mathcal{L}_{train} (w, \alpha))}{\partial \alpha}} \nabla_{w'} \mathcal{L}_{val} (w', \alpha) + \nabla_\alpha \mathcal{L}_{val}(w', \alpha) \\ &= -\xi \nabla_{\alpha, w}^2 \mathcal{L}_{train} (w, \alpha) \nabla_{w'} \mathcal{L}_{val} (w', \alpha) + \nabla_\alpha \mathcal{L}_{val}(w', \alpha) \\ &= \color{green}{\nabla_\alpha \mathcal{L}_{val}(w', \alpha) - \xi \nabla_{\alpha, w}^2 \mathcal{L}_{train} (w, \alpha) \nabla_{w'} \mathcal{L}_{val} (w', \alpha)} \end{split} \end{equation}$$