Calculating gradient of scalar with vector jacobian products

318 Views Asked by At

I am trying to calculate the $\phi$ gradient of the scalar \begin{align*} F = \sum_{i=0}^nf^T\dfrac{\partial h_i}{\partial \theta}^T\lambda_i \tag{1} \end{align*} $f, \theta \in \mathcal{R}^{m}$, $h_i \in \mathcal{R}^k$, $h_i$ is a shorthand for $h_i(x_i, \theta)$ where $x_i \in \mathcal{R}^k$. $\lambda_i \in \mathcal{R}^k$ and are defined by \begin{align*} &\lambda_{n} = 0 \\ &\lambda_{i-1} = \dfrac{\partial h_i}{\partial x_i}^T\lambda_i + \dfrac{\partial p_i}{\partial x_i}^T\beta_i(\phi) \end{align*} where $\beta_i$ and $p_i = p_i(x_i)$ are scalars.

Supposedly, using vector-jacobian products (i.e. reverse mode differentiation) the gradient $\nabla_{\phi}F$ can be calculated in some constant times complexity as $F$. However, after much trying I am unable to express the gradient using only vector-jacobian products. I can only get it down to either matrix-jacobian products, or jacobian-vector products. Either way, this means calculating the gradient would scale with the dimension of $x$, meaning it would not be constant, but scale with $k$. This leaves me confused as to whether I am doing something wrong, or there is some sort of caveat to reverse differentiation that I'm not aware of. I highly suspect the former but I'm not really sure what to do differently.

Gradient calculation

Method A

Using \begin{align*} \dfrac{\partial \lambda_{i-1}}{\partial \phi} = \dfrac{\partial h_i}{\partial x_i}^T\dfrac{\partial \lambda_i}{\partial \phi} + \dfrac{\partial p_i}{\partial x_i}^T\dfrac{\partial \beta_i}{\partial \phi} \end{align*} so that the gradient is \begin{align*} \sum_{i=0}^nf^T\dfrac{\partial h_i}{\partial \theta}^T\dfrac{\partial \lambda_i }{\partial \phi} . \end{align*} You can massage this further, but you always end up having to form some sort of matrix in $\mathcal{R}^{k\times k}$.

Method B

The gradient is \begin{align*} \sum_{i=0}^{n-1}\xi_i^T\dfrac{\partial p_{i+1}}{\partial x_{i+1}}^T\dfrac{\partial \beta_{1+i}}{\partial \phi} \end{align*} where $\xi$ are given by \begin{align*} &\xi_0^T = f^T\dfrac{\partial h_0}{\partial \theta}^T \\ & \xi_{i+1}^T = f^T\dfrac{\partial h_{i+1}}{\theta}^T + \xi_i^T\dfrac{\partial h_{i+1}}{\partial x_{i+1}}^T \end{align*} In this case you have to calculate $\partial h_{i+1}/\partial x_{i+1}$.

3

There are 3 best solutions below

0
On BEST ANSWER

TLDR: We can calculate $\dfrac{\partial F}{\partial \phi}$ efficiently, either by using jacobian vector products, or nested vector jacobian products, which are equivalent to jacobian vector products. Also, jacobian vector products aren't necessarily bad when talking about gradients of compositions of functions.

To explain this, first I'm going to cite two definitions/results from [1].

VJP vs JVP

For some function $f(x) : \mathcal{R}^m \rightarrow \mathcal{R}^n$, the vector jacobian product is defined as \begin{align*} \text{vjp}(u, f, x) := u^T\dfrac{\partial f}{\partial x} \end{align*} for some constant vector $u$. The time complexity of this vjp is \begin{align*} TIME[f, \text{vjp}(u, f, x)] <= 4 \,TIME[f] \end{align*} That is to say, the time complexity of the vjp is bounded by a constant factor of 4 times the complexity of $f$.

For the same function $f(x)$, the jacobian vector product is defined as \begin{align*} \text{jvp}(f, x, v) := \dfrac{\partial f}{\partial x}v \end{align*} The time complexity of this jvp is \begin{align*} TIME[f, \text{jvp}(f, x, v)] <= \frac{5}{2} \,TIME[f] \end{align*} At this point, if you were me when I asked this question, you're asking yourself, "wait, they both have a constant computation complexity?" You might think that jacobian vector products, also known as forward mode differentiation, are to be avoided, whereas we only want to use vector jacobian products.

Scalar objectives and when JVP go bad

To understand why JVP get a bad rap, let's consider the function $h(\theta) : \mathcal{R}^m \rightarrow \mathcal{R}$. In a common use case, we want to compute the gradient $\dfrac{\partial h}{\partial \theta}$. If we want to compute this gradient with a vjp, its simply \begin{align*} \dfrac{\partial h}{\partial \theta} = \text{vjp}(1, h, \theta) \end{align*} where $1$ is a scalar. Since we can compute this gradient with a single vjp, it follows that we can compute the gradient with the equivalent of (at most) 4 evaluations of $h$. Now lets say we try use jvp instead. \begin{align*} \dfrac{\partial h}{\partial \theta} = \dfrac{\partial h}{\partial \theta}I = \begin{bmatrix} \text{jvp}(h, \theta, e_1) & \text{jvp}(h, \theta, e_2) & \ldots & \text{jvp}(h, \theta, e_m) \end{bmatrix} \end{align*} where $I$ is the $\mathcal{R}^m$ identity matrix, and $e_i$ is a vector with 1 at index $i$ and 0s everywhere else. Now it's clear to evaluate the gradient we don't need a single jvp, but actually we need $m$ jvp. So in this context jvp go very wrong.

Relevance to the original question

As stated in the question, we have \begin{align*} \dfrac{\partial F}{\partial \phi} = \sum_{i=0}^{n-1}\xi_i^T\dfrac{\partial p_{i+1}}{\partial x_{i+1}}^T\dfrac{\partial \beta_{1+i}}{\partial \phi} \end{align*} where $\xi$ are given by \begin{align*} &\xi_0^T = f^T\dfrac{\partial h_0}{\partial \theta}^T \\ & \xi_{i+1}^T = f^T\dfrac{\partial h_{i+1}}{\theta}^T + \xi_i^T\dfrac{\partial h_{i+1}}{\partial x_{i+1}}^T \end{align*} Now the $\xi$ can be efficiently calculated using jvp. \begin{align*} & \xi_{i} = \text{jvp}(h_i, \theta, f) + \text{jvp}(h_i, x_i, \xi_{i-1}) \end{align*} Then gradient is simply \begin{align*} \sum_{i=0}^{n-1} \langle\text{vjp}(1, p_{i+1}, x_{i}), \xi_i\rangle\text{vjp}(1, \beta_{i+1}, \phi) \end{align*} So in this case, it's simple to calculate $\xi$ with jvp, but not with vjp.

As a final point, I'll mention that Hyperplane's answer could also work - this is because two nested vjp is equivalent to jvp. \begin{align*} \text{vjp}(c, \text{vjp}(y, g, \phi), y) = \text{jvp}(g, \phi, c)^T \end{align*} So as an alternative to using jvp, you can use nested vjp instead.

Reference

[1] "Evaluating Derivatives: principles and techniques of algorithmic differentiation" by Griewank, Walther (2008). Chapter 3.

1
On

There is nothing special about $F$, it is a fact that computing the gradient of any scalar function in reverse mode is comparable to evaluating the function. You are probably just getting irritated by the reverse indexing, additional derivatives present and non-standard notation.

To simplify notation, let $A_i = \big(\frac{\partial h_i}{\partial\theta}\big)^T$, $B_i = \big(\frac{\partial h_{i}}{\partial x_{i}}\big)^T$, $v_i =\big(\frac{\partial p_{i}}{\partial x_{i}}\big)^T$. So we want

$$ \frac{\partial}{\partial \phi} \sum_{i=n}^0 F_i, \qquad F_i = f^TA_i \lambda_i , \qquad \lambda_{i-1} = B_i \lambda_i + v_i \beta_i $$

Then

$$\begin{aligned} \frac{\partial}{\partial \phi} F_i &= \frac{\partial f^TA_i \lambda_i}{\partial \phi} \\&= \frac{\partial f^TA_i \lambda_i}{\partial \lambda_i}\frac{\partial \lambda_i}{\partial \phi} \\&= \frac{\partial f^TA_i \lambda_i}{\partial \lambda_i}\frac{\partial B_i \lambda_{i+1} + v_{i+1} \beta_{i+1}}{\partial \phi} \\&= \frac{\partial f^TA_i \lambda_i}{\partial \lambda_i}\Big(\frac{\partial B_i\lambda_{i+1}}{\partial \lambda_{i+1}}\frac{\partial\lambda_{i+1}}{\partial \phi}+ \frac{\partial v_{i+1}\beta_{i+1}}{\partial \beta_{i+1}}\frac{\partial \beta_{i+1}}{\partial \phi}\Big) \\&= \frac{\partial f^TA_i \lambda_i}{\partial \lambda_i}\Big(\frac{\partial B_i\lambda_{i+1}}{\partial \lambda_{i+1}}\Big(\frac{\partial B_{i+1}\lambda_{i+2}}{\partial \lambda_{i+2}}\frac{\partial\lambda_{i+2}}{\partial \phi}+ \frac{\partial v_{i+2}\beta_{i+2}}{\partial \beta_{i+2}}\frac{\partial \beta_{i+2}}{\partial \phi}\Big)+ \frac{\partial v_{i+1}\beta_{i+1}}{\partial \beta_{i+1}}\frac{\partial \beta_{i+1}}{\partial \phi}\Big) \\&=\ldots \end{aligned}$$

And then we just accumulate gradients from the left: $z^T_0 = \frac{\partial f^TA_i \lambda_i}{\partial \lambda_i}$ is a row vector. Then $z_1^T = z_0^T\frac{\partial B_i\lambda_{i+1}}{\partial \lambda_{i+1}}$ is a VJP, $z_2^T = z_1^T\frac{\partial B_{i+1}\lambda_{i+2}}{\partial \lambda_{i+2}}$ is a VJP, etc. pp.

3
On

Reading you post on https://ai.stackexchange.com made the problem much clearer. Given

$$ \begin{align*} F = c^T \dfrac{\partial g}{\partial \phi}^T y \end{align*} = \langle c,\,\text{vjp}(y, g, \phi)\rangle $$

where $y=h(x, \theta)$ you do not see how to express $\frac{\partial }{\partial \phi} F$ as VJPs. But again, it's really straightforward:

$$ \frac{\partial }{\partial \theta} F = \frac{\partial }{\partial \theta} \langle c,\,\text{vjp}(y, g, \phi)\rangle = c^T \frac{\partial \text{vjp}(y, g, \phi)}{\partial y} \frac{\partial y}{\partial \theta} $$

where

  • $z_0 ^T = c^T \frac{\partial \text{vjp}(y, g, \phi)}{\partial y} = \text{vjp}(c, \text{vjp}(y,g,\phi), y)$ is the VJP of $c$ with the jacobian of the function $y\mapsto \text{vjp}(y, g, \phi)$
  • $z_1^T = z_0^T \frac{\partial y}{\partial \theta} = \text{vjp}(z_0, y, \theta)$ is the VJP of $z_0$ with the jacobian of the function $y\colon \theta \mapsto h(x, \theta)$.

You can treat $y\mapsto \text{vjp}(y, g, \phi)$ as any other vector-to-vector function in this context.