Lemma: Let $A \in \mathbb{R}^{m\times n}$ and $B \in \mathbb{R}^{n\times k}$ be matrices, and let $f:\mathbb{R}^{m\times k} \to\mathbb{R}$ be a differentiable function. Let $X = AB$. Then $$ \frac{\partial f(X)}{\partial A} = \Big(\frac{\partial f(X)}{\partial X}\Big)B^T, $$ where the right-hand side is the matrix product, and we are writing $\frac{\partial f(X)}{\partial X}$ for the matrix with $(i,j)$-entry $\frac{\partial f(X)}{\partial X_{ij}}$.
Proof: One way to see this is by using the multivariable chain rule to get $$ \frac{\partial f(AB)}{\partial A_{ij}} = \sum_{k,l}\frac{\partial f(X)}{\partial X_{kl}}\cdot\frac{\partial X_{kl}}{\partial A_{ij}}, $$ and then observing that $$ \frac{\partial X_{kl}}{\partial A_{ij}} = \mathbb{1}_{i=k}B_{jl}, $$ which implies that $$ \sum_{k,l}\frac{\partial f(X)}{\partial X_{kl}}\cdot\frac{\partial X_{kl}}{\partial A_{ij}} = \sum_{l}\frac{\partial f(X)}{\partial X_{il}}B_{jl} = \Big(\Big(\frac{\partial f(X)}{X}\Big)B^T\Big)_{ij}. $$ $$\tag*{$\blacksquare$}$$ My question: The thing we are trying to show looks pretty much exactly like the single variable chain rule, namely that for real numbers $x = ab$, we have $$ \frac{\partial f(x)}{\partial a} = \Big(\frac{d f(x)}{dx} \Big)\cdot b, $$ so I'm wondering if there is some appropriate "matrix chain rule" that lets us see the lemma immediately, without having to expand it in terms of all these scalar derivatives.
I've searched around things like "matrix/tensor chain rule", "matrix/tensor calculus", but I haven't found exactly what I'm looking for. I suspect that the answer lies in tensor calculus, but the references I've found contain a lot of abstraction, and it seems they are intended for a much more general setting (i.e. general relativity or the like).
This is a long comment.
An easy way to check this is with a simple example. $$ A = \begin{bmatrix} x & y \\ z & t \end{bmatrix} $$ $$ B=\begin{bmatrix} a & b \\ c & d \end{bmatrix} $$ $$ f(AB)= (AB)_{11} (AB)_{22} + (AB)_{12} (AB)_{21} =X_{11} X_{22} + X_{12} X_{21} $$ We have, $$ X_{11}=(AB)_{11}=ax+cy $$ $$ X_{12}=(AB)_{12}=bx+dy $$ $$ X_{21}=(AB)_{21}=az+ct $$ $$ X_{22}=(AB)_{22}=bz+dt $$ Lets adopt the convention that $\partial f / \partial X$ is, $$ \frac{\partial f}{\partial X} = \begin{bmatrix} \frac{\partial f}{\partial X_{11}} & \frac{\partial f}{\partial X_{21}} \\ \frac{\partial f}{\partial X_{12}} & \frac{\partial f}{\partial X_{22}} \\ \end{bmatrix} = \begin{bmatrix} X_{22} & X_{12} \\ X_{21} & X_{11} \\ \end{bmatrix} $$
I will assume that, $$ {\left(\frac{\partial f}{\partial A}\right)^j}_i={B^k}_{i} {\left(\frac{\partial f}{\partial X}\right)^j}_k $$ that is, \begin{align*} \frac{\partial f}{\partial A}= \begin{bmatrix} a & b \\ c & d \\ \end{bmatrix} \begin{bmatrix} X_{22} & X_{12} \\ X_{21} & X_{11} \\ \end{bmatrix} & = \begin{bmatrix} aX_{22} + b X_{21}& aX_{12} + b X_{11}\\ cX_{22}+ d X_{21} & cX_{12} + d X_{11}\\ \end{bmatrix} \\ &= \begin{bmatrix} 2abz+(ad+bc)t& 2abx+(ad+bc)y\\ 2cdt+(ad+bc)z & 2cdy+(ad+bc)x\\ \end{bmatrix} \\ \end{align*}
The brute force approach would calculate $\partial f / \partial A$ directly. $$ \frac{\partial f }{ \partial A_{11}} = \frac{\partial f }{ \partial x} = \frac{\partial X_{11} }{ \partial x} X_{22} + \frac{\partial X_{12} }{ \partial x} X_{21} = a (bz+dt)+b(az+ct)=2abz+(ad+bc)t $$ $$ \frac{\partial f }{ \partial A_{21}} = \frac{\partial f }{ \partial z} = \frac{\partial X_{22} }{ \partial z} X_{11} + \frac{\partial X_{21} }{ \partial z} X_{12} = b (ax+cy)+a(bx+dy)=2abx+(ad+bc)y $$ $$ \frac{\partial f }{ \partial A_{12}} = \frac{\partial f }{ \partial y} = \frac{\partial X_{11} }{ \partial y} X_{22} + \frac{\partial X_{12} }{ \partial y} X_{21} = c (bz+dt)+d(az+ct)=2cdt+(ad+bc)z $$ $$ \frac{\partial f }{ \partial A_{22}} = \frac{\partial f }{ \partial t} = \frac{\partial X_{22} }{ \partial t} X_{11} + \frac{\partial X_{21} }{ \partial t} X_{12} = d (ax+cy)+c(bx+dy)=2cdy+(ad+bc)x $$ So $\partial f / \partial A = B \partial f / \partial X$ works. This does not depend on this specific example and there is a way of arriving at this result by defining linear maps between tensor spaces.
Note what happens if you define $\partial X / \partial X$ (lets assume $X$ is a $2\times 2$ matrix to save space) as, \begin{equation*} \frac{\partial X} { \partial X} = \begin{bmatrix} \frac{\partial X}{\partial X_{11}} & \frac{\partial X}{\partial X_{12}} \\ \frac{\partial X}{\partial X_{21}} & \frac{\partial X}{\partial X_{22}} \\ \end{bmatrix} = \begin{bmatrix} \begin{bmatrix} 1 & 0 \\ 0 & 0 \\ \end{bmatrix} & \begin{bmatrix} 0 & 1 \\ 0 & 0 \\ \end{bmatrix} \\ & \\ \begin{bmatrix} 0 & 0 \\ 1 & 0 \\ \end{bmatrix} & \begin{bmatrix} 0 & 0 \\ 0 & 1 \\ \end{bmatrix} \\ \end{bmatrix} \end{equation*} We should have ${(\partial X/\partial X)^{jl}}_{ik} {X^k}_l = {X^j}_i$ in the same way that $(dx/dx)x=x$. But instead we get, \begin{align*} {{{\left(\frac{\partial X}{\partial X}\right)}^{jl}}}_{ik} {X^k}_l & = \begin{bmatrix} \mathrm{tr} \left( \begin{bmatrix} 1 & 0 \\ 0 & 0 \\ \end{bmatrix} \begin{bmatrix} X_{11} & X_{12} \\ X_{21} & X_{22} \end{bmatrix} \right) & \mathrm{tr} \left( \begin{bmatrix} 0 & 1 \\ 0 & 0 \\ \end{bmatrix} \begin{bmatrix} X_{11} & X_{12} \\ X_{21} & X_{22} \end{bmatrix} \right) \\ & \\ \mathrm{tr} \left( \begin{bmatrix} 0 & 0 \\ 1 & 0 \\ \end{bmatrix} \begin{bmatrix} X_{11} & X_{12} \\ X_{21} & X_{22} \end{bmatrix} \right) & \mathrm{tr} \left( \begin{bmatrix} 0 & 0 \\ 0 & 1 \\ \end{bmatrix} \begin{bmatrix} X_{11} & X_{12} \\ X_{21} & X_{22} \end{bmatrix} \right) \end{bmatrix} \\ & = \begin{bmatrix} X_{11} & X_{21} \\ X_{12} & X_{22} \end{bmatrix} \end{align*} which is $X^T$.