SinkHorn and optimal transport.

357 Views Asked by At

*************************************** EDIT *********************************************

I'm trying to code Sinkhorn algorithm, especially I'm trying to see if I can compute the optimal transportation between two measures when the strengh of the entropic regularization converges to 0.

For exemple let's transport the uniform measure $U$ over $[0;1]$ into the uniform measure $V$ over $[1;2]$. The optimal measure for the quadratic coast is $(x,x+1)_{\#} U$.

Let's discretize $[0;1]$, the measure $U$, $[1;2]$ and the measure $V$. Using Sinkhorn I'm supposed to get a measure such that the support is in the graphe of the line $y = x+1$. But it didn't so I'm working on it to find what's the problem. I'm going to show you my code and my result maybe someone can help me.

import numpy as np
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.colors as colors

#Parameters
N = 10                        #Step of the discritization of [0,1]
Niter = 10**5

def Sinkhorn(C, mu, nu, lamb):
    # lam : strength of the entropic regularization

    #Initialization
    a1 = np.zeros(N)
    b1 = np.zeros(N)
    a2 = np.ones(N)
    b2 = np.ones(N)
    Iter = 0

    GammaB = np.exp(-lamb*C)

    #Sinkhorn
    while (Iter < Niter):
        a1 = a2
        b1 = b2
        a2 = mu/(np.dot(GammaB,b1))
        b2 = nu/(np.dot(GammaB.T,a2))
        Iter +=1

    # Compute gamma_star
    Gamma = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            Gamma[i][j] = a2[i]*b2[j]*GammaB[i][j]
    Gamma /= Gamma.sum()

    return Gamma

    ## Test between uniform([0;1]) over uniform([1;2])

    S = np.linspace(0,1,N, False)  #discritization of [0,1]
    T = np.linspace(1,2,N,False)  #discritization of [1,2]

    # Discretization of uniform([0;1])
    U01 = np.ones(N)
    Mass = np.sum(U01)
    U01 = U01/Mass

    # Discretization uniform([1;2])
    U12 = np.ones(N)
    Mass = np.sum(U12)
    U12 = U12/Mass

    # Cost function
    X,Y = np.meshgrid(S,T)
    C = (X-Y)**2               #Matrix of c[i,j]=(xi-yj)²

    def plot_Sinkhorn_U01_U12():

        #plot optimal measure and convergence
        fig = plt.figure()
        for i in range(4):
            ax = fig.add_subplot(2, 2, i+1, projection='3d')
            Gamma_star = Sinkhorn(C, U01, U12, 5**i)
            ax.scatter(X, Y, Gamma_star, cmap='viridis', linewidth=0.5)
            plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(5**i))
        plt.show()

        plt.figure()
        for i in range(4):
            plt.subplot(2,2,i+1)
            Gamma_star = Sinkhorn(C, U01, U12, 5**i)
            plt.imshow(Gamma_star,interpolation='none')
            plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(5**i))

        plt.show()

        return

    # The transport between U01 ans U12 is x -> x+1 so the support of gamma^* is contained in the graph of the function x -> (x,x+1) which is the line y = x+1

    plot_Sinkhorn_U01_U12()

This is what I get :

enter image description here

enter image description here

Here is the matrix I got for gamma_star(125) :

[[0.08 0.02 0.   0.   0.   0.   0.   0.   0.   0.  ]
 [0.02 0.06 0.02 0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.02 0.06 0.02 0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.02 0.06 0.02 0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.02 0.06 0.02 0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.02 0.06 0.02 0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.02 0.06 0.02 0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.02 0.06 0.02 0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.02 0.06 0.02]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.02 0.08]]

It's closer from my expection which is a diagonal matrix.

An other problem, maybe the main, is I can't push this code too far, for exemple if I try to compute gamma_star(300) here's the result :

Gamma_star = [[nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]]

Which come from the fact $a_{2}$ and $b_{2}$ converges to $+\infty$.

b2 : [4.22070260e+171 3.28948016e+146 8.44373864e+121 7.78758450e+097
 5.12009123e+074 6.52968858e+051 8.65981779e+028 5.72200901e+005
 4.20378081e-019 5.74237308e-044] de norme : inf
a2 = [4.58777657e-043 3.29463705e-018 4.23493266e+006 5.99938220e+029
 4.32904674e+052 3.25254593e+075 4.66447084e+098 4.81016120e+122
 1.83818231e+147 2.33362605e+172]
b2 : [6.07800116e+171 4.63605357e+146 1.14536591e+122 9.62865515e+097
 5.61133076e+074 6.57044337e+051 7.97952189e+028 4.60805921e+005
 3.04310458e-019 4.00061372e-044] de norme : inf
a2 = [3.18546735e-043 2.37716997e-018 3.41731335e+006 5.52999193e+029
 4.35727949e+052 3.56568492e+075 5.75762196e+098 6.54678638e+122
 2.59987570e+147 3.37185830e+172]

Thanks and regards.