Python的递归函数

递归(英语:Recursion),又译为递回,在数学与计算机科学中,是指在函数的定义中使用函数自身的方法。递归一词还较常用于描述以自相似方法重复事物的过程。例如,当两面镜子相互之间近似平行时,镜中嵌套的图像是以无限递归的形式出现的。也可以理解为自我复制的过程。

我们先来看一个简单的例子:

阶乘

1
2
3
4
def fact(n):
if n == 1:
return 1
return n * fact(n - 1) # 不要忘记return

计算过程为:

1
2
3
4
5
fact(5)
5 * fact(4)
5 * (4 * fact(3))
5 * (4 * (3 * fact(2)))
5 * (4 * (3 * (2 * fact(1))))

上面的例子存在一个容易错误的地方,可能受函数调用的习惯,第4行程序return n * fact(n - 1)容易丢掉return,这样的程序会返回None,而不是想要的值。

理论上,所有的递归函数都可以写成循环的方式,但循环的逻辑不如递归清晰。使用递归函数需要注意防止栈溢出。在计算机中,函数调用是通过栈(stack)这种数据结构实现的,每当进入一个函数调用,栈就会加一层栈帧,每当函数返回,栈就会减一层栈帧。由于栈的大小不是无限的,所以,递归调用的次数过多,会导致栈溢出。

解决递归调用栈溢出的方法是通过尾递归优化,尾递归是指 在函数返回的时候,调用自身本身,并且,return语句不能包含表达式。这样,编译器或者解释器就可以把尾递归做优化,使递归本身无论调用多少次,都只占用一个栈帧,不会出现栈溢出的情况。

1
2
3
4
5
6
7
def fact(n):
return fact_iter(n, 1)

def fact_iter(num, product):
if num == 1:
return product
return fact_iter(num - 1, num * product)

但是:

Python标准的解释器没有针对尾递归做优化,任何递归函数都存在栈溢出的问题。

上面的例子还存在一个问题,如果同一层存在多个需要递归调用的情况,要怎么处理呢?看下面的例子!

莱文斯坦距离

莱文斯坦距离是衡量两个字符串编辑距离的一种方式。编辑距离值的是,将一个字符串转化成另一个字符串,需要的最少编辑操作次数(比如增加一个字符,删除一个字符,替换一个字符)。编辑距离越大,说明两个字符串的相似程度越小;相反,编辑距离就越小,说明两个字符串的相似程度越大。两个完全相同的字符串,其编辑距离就是0。莱文斯坦距离允许增加、删除、替换这三种编辑方式,从两个字符串差异的大小角度出发。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
stra = 'a'
strb = 'b'
min_edit_num = 999
def distance(i,j,edit_num):
global min_edit_num
if i == len(stra) or j == len(strb):
print(i,j)
if i < len(stra):
edit_num += len(stra) - i
if j < len(strb):
edit_num += len(strb) - j
if min_edit_num > edit_num:
edit_num = min_edit_num
return edit_num

if stra[i] == strb[j]:
return distance(i+1, j+1, edit_num)
else:
return distance(i+1, j, edit_num+1)
return distance(i, j+1, edit_num+1)
return distance(i+1, j+1, edit_num+1)

print(distance(0,0,0))

注意,上面这个程序执行的结果是错误的!因为在执行stra[i] != strb[j]时,永远只会返回第一个return distance(i+1, j, edit_num+1),程序就结束了,而后续的两个return语句不会被执行。那要怎么处理呢?我采用的方式是不添加ruturn语句,用一个list来存储每次计算的结果。

修改后的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
stra = 'a'
strb = 'bc'
result = []

def distance(i,j,edit_num):
if i == len(stra) or j == len(strb):
if i < len(stra):
edit_num += len(stra) - i
if j < len(strb):
edit_num += len(strb) - j
result.append(edit_num)
return # 这个ruturn不可省略

if stra[i] == strb[j]:
distance(i+1, j+1, edit_num)
else:
distance(i+1, j, edit_num+1)
distance(i, j+1, edit_num+1)
distance(i+1, j+1, edit_num+1)

distance(0,0,0)
print(min(result))

或者,当不需要记录每次执行修改的次数可以用全局变量的方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
stra = 'a'
strb = 'c'
result = 999

def distance(i,j,edit_num):
global result
if i == len(stra) or j == len(strb):
if i < len(stra):
edit_num += len(stra) - i
if j < len(strb):
edit_num += len(strb) - j
if edit_num < result:
result = edit_num
return # 这个ruturn不可省略

if stra[i] == strb[j]:
distance(i+1, j+1, edit_num)
else:
distance(i+1, j, edit_num+1)
distance(i, j+1, edit_num+1)
distance(i+1, j+1, edit_num+1)

distance(0,0,0)
print(result)

参考资料