矩阵求导(Matrix Derivative)也称作矩阵微分(Matrix Differential),在机器学习、图像处理、最优化等领域的公式推导中经常用到。矩阵求导实际上是多元变量的微积分问题,只是应用在矩阵空间上而已,即为标量求导的一个推广,他的定义为将自变量中的每一个数与因变量中的每一个数求导。

具体地,假设存在 A_{m \times n} B_{p \times q} ,则 \frac{\partial A}{\partial B} 会将 A 中的每一个值对 B 中的每一个值求导,最后一共会得到 m \times n \times p \times q 个导数值。这么多的导数值,最后是排布成一个 m \times (n \times p \times q) 的矩阵还是一个 (m \times n \times p) \times q 的矩阵呢?矩阵求导的关键就在于规定如何排布这么多的导数值。

以分布布局为例子,一共有以下几个矩阵求导法则。分母布局是什么意思呢?简单的说就是以分母为一个基准,希望求导出来的结果和分母的维度相同。除了分母布局以外还有分子布局。分子布局和分母布局的求导结果通常相差一个转置。

基本法则

法则 0 :标量对标量求导

略。详细的请参考高等数学。

法则 1 :标量对向量求导

考虑我们有 f 是一个标量, % <![CDATA[ x = \begin{bmatrix} x_1 & x_2 & \cdots & x_p \end{bmatrix}^{T} %]]> 是一个 p \times 1 的列向量。则有:

% <![CDATA[ \frac{\partial f}{\partial x}=\begin{bmatrix}\frac{\partial f}{\partial x_1} & \frac{\partial f}{\partial x_2} & \cdots & \frac{\partial f}{\partial x_p}\end{bmatrix}^{T} %]]>

可以看得出,求导出来的结果维度是和分母 x 相同的。若 x 为行向量同理。

法则 2 :向量对标量求导

考虑我们有 % <![CDATA[ f = \begin{bmatrix} f_1 & f_2 & \cdots & f_m \end{bmatrix}^{T} %]]> 是一个 m \times 1 的列向量, x 是一个标量。则有:

% <![CDATA[ \frac{\partial f}{\partial x}=\begin{bmatrix}\frac{\partial f_1}{\partial x} & \frac{\partial f_2}{\partial x} & \cdots & \frac{\partial f_m}{\partial x}\end{bmatrix} %]]>

可以看得出,这个时候求导出来的结果维度和分子 f 是相反的。若 f 为行向量同理。

法则 3 :向量对向量求导

考虑我们有 % <![CDATA[ f = \begin{bmatrix} f_1 & f_2 & \cdots & f_m \end{bmatrix}^{T} %]]> 是一个 m \times 1 的列向量, % <![CDATA[ x = \begin{bmatrix} x_1 & x_2 & \cdots & x_p \end{bmatrix}^{T} %]]> 是一个 p \times 1 的列向量。则有:

% <![CDATA[ \frac{\partial f}{\partial x}=\begin{bmatrix}\frac{\partial f_1}{\partial x_1} & \frac{\partial f_2}{\partial x_1} & \cdots & \frac{\partial f_m}{\partial x_1} \\ \frac{\partial f_1}{\partial x_2} & \frac{\partial f_2}{\partial x_2} & \cdots & \frac{\partial f_m}{\partial x_2} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial f_1}{\partial x_p} & \frac{\partial f_2}{\partial x_p} & \cdots & \frac{\partial f_m}{\partial x_p} \end{bmatrix} %]]>

这时求导结果的维度为 p \times m

法则 4 :标量对矩阵求导

考虑我们有 f 是一个标量, x_{p \times q} 是一个矩阵。则有:

% <![CDATA[ \frac{\partial f}{\partial x}=\begin{bmatrix}\frac{\partial f}{\partial x_{11}} & \frac{\partial f}{\partial x_{12}} & \cdots & \frac{\partial f}{\partial x_{1q}} \\ \frac{\partial f}{\partial x_{21}} & \frac{\partial f}{\partial x_{22}} & \cdots & \frac{\partial f}{\partial x_{2q}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial f}{\partial x_{p1}} & \frac{\partial f}{\partial x_{p2}} & \cdots & \frac{\partial f}{\partial x_{pq}} \end{bmatrix} %]]>

同样,我们求导结果和分母 x 的维度一致,是 p \times q

法则 5 :矩阵对向量求导

考虑我们有 f_{m \times n} 是一个矩阵, x 是一个标量。则有:

% <![CDATA[ \frac{\partial f}{\partial x}=\begin{bmatrix}\frac{\partial f_{11}}{\partial x} & \frac{\partial f_{21}}{\partial x} & \cdots & \frac{\partial f_{m1}}{\partial x} \\ \frac{\partial f_{21}}{\partial x} & \frac{\partial f_{22}}{\partial x} & \cdots & \frac{\partial f_{m2}}{\partial x_{2q}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial f_{n1}}{\partial x} & \frac{\partial f_{n2}}{\partial x} & \cdots & \frac{\partial f_{nm}}{\partial x} \end{bmatrix} %]]>

我们求导的结果与分子相反,为 n \times m

其余:向量与矩阵之间以及矩阵与矩阵之间的求导

当我们的自变量与因变量都为不为标量时,根据我们对矩阵求导实质的讨论,势必会得出大量的导数难以被排列。例如,一般情况下,假设我们有 f_{m \times n} 以及 x_{p \times q} ,则求导后我们会得到 m \times n \times p \times q 个导数结果。这时对这些导数一般有两种定义方法。

第一种定义

我们按照之前的法则,将 \frac{\partial f}{\partial x} 理解为对每一个 f 中的标量,使其对 x 求导,然后将其放回矩阵 f 中的原位。即我们使用 \frac{\partial f_{ij}}{\partial x} 替换 f_{ij} ,最后会得到一个 mp \times nq 的导数矩阵。

第二种定义(主流)

这种定义是将矩阵对矩阵求导问题归约到向量对向量求导。即对矩阵先做向量化处理,然后再求导:

\frac{\partial f}{\partial x}=\frac{\partial vec(f)}{\partial vec(x)}

其中,向量化的实现方法分为列向量化和行向量化。我们以列向量化为例,将 f_{m \times n} x_{p \times q} 向量化为 f_{mn \times 1} x_{pq \times 1} ,然后利用法则 3 求导得到维度为 pq \times mn 的导数结果。

有用的公式

下列公式中, A_{m \times 1} x_{m \times 1} 是列向量, B_{m \times m} 是矩阵。下面 3 个公式在文末有证明。

编号 公式
1 \frac{\partial{x^{T}A}}{\partial{x}} = \frac{\partial{A^{T}x}}{\partial{x}} = A
2 \frac{\partial{x^{T}x}}{\partial{x}} = x
3 \frac{\partial{x^{T}Bx}}{\partial{x}} = (B + B^{T})x

下列公式是一些关于矩阵迹的公式。其中, a 是一个标量, A , B , C 分为三个矩阵。

编号 公式
1 tr(a) = a
2 tr(A) = tr(A^T)
3 tr(AB) = tr(BA)
4 tr(ABC) = tr(CAB) = tr(BCA)
5 \frac{\partial{tr(AB)}}{\partial{A}} = B^T
6 \frac{\partial{tr(ABA^{T}C)}}{\partial{A}} = CAB + C^{T}AB^{T}

一些公式的证明

令:

% <![CDATA[ A_{m \times 1} = \begin{bmatrix} A_1 & A_2 & \cdots & A_m \end{bmatrix} ^ {T} %]]> % <![CDATA[ B_{m \times m} = B_{m \times m} = \begin{bmatrix} B_{11} & B_{12} & \cdots & B_{1m} \\ B_{21} & B_{22} & \cdots & B_{2m} \\ \vdots & \vdots & \ddots & \vdots \\ B_{m1} & B_{m2} & \cdots & B_{mm} \end{bmatrix} %]]> % <![CDATA[ x_{m \times 1} = \begin{bmatrix} x_1 & x_2 & \cdots & x_m \end{bmatrix} ^ {T} %]]>

公式 1

\frac{\partial{x^{T}A}}{\partial{x}} = \frac{\partial{A^{T}x}}{\partial{x}} = A

因为 A_{m \times 1} x_{m \times 1} 是列向量,所以 x^{T}A = A^{T}x = \sum_{i=1}^{m}{A_{i}x_{i}} 为一个标量,所以可以用法则 1 进行计算。

\frac{\partial{x^{T}A}}{\partial{x}} = \frac{\partial{A^{T}x}}{\partial{x}} = \begin{bmatrix} \frac{\partial{\sum_{i=1}^{m}{A_{i}x_{i}}}}{\partial{x_1}} \\ \frac{\partial{\sum_{i=1}^{m}{A_{i}x_{i}}}}{\partial{x_2}} \\ \cdots \\ \frac{\partial{\sum_{i=1}^{m}{A_{i}x_{i}}}}{\partial{x_m}} \end{bmatrix} = \begin{bmatrix} A_1 \\ A_2 \\ \cdots \\ A_m \end{bmatrix} = A

公式 2

同理 公式 1

公式 3

\frac{\partial{x^{T}Bx}}{\partial{x}} = (B + B^{T})x

由题意可得, x^{T}Bx 为标量,则原式为标量对列向量求导,可以用法则 1 进行计算。

\frac{\partial{x^{T}Bx}}{\partial{x}} = \begin{bmatrix} \frac{\partial{\sum_{i=1}^{m}{\sum_{j=1}^{m}{B_{ij}x_{i}x_{j}}}}}{\partial{x_1}} \\ \frac{\partial{\sum_{i=1}^{m}{\sum_{j=1}^{m}{B_{ij}x_{i}x_{j}}}}}{\partial{x_2}} \\ \cdots \\ \frac{\partial{\sum_{i=1}^{m}{\sum_{j=1}^{m}{B_{ij}x_{i}x_{j}}}}}{\partial{x_m}} \end{bmatrix}

由导数法则有:

\frac{\partial{f(x)g(x)}}{\partial{x}} = \frac{\partial{f(x)}}{x}g(x) + f(x)\frac{\partial{g(x)}}{\partial{x}}

于是,原式继续有:

= \begin{bmatrix} \sum_{i=1}^{m}{B_{i1}x_i} + \sum_{j=1}^{m}{B_{1j}x_j} \\ \sum_{i=1}^{m}{B_{i2}x_i} + \sum_{j=1}^{m}{B_{2j}x_j} \\ \cdots \\ \sum_{i=1}^{m}{B_{im}x_i} + \sum_{j=1}^{m}{B_{mj}x_j} \end{bmatrix} = \begin{bmatrix} \sum_{i=1}^{m}{B_{i1}x_i} \\ \sum_{i=1}^{m}{B_{i2}x_i} \\ \cdots \\ \sum_{i=1}^{m}{B_{im}x_i} \end{bmatrix} + \begin{bmatrix} \sum_{j=1}^{m}{B_{1j}x_j} \\ \sum_{j=1}^{m}{B_{2j}x_j} \\ \cdots \\ \sum_{j=1}^{m}{B_{mj}x_j} \end{bmatrix}

% <![CDATA[ = \begin{bmatrix} B_{11} & B{21} & \cdots & B{m1} \\ B_{12} & B{22} & \cdots & B{m2} \\ \vdots & \vdots & \ddots & \vdots \\ B_{1m} & B{2m} & \cdots & B{mm} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_m \end{bmatrix} %]]> % <![CDATA[ + \begin{bmatrix} B_{11} & B{12} & \cdots & B{1m} \\ B_{21} & B{22} & \cdots & B{2m} \\ \vdots & \vdots & \ddots & \vdots \\ B_{m1} & B{m2} & \cdots & B{mm} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_m \end{bmatrix} %]]>

= (A^{T} + A)x = (A + A^{T})x

发现存在错别字或者事实错误?请麻烦您点击 这里 汇报。谢谢您!