Python itertools 模块中的全排列算法,看起来简单却非常让人费解?

在刷 leetcode 时遇到了全排列问题,想来想去也只能想出递归的解法。然后突然想起 Python 的 itertools 模块中有全排列的函数,模块源码是用 C 写的,不过在官方文档中有提供 Python 版本的代码,代码看起来非常简单,但是我看了很久都看不懂算法的原理是什么。

接着在 SO 上发现了相关问题,根据“ Alex Martelli ”的回答,该算法涉及到了Cyclic permutation 理论。有兴趣的同学可以研究一下代码和 Cyclic permutation 理论。

附上代码

def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

Python itertools 模块中的全排列算法,看起来简单却非常让人费解?

12 回复

谢谢分享,学习了


itertools.permutations 的原理其实很直接:它生成输入可迭代对象中所有可能的、长度为 r 的排列元组。如果 r 未指定,则默认使用可迭代对象的长度。关键在于,它基于元素的位置而非值来生成排列,并且会处理重复元素。

直接看代码最清楚。下面是一个简化版的 permutations 实现,它揭示了核心逻辑:

def permutations(iterable, r=None):
    # 将输入转换为元组,以便索引
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r

    if r > n:
        return

    # 初始化索引列表,代表当前排列中每个位置在 pool 中的索引
    indices = list(range(n))
    # cycles 列表用于控制回溯:cycles[i] 表示当 indices[i] 完成一轮循环后需要重置
    cycles = list(range(n, n - r, -1))

    # 生成第一个排列(按初始索引顺序)
    yield tuple(pool[i] for i in indices[:r])

    while True:
        # 从最右侧的位置开始回溯
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                # 如果当前位循环完毕,将其索引与后方未使用的索引进行“滚动”交换
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i  # 重置循环计数器
            else:
                # 还有可用的循环,交换 indices[i] 和它后方第 cycles[i] 个未使用的位置
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                # 生成新排列
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            # 所有位置都循环完毕,生成结束
            return

# 测试一下
print(list(permutations([1, 2, 3], 2)))
# 输出: [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]

核心机制:

  1. 基于索引操作:算法并不直接操作元素值,而是维护一个 indices 列表,记录当前排列中每个位置对应原始序列(pool)中的索引。
  2. 回溯与交换:通过 cycles 列表控制每个位置的回溯深度。每次生成新排列时,从最右位尝试“进位”,通过交换 indices 中的索引来改变该位置对应的元素。
  3. 处理重复值:由于依赖索引,如果输入 pool 中有重复值(如 [1, 1, 2]),生成的排列元组在值上可能看起来是重复的,但它们在索引层面是不同的排列。

一句话总结:把它理解成一个在索引数组上运行的特殊“计数器”,每次“进位”都通过交换索引来产生新排列。

C++也有啊

stl 也一样, 全排列的标准算法

好像网上讨论这个算法的人比较少,我看了大都是递归和字典序算法

做缓存时蛮方便的

NOIP2004 普及组最后一题.我以为已经是常识了……

我一般用的是二进制来做的。
A->1000
B->0100
C->0010
D->0001

例如
0110 对应了 BC
1011 对应了 A CD

所以只要计算 1 ( 0001 )到 15 ( 1111 )就可以得到所有组合

4 个字母的全排列不是有 24 种情况吗?计算 1-15 怎么得到所有组合?

我说的是一个类似的题的做法,全排列因此稍稍改一下。

他算出的是组合,题主问的是排列

受教了! 9 年前大学毕业面试题就做过这题,我一直以为我的解法是对的,今天看到你的答案比我的好多了!

回到顶部