Einstein summation for matricized tensor times Khatri Rao product (mttkrp) for use in tensor decomposition

51 Views Asked by At

I am trying to compute the matricized tensor times Khatri-Rao product (mttkrp) as part of an optimization problem for tensor factorization, shown below:

$A = \mathcal{X}_{(1)}(B \otimes C)$

Where $\mathcal{X} \in \mathbb{R}^{L \times M \times N} $ is a tensor, $B \in \mathbb{R}^{M \times R}$ and $C \in \mathbb{R}^{N \times R}$ are (factor) matrices and $A \in \mathbb{R}^{L \times R}$. The subscript $\mathcal{X}_{(1)}$ indicates the i-th mode unfolding of a tensor (matricization where the mode-i fibers become the columns of the resulting matrix) and $\otimes$ is the Khatri-Rao product. The mttkrp is the most expensive operation in my problem, so often a mttkrp kernel is used in place of this direct computation, shown in index notation:

$A_{i,j} = \sum_{k,l} \mathcal{X}_{i,k,l} B_{k,j} C_{l,j}$

What I need actually need to compute is the mttkrp of matrices in reverse:

$D = \mathcal{X}_{(1)}(C \otimes B)$.

$D \in \mathbb{R}^{L \times R}$. I would like to know if anyone knows how to formulate this in Einstein summation?

EDIT: I have seen it reported several times in literature/software documentation that the index notation expression for $D$ is (e.g., Kjolstad et al):

$D_{i,j} = \sum_{k,l} \mathcal{X}_{i,k,l} C_{l,j} B_{k,j}$

However, I am unable to reproduce/program this. See associated python code for the comparison of direct matrix multiplication vs associated kernels.

    import numpy as np
    from tensorly.tenalg import khatri_rao
    from tensorly.base import unfold
    
    L,M,N,R = 10,20,15,3
    X = np.random.randn(L,M,N)
    B = np.random.randn(M,R)
    C = np.random.randn(N,R)
    
    p1 = np.dot(unfold(X,0),khatri_rao([B,C])) # forward
    p2 = np.dot(unfold(X,0),khatri_rao([C,B])) # reverse
    p3 = np.einsum('ikl,kj,lj->ij',X,B,C) # forward
    np.allclose(p1,p3) # True
    p4 = np.einsum('ilk,kj,lj->ij',X,C,B) # proposed kernel
    np.allclose(p2,p4) # False
    np.allclose(p1,p4) # True