Learning Rotation SO(3) and translations on signed distance functions

45 Views Asked by At

Background

Signed distance function (SDF) is a function $f(x): \mathbb{R}^3\mapsto \mathbb{R}$, which describe the signed distance of $x\in \mathbb{R}^3$ to the 3D surfaces $S$. I managed to approximate an SDF function through the implicit neural fields $f_\phi(x)$ in Implicit Neural Representations with Periodic Activation Functions.

Now the $f_\phi(x)$ can be applied on $X\in \mathbb{R}^3$, and is $C^2$ diffentiable (at least in my experiment).

Question

Given a signed distance function $f(x):\mathbb{R}^3\mapsto \mathbb{R}$, and a set of 3D points $X=\{x_0, x_1, ..., x_n\}$. How can I get a rotation matrix $R^*\in SO(3)$ and a translation matrix $T^*\in \mathbb{R}^3$ so that: $$ R^*,T^*=argmin_{R,T} \sum_i f^2(Rx_i+T) $$ which means I want to make all points lie on surface $S$ after rotation and translation.

The method I tried and failed

Set $\theta, T$ as learning parameters, where $\theta$ is the rotating angle .

Set $\sum_i f^2(Rx_i+T)$ as the optimization goal, using Adam to perform gradient descent to learn $\theta, T$.

def learning_rotations(self)
    gt_y = get_surf_pcl(self.sdf, 2500)            // sample 2500 points from sdf function
    gt_theta = -2                                  // ground truth theta
    gt_translation = torch.randn((1, 3)).cuda()    // ground truth T
    x = rotation(gt_y, -gt_theta, dim=0)           // rotate \theta about x axis
    x = x + gt_translation                         // get point cloud X
    theta = torch.autograd.Variable(
            torch.FloatTensor(np.random.randn(1, ))
        ).requires_grad_(True)
    translation = torch.autograd.Variable(
            torch.FloatTensor(np.random.randn(1 ,3)).to(V().cfg.device)
        ).requires_grad_(True)
    
    opt = torch.optim.Adam([theta, translation], lr=1e-2)
    sch = torch.optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.5)
    for iteration in range(2000):
        sch.step()
        opt.zero_grad()
        y_hat, _ = rotation(x - translation, theta, dim=0) // rotate \theta about x axis
        loss = self.sdf(y_hat).abs().mean()
        loss.backward()
        opt.step()