Find maximum entropy dirichlet distribution with given mean

33 Views Asked by At

Given a categorical distribution p, any Dirichlet distribution Dir(k*p) will have mean p for k>0, however I wish to find the k which maximizes the entropy of this distribution. The following code does it in an iterative search:

import numpy as np
from scipy.optimize import minimize
from scipy.special import gammaln, psi

def dirichlet_entropy(alpha):
    k = len(alpha)
    alpha_0 = np.sum(alpha)
    entropy = np.sum(gammaln(alpha)) - gammaln(alpha_0) + (alpha_0 - k) * psi(alpha_0) - np.sum((alpha - 1) * psi(alpha))
    return entropy

def optimize_alpha(p):
    # Objective function to be minimized
    # It returns the negative of the Dirichlet entropy for the alpha = s * p
    def objective(s, p):
        alpha = s[0] * p
        return -dirichlet_entropy(alpha)

    # Initial guess
    s0 = [1]
    
    # Call the optimizer
    result = minimize(objective, s0, args=(p), method='L-BFGS-B', tol=1e-6)
    s_opt = result.x[0]
    
    # Return the optimal alpha
    return s_opt * p

For example optimize_alpha([1/3,1/3,1/3]) returns approximately [1,1,1] (uniform distribution), and optimize_alpha([0.5, 0.3, 0.2]) returns [1.86442577 1.11865546 0.74577031].

Is there some clever way to do this analytically, or at least approximate it?