Most efficient way of determining best output from pairwise neural network

27 Views Asked by At

In a recent introductory AI class I took, we were asked to write a system of neural networks to classify numbers from the MNIST database, a database of pictures of handwritten digits commonly used for testing machine learning classification. In our system, there was $\binom{b}{2}$ (with $b = 10$) "pairwise" neural networks, one for each pair of digits.

Let the neural network $N_{i, j} = N_{j, i}$ be the network for testing between digit $i$ and digit $j$. Each network would output a value $x_{i, j}$ chosen from the probability distribution $f(x)$ (which is $0$ for $x < 0$ and $x > 1$ and is symmetric about $x = 1/2$). The score for digit $j$ would go up by $x_{i, j}$ and for digit $i$, it would go up by $1 - x_{i, j}$. The task was then to select as the prediction the digit that gets the highest score.

Equivalently, this can be written with $g(x) = f\left(x+\frac{1}{2}\right)$ and the score of digit $i$ going down by $x_{i, j}$. This is a bit easier to work with, so I'll just use this.

I am looking for an algorithm that requires running the fewest number of pairwise neural networks (on average). The key decision is choosing the two digits to run the next network on (i.e. step $2$ in the following general algorithm).

$1$. Initialize scores to [0, 0, ..., 0, 0] (length $b$).
$2$. Choose two digits somehow.
$3$. Run the neural network between those two digits and update scores.
$4$. If it is certain that a digit is going to win (i.e. it has an insurmountable lead in scores), return that digit and end.
$5$. Repeat steps $2 - 4$ until the process has ended or all networks have been run.
$6$. Return $\text{argmax} ( \text{scores} )$ .

Step $4$ is pretty easy to calculate: Let $n$ be the digit with the highest score. Simply check whether $\text{scores}_n - \frac{k_n}{2} > \max\left(\text{scores}_i + \frac{k_i}{2}, i = 0...b-1, i \not = n\right)$, where $k_i$ is the number of neural networks left to check for digit $i$. The hard part is step $2$ - how can the digits be chosen optimally at each iteration?

For $b = 3$, it is simple. Check $N_{0, 1}$. Run the network with $2$ and whichever digit "won" (got the score above $0$) with $N_{0, 1}$. The process can be ended here depending on the scores, or might have to run the last network.

One possible approach for general $b$ is to choose the digit, $d_1$, with the highest probability of being selected for which not all networks have been run. Then choose the digit, $d_2$, with the next highest probability of being selected such that $N_{d_1, d_2}$ has not been run.

I have a couple of questions on this topic:

$1$. Is the above approach optimal? If not, what would be the best algorithm for choosing the digits?
$2$. If it is optimal, how would I calculate the probability of being selected?
$3$. What would be the expected number of networks to run with the optimal approach?