We are building an online service where users can try our model. Our model assigns a different embedding vector for each entity. User would like to query for the entity $t$ that has the lowest scoring function: $$ || h \times M_r + r - t \times M_r ||^2_2 $$ Where $h, r, t \in \mathbb{R}^{1024}$, and $M_r$ is the projection matrix, in our case, a diagonal matrix $M_r \in \mathbb{R}^{1024 * 1024}$.
In the training process, we calculate the score for every possible entity and see if some entities have the lowest score, which takes about 5 seconds on our machine. But we want to build an online service, so taking 5 seconds serving a single request isn't acceptable.
There are tools, such as faiss, which can quickly find the top k vectors from a large set of vectors that have the smallest L2 distance from a given query vector. Is there any way to use this tool to speed up our query time?