Simplification of differentiation of multiple variables with summations

20 Views Asked by At

I've attempted to differentiate a function to verify the result of some research. The function is based on the $t$-distribution, which assigns the membership of data points $z_i$ to cluster centroids $\mu_j$. The final formula describes the difference of the current result $q$ to an improved clustering $p$. It is achieved with repeated normalisations of the membership values across two axes:

\begin{align} \hat{q}_{ij} &= (1 + ||z_i-\mu_j||^2 / \nu) ^ {-\frac{\nu+1}{2}} \\ q_{ij} &= \hat{q}_{ij}/\sum_{j'}\hat{q}_{ij'} \\ \hat{p}_{ij} &= q_{ij}^2 / \sum_{i'} q_{i'j} \\ p_{ij} &= \hat{p}_{ij}/\sum_{j'}\hat{p}_{ij'} \\ \text{KL}(P||Q) &= \sum_{i,j}p_{ij}\log\frac{p_{ij}}{q_{ij}} \end{align}

I'm mainly interested in the partial derivatives with respect to $z_i$ and $\mu_j$, because for practical purposes $\nu$ is constant. Now, it looks like the partials will be identical except for the sign, so differentiating once should be enough.

The authors provided shockingly simple solutions to the differentiation:

\begin{align} \frac{\partial \text{KL}}{\partial z_i} &= \frac{\nu + 1}{\nu}\sum_j(1 + ||z_i-\mu_j||^2 / \nu)^{-1}(p_{ij} - q_{ij})(z_i - \mu_j) \\ \frac{\partial \text{KL}}{\partial \mu_j} &= -\frac{\nu + 1}{\nu}\sum_i(1 + ||z_i-\mu_j||^2 / \nu)^{-1}(p_{ij} - q_{ij})(z_i - \mu_j) \end{align}

Me, I haven't been so lucky. Applying the chain rule to $\hat{q}_{ij}$ was simple enough:

$$\frac{\partial \hat{q}_{ij}}{\partial z_i} = -\frac{\nu+1}{\nu} (1 + ||z_i - \mu_j||^2 / \nu) ^ {-\frac{\nu+3}{2}} \cdot (z_i - \mu_j)$$

And likewise for the later formulas independently (I think):

\begin{align} \frac{\partial q_{ij}}{\partial z_i} &= \frac{\frac{\partial \hat{q}_{ij}}{\partial z_i} \cdot \sum_{j'}\hat{q}_{ij'} - \hat{q}_{ij} \cdot \sum_{j'}\frac{\partial \hat{q}_{ij'}}{\partial z_i}}{(\sum_{j'}\hat{q}_{ij'})^2} \\ % p-hat \frac{\partial \hat{p}_{ij}}{\partial z_i} &= \frac{2q_{ij}\frac{\partial q_{ij}^2}{\partial z_i} \cdot \sum_{i'} q_{i'j} - q_{ij}^2 \cdot \frac{\partial q_{ij}}{\partial z_i}}{(\sum_{i'}q_{i'j})^2} \\ % p \frac{\partial p_{ij}}{\partial z_i} &= \frac{\frac{\partial \hat{p}_{ij}}{\partial z_i} \cdot \sum_{j'}\hat{p}_{ij'} - \hat{p}_{ij} \cdot \sum_{j'}\frac{\partial \hat{p}_{ij'}}{\partial z_i}}{(\sum_{j'}\hat{p}_{ij'})^2} \\ % KL \frac{\partial \text{KL}}{\partial z_i} &= \frac{\partial p_{ij}}{\partial z_i}\log\frac{p_{ij}}{q_{ij}} + p_{ij}\frac{\partial \log\frac{p_{ij}}{q_{ij}}}{\partial z_i} \\ &= \frac{\partial p_{ij}}{\partial z_i}\log\frac{p_{ij}}{q_{ij}} + p_{ij}\frac{q_{ij}}{p_{ij}}\frac{\partial \frac{p_{ij}}{q_{ij}}}{\partial z_i} \\ &= \frac{\partial p_{ij}}{\partial z_i}\log\frac{p_{ij}}{q_{ij}} + p_{ij}\frac{q_{ij}}{p_{ij}}\left(\frac{\partial p_{ij}}{\partial z_i}q_{ij} - p_{ij}\frac{\partial q_{ij}}{\partial z_i}\right) / q_{ij}^2 \\ &= \frac{\partial p_{ij}}{\partial z_i}\log\frac{p_{ij}}{q_{ij}} + \left(\frac{\partial p_{ij}}{\partial z_i}q_{ij} - p_{ij}\frac{\partial q_{ij}}{\partial z_i}\right) / q_{ij} \\ \end{align}

But here's where my trust broke down. I see no immediate ways of simplifying the expression, and substituting everything in until I get to the original variables seems like a huge waste of paper. $\partial \hat{q}_{ij}$ is quite close to the provided formulas, but $\partial \text{KL}$ still has a logarithmic term too.

Am I missing some important concepts with differentiating summations or normalisations? I know that differentiating a sum in which only one component is variable can be reduced, but I see no such easy reductions here beyond one summation across $i$ in $\partial \hat{p}_{ij}$, which I already reduced. I did find a couple of common factors when substituting to $\partial q_{ij}$, but no real simplifications.

\begin{align} N_{ij} &= 1 + ||z_i - \mu_j||^2/\nu \\ \frac{\partial q_{ij}}{\partial z_i} &= \frac{-\frac{v+1}{v}N_{ij}^{-\frac{v+1}{2}}\cdot\left((z_i-\mu_j)\cdot\sum_{j'}N_{ij'}^{-\frac{v+1}{2}}\cdot N_{ij}^{-1} + \sum_{j'}\left(-(z_i-\mu_{j'})N_{ij'}^{-\frac{v+3}{2}}\right)\right)}{\left(\sum_{j'}N_{ij'}^{-\frac{v+1}{2}}\right)^2} \end{align}

I'd appreciate any pointers towards a reasonable approach! Or if substituting and seeing it play out is the only way to go, then so be it.