Iteratively reweighted least squares for LASSO problem

67 Views Asked by At

I'm trying to solve the following (here simplified) problem (here 1D, and $x>0$):

$$ \arg \min_{x} \frac{1}{2} \left\| A x - b \right\|_{2}^{2} + \frac{1}{2} \left\| x \right\|_{1}$$

I need to solve the problem with IRLS and both terms need to have the common pre-factor $\frac{1}{2}$ to not need a special handling in the code. In the IRLS spirit the problem can be reformulated to:

$$ \arg \min_{x} \frac{1}{2} \rho_1\left(\left\| A x - b \right\|_{2}^{2}\right) + \frac{1}{2} \rho_2\left(\left\| x \right\|_{2}^2\right), \quad \rho_1(s)=s, \quad \rho_2(s)=\sqrt{s}$$

Taking the derivative we get: $$ 0 = \rho_1' A^T (Ax -b) + \rho_2' x$$

Thus we have a weighted least squares problem that we can iteratively solve:

$$ \arg \min_{x} \frac{1}{2} \left\| A x - b \right\|_{2}^{2} + \frac{1}{2} \left\| x \right\|_{1} = \arg \min_{x} \frac{1}{2} w_1 \left\| A x - b \right\|_{2}^{2} + \frac{1}{2} w_2 \left\| x \right\|_{2}^2$$ with $w_1 =\rho_1' = 1,\quad w_2(\left\| x \right\|_{2}^2) =\rho_2'(\left\| x \right\|_{2}^2) = \frac{1}{2\sqrt{\left\| x \right\|_{2}^2}} = \frac{1}{2\left\| x \right\|_{1}}$

Implementing this for: $\arg \min_{x} \frac{1}{2} \left\|x - 2 \right\|_{2}^{2} + \frac{1}{2} \left\| x \right\|_{1}$ I get the correct solution of $x=\frac{3}{2}$. Nevertheless, I get a different loss when evaluating: $ \frac{1}{2} w_1 \left\| x - 2 \right\|_{2}^{2} + \frac{1}{2} w_2 \left\| x \right\|_{2}^2 = 0.5 \neq 0.875$. Evaluating $ \frac{1}{2} w_1 \left\| x - 2 \right\|_{2}^{2} + w_2 \left\| x \right\|_{2}^2 = 0.875$ gives the correct result which I don't understand. Which point in my derivation is incorrect and how can I fix this?

Below is the implemented algorithm

import numpy as np

# 1/2 ||Ax-y||^2 + 1/2 ||x||_1
# 1/2 ||x-2||^2  + 1/2 ||x||_1
A = np.ones([1, 1])
y = np.array([2])
x = np.ones([1]) * 0.5

W1 = lambda x: 1.0
W2 = lambda x: 1 / np.abs(2 * x)

for _ in range(100):
    x = np.linalg.inv(W1(x) * A.T @ A + W2(x) * np.eye(x.size)) @ A.T @ y

print("optimized x: " + str(x))
print("loss: " + str(0.5 * np.sum((A @ x - y) ** 2) + 0.5 * np.abs(x)))
print("irls loss: " + str(0.5 * W1(x) * np.sum((A @ x - y) ** 2) + 0.5 * W2(x) * np.sum(x ** 2)))  # incorrect
print("irls loss: " + str(0.5 * W1(x) * np.sum((A @ x - y) ** 2) + W2(x) * np.sum(x ** 2)))  # correct

Note that https://stats.stackexchange.com/users/6244/royi answers a similar problem in https://stats.stackexchange.com/a/299380 but without .