很久很久以前,当计算机还不普及的时候,对一个复杂函数的求导一定是众多学者的噩梦之一。试想,也许面对一个三元多项式函数时,你还可以游刃有余地对每一个变量求导数,但当函数被拓展到成千上万元、成千上万项的时候,你还有信心求的出来导吗?于是,随着计算机的普及,自动求导算法的提出帮助众多学者从大型函数的求导中解放了出来。

深度学习的反向传播也是一个典型的自动求导过程,而作为 pytorch 魔法力量的核心 Autograd,一定也被很多人好奇过其实现的原理。在本篇文章中,我将向大家解释一个最为简单的自动求道例子,希望可以抛砖引玉,启发大家。

自动求道的原理

每一个人在大一学习高等数学时,一定都听说过一个名词:链式法则。链式法则的数学表示如下:

\frac {du} {dx}=\frac {du} {dy}\cdot \frac {dy} {dx}

可以看到, u 是包含了 y 的函数, y 是包含了 x 的函数,那么为了求到 u x 的导数, 我们可以先求 u y 的导数 \frac{du}{dy} , 然后求 y x 的导数 \frac{dy}{dx} ,最后把两个导数相乘就算出了我们想要的结果。

有人说,反向传播就是链式法则的另外一个花哨的名字。确实,这其实就是自动求导的核心。链式法则告诉我们,面对一个复杂的函数,我们若是可以把它拆分为一节一节简单的函数原子复合起来的结果,那我们对每一个简单的函数原子求导,最后把导数乘起来就可以得到复杂函数的结果。那么,深度学习最基础的神经元, y = wx + b 是不是就是一个简单的原子?

所以,自动求导和思路也就清晰了。我们通过前向传播记录每一个变量到最终函数的路径,然后我们沿着这条路径从函数返回,就可以得到函数对该变量的导数了。

利用二叉树设计一个自动求导机

在本部分我们将利用二叉树设计一个最简单的自动求导机。我们假设对一个简单的函数求导:

z=3x^{2}+5x+\lg y

将算式转换为后缀表达式

上述函数为一个中缀表达式。中缀表达式是一种方便人类认知的表达式,但其不方便计算机读取。为了让计算机得到这个函数的解析树,我们需要先使用调度场算法将函数从中缀表达式转换为后缀表达式。

将后缀表达式转换为解析树

将后缀表达式转换为解析树就是一个非常简单的过程了,我们只需要对后缀表达式作一次遍历即可。遍历的规则如下:

  • 当读取到的不为运算符时,将读取的字符作为解析树的一个节点存入栈内。
  • 当读取到的为运算符时,将读取的字符作为解析树的一个节点,取出栈顶的两个节点作为其的左右子节点,然后存入栈内。

例如我们有一个表达式为 3+4\times \frac {2} {\left( 1-5\right) ^{2^{3}}} ,他所对应的后缀表达式为:

1
3 4 2 * 1 5 − 2 3 ^ ^ / + 

于是我们构建解析树的过程如下:

输入
3 (3)
4 (3),(4)
2 (3),(4),(2)
x (3),(4 x 2)
1 (3),(4 x 2),(1)
5 (3),(4 x 2),(1),(5)
- (3),(4 x 2),(1 - 5)
2 (3),(4 x 2),(1 - 5),(2)
3 (3),(4 x 2),(1 - 5),(2),(3)
^ (3),(4 x 2),(1 - 5),(2 ^ 3)
^ (3),(4 x 2),((1 - 5) ^ (2 ^ 3))
/ (3),((4 x 2) / ((1 - 5) ^ (2 ^ 3)))
+ (3 + (4 x 2) / ((1 - 5) ^ (2 ^ 3)))

代码大概是这样的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
struct ParseTree{
	string data; // 存的数据
	int result = -1; // 子树计算的结果
	int diff = -1; // 自动求导值
	struct ParseTree* left = nullptr;
	struct ParseTree* right = nullptr;
};

ParseTree* parse(){
	for (int I = 0; I < o_vector.size(); I++){
		if (isOperator(o_vector[I])) {
			ParseTree* s1 = parse_stack.top();
			parse_stack.pop();
			ParseTree* s2 = parse_stack.top();
			parse_stack.pop();
			ParseTree* parseNode = new ParseTree();
			parseNode->data = o_vector[I];
			parseNode->left = s1;
			parseNode->right = s2;
			parse_stack.push(parseNode);
		} else {
			ParseTree* parseLeaf = new ParseTree();
			parseLeaf->data = o_vector[I];
			parse_stack.push(parseLeaf);
		}
	}
	return parse_stack.top();
}

正向传播

我们得到的解析树之后,根节点就相当于是函数本身,内节点相当于是每一个计算符号,叶节点相当于每一个变量。正向传播的过程,实际就是从叶节点开始,经过每一个内节点最后计算到根节点的过程。由于二叉树的特点,每一个节点只有两个字节点,也就是说每一次内节点只会涉及一个符号,两个变量。所以整个正向传播过程就是从根节点的一个递归过程。

1
2
3
4
5
6
7
int calculate(ParseTree* head) {
	if (head->left != nullptr || head->right != nullptr) { // head is a node
		head->result =node_calculate((head->data)[0], calculate(head->left), calculate(head->right)); // node_calculate() is a function to calculate the result
		return head->result;
	}
	return stoi(head->data);
}

反向传播

我们正向传播后,每一个内节点,尽管其可能包含的是一个运算符号,它也存储了以它为根节点的子树的运算结果。所以当对它的父节点作运算时,可以将它看作一个数字。所以方向传播的过程就变成了从根节点开始的,每一次算字节点的导数,最后导数累加的过程。这个过程希望读者可以自己来实现代码。

当读者实现完这个部分的代码后,一段完整的自动求导代码也就完成了。



发现存在错别字或者事实错误?请麻烦您点击 这里 汇报。谢谢您!