How to estimate the gradient of an argmin function.

301 Views Asked by At

Suppose we have a family of functions $\{f_{\theta}(x)\}_{\theta}$ parameterized by $\theta$.

Now, let $x^*(\theta):=\arg\min_x f_\theta(x)$ be the minimum point for $f_\theta$.

Suppose the function $x^*(\theta)$ is differentiable (reminded by David), my question is: how to estimate the gradient $\frac{d x^*(\theta)}{d \theta}$ ?

I did some surveys and found this gradient (in fact any gradient) can be estimated by the finite difference technique. However, it is rather slow when $\theta$ includes many parameters, e.g., $f_\theta$ is a neural network.

Any possible approach is welcomed.


Update: supplement some details for Hyperplane's answer, in case anyone else is interested in this question.

  • IFT (Implicit Function Theorem): here and here.

  • Regularity conditions under which $x^*(\theta)$ is differentiable: here and Corollary. 2 in here.

  • The Hessian vector product trick: here. Why this trick is helpful for iterative solvers: here.

  • Common methods to solve a linear equation (system): here.

  • Use LU decomposition to solve a linear equation here.

  • Details for iterative methods, e.g. CG, MINRS, and GMRES: here and here

1

There are 1 best solutions below

0
On BEST ANSWER

This is typically done via the implicit function theorem. Given suitable regularity conditions, assume $x^*(\theta)$ minimizes $f(x, \theta)$. Then

$$\begin{aligned} 0 &= \frac{∂f(x^*(\theta), \theta)}{∂x} \\ ⟹ 0 &= \frac{}{\theta}\frac{∂f(x^*(\theta), \theta)}{∂x} \\ ⟹ 0 &= \frac{∂}{∂x}\frac{f(x^*(\theta), \theta)}{\theta} \\ &= \frac{∂}{∂x}\Big(\frac{∂f}{∂x}\frac{x^*}{\theta} + \frac{∂f}{∂\theta}\frac{\theta}{\theta}\Big) \\ ⟹ &\;\boxed{\frac{∂^2f}{∂x^2}\frac{x^*}{\theta}=-\frac{∂^2f}{∂x∂\theta}} \end{aligned}$$

This linear equation can be solved with iterative techniques without ever computing the full Hessian $\frac{∂^2f}{∂x^2}$ explicitly, by making use of Hessian Vector Products: $\frac{∂^2f}{∂x^2}v = \frac{∂}{∂x}\Big(\frac{∂f}{∂x}v\Big)$ and an iterative solver like GMRES.

See for instance: