Uniform sampling of points on a simplex

4.5k Views Asked by At

I have this problem: I'm trying to sample the relation $$ \sum_{i=1}^N x_i = 1 $$ in the domain where $x_i>0\ \forall i$. Right now I'm just extracting $N$ random numbers $u_i$ from a uniform distribution $[0,1]$ and then I transform them into $x_i$ by using $$ x_i = \frac{u_i}{\sum_{i=1}^N{u_i}}. $$

This is correct but requires me three $u_i$ to have a sample of the relation.

If $N=2$, then the relation simplify into $x_1 + x_2 = 1$ and I can easily extract one $u_1 = x_1$ and evaluate the other as $x_2 = 1 - x_1$, therefore using only one extraction. But if I do the same with $N=3$, I can have $u_1 = 0.8$, $u_2 = 0.7$ and I can't just use the first relation to evaluate the $x_i$ and I cannot combine it with the second one, since I don't know the $\sum u_i$.

How can I sample the $x_i$ that satisfies the first relation while only picking $N-1$ values from a random distribution (preferably the uniform one)? And what kind of mathematical problem is this?

PS I'm not a mathematician and this is my first question here. I've searched for an answer but I think that I haven't already well defined the problem so I don't really know what to search. I feel the solution might be close to this but I can't figure out how. So I would appreciate any help on better defining the problem from a mathematical point of view, as well on finding its solution (of course xD)

PS2 also a help to better choose the tags and the title would be appreciated

2

There are 2 best solutions below

6
On

Mathematically, you are trying to uniformly sample points on a simplex, which in turn is also equivalent to sampling points from a Dirichlet distribution and then normalizing.

Here's how to see all this: take the unit interval and pick uniformly at random and independently $N-1$ points. These points partition the unit line into segments, so take your $x_i$ to equal the lengths of the segments. Basically, to implement this, you draw the $N-1$ points at random and then you sort them in increasing order, which gives you a kind of Beta distribution.

If you need to program this, here's one efficient way of doing it without sorting:

1) Generate $N$ independent points $E_i$ from a an exponential distribution by drawing $U_i$ from a uniform distribution on $[0,1]$ and then compute the negative logorithm: $E_i:=-\log(U_i)$.

2) Sum all samples to get $S:=\sum_{i=1}^{N}E_i$.

3) The vector $x:=(x_1,\ldots,x_{N})$, where $x_i=E_i/S$ is uniformly distributed in your space.

0
On

There is also an analytic expression you can use, to map uniformly distributed points inside a $d$ dimensional unit hypercube, to uniformly distributed points inside a simplex of the same dimension with nodes $[\boldsymbol\xi_0, \boldsymbol\xi_1,\cdots,\boldsymbol\xi_{d+1}]$, where $\boldsymbol\xi_i\in\mathbb{R}^{d}$:

\begin{align} M_{d}(r_1, \cdots,r_d) = \boldsymbol\xi_0 + \sum_{i=1}^{d}\prod_{j=1}^i r_{d-j+1}^{\frac{1}{{d-j+1}}}(\boldsymbol\xi_i - \boldsymbol\xi_{i-1}), \end{align} The $r_{i}$ are the points inside the unit hypercube and are distributed as $\mathcal{U}[0,1]$, and the points inside the simplex (i.e. $M_d\in\mathbb{R}^d$), are also distributed uniformly. This formula is derived in an appendix of this article: https://doi.org/10.1016/j.jcp.2015.12.034

When I derived this formula I wasn't thinking of sampling on the (unit) simplex, but with a bit of experimenting, I think this is still possible by simply setting $r_d=1$, and of course by setting the nodes $\boldsymbol\xi_i$ to those of the unit simplex.

Below you'll find a Python implementation for sampling inside or on the (unit) simplex with a 2D, 3D, 4D and 100D example:

import numpy as np
import matplotlib.pyplot as plt

def sample_simplex(xi, n_mc=1, on_simplex=False):
    """
    Use an analytical function map to points in the d-dimensional unit hypercube to a
    d-dimensional simplex with nodes xi.

    Parameters
    ----------
    xi: array of floats, shape (d + 1, d)
        The nodes of the d dimensional simplex.

    n_mc : int, The default is 1.
        Number of samples to draw from inside the simplex

    on_simplex: boolean, default is False
        If True, sample on the simplex rather than inside it

    Returns: array, shape (n_mc, d)
    -------
    n_mc uniformly distributed points inside the d-dimensional simplex with edges xi.

    """
    d = xi.shape[1]
    P = np.zeros([n_mc, d])
    for k in range(n_mc):
        # random points inside the unit hypercube
        r = np.random.rand(d)
        # sample on, instead of inside the simplex (all samples will sum to one)
        if on_simplex:
            r[-1] = 1

        # the term of the map is \xi_k_j0
        sample = np.copy(xi[0])
        for i in range(1, d + 1):
            prod_r = 1.
            # compute the product of r-terms: prod(r_{d-j+1}^{1/(d-j+1)})
            for j in range(1, i + 1):
                prod_r *= r[d - j]**(1. / (d - j + 1))
            # compute the ith term of the sum: prod_r*(\xi_i-\xi_{i-1})
            sample += prod_r * (xi[i] - xi[i - 1])
        P[k, :] = sample

    return P

plt.close('all')

# triangle points
xi = np.array([[0.0, 0.0], [1.0, 0.0], [0, 1]])
samples = sample_simplex(xi, 1000, on_simplex=True)
fig = plt.figure()
ax = fig.add_subplot(111)
# plot random samples
ax.plot(samples[:,0], samples[:, 1], '.')
# plot edges
ax.plot(xi[:, 0], xi[:, 1], 'o')
plt.tight_layout()

# 3D simplex
xi = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0, 1, 0], [0,0,1]])
samples = sample_simplex(xi, 1000, on_simplex=True)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# plot random samples
ax.scatter(samples[:,0], samples[:, 1], samples[:,2], '.')
# plot edges
ax.scatter(xi[:, 0], xi[:, 1], xi[:, 2], 'o')
plt.tight_layout()

# 4D simplex
xi = np.array([[0,0,0,0], [1.0, 0.0, 0.0, 0], [0, 1, 0, 0], [0,0,1,0], [0,0,0,1]])
samples = sample_simplex(xi, 1000, on_simplex=True)
# These should all sum to one
print(np.sum(samples, axis=1))

# 100D simplex
xi = np.zeros([1, 100])
xi = np.append(xi, np.eye(100), axis=0)
samples = sample_simplex(xi, 1000, on_simplex=True)
# These should all sum to one
print(np.sum(samples, axis=1))

plt.show()