itertools --- 建立產生高效率迴圈之疊代器的函式


這個模組實作了許多 疊代器 (iterator) 構建塊 (building block),其靈感來自 APL、Haskell 和 SML 的結構。每個構建塊都以適合 Python 的形式來重新設計。

這個模組標準化了快速且高效率利用記憶體的核心工具集,這些工具本身或組合使用都很有用。它們共同構成了一個「疊代器代數 (iterator algebra)」,使得在純 Python 中簡潔且高效地建構專用工具成為可能。

例如,SML 提供了一個造表工具:tabulate(f),它產生一個序列 f(0), f(1), ...。在 Python 中,可以透過結合 map()count() 組成 map(f, count()) 以達到同樣的效果。

一般疊代器:

疊代器

引數

結果

範例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, ...

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, ..., p_n-1), ...

batched('ABCDEFG', n=3) ABC DEF G

chain()

p, q, ...

p0, p1, ... plast, q0, q1, ...

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

可疊代物件

p0, p1, ... plast, q0, q1, ...

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), ...

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

count()

[start[, step]]

start, start+step, start+2*step, ...

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, ... plast, p0, p1, ...

cycle('ABCD') A B C D A B C D ...

dropwhile()

predicate, seq

seq[n], seq[n+1],當 predicate 失敗時開始

dropwhile(lambda x: x<5, [1,4,6,3,8]) 6 3 8

filterfalse()

predicate, seq

當 predicate(elem) 失敗時 seq 的元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) 6 8

groupby()

iterable[, key]

根據 key(v) 的值分組的子疊代器

groupby(['A','B','DEF'], len) (1, A B) (3, DEF)

islice()

seq, [start,] stop [, step]

seq[start:stop:step] 的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

可疊代物件

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

repeat()

elem [,n]

elem, elem, elem,... 重複無限次或 n 次

repeat(10, 3) 10 10 10

starmap()

func, seq

func(*seq[0]), func(*seq[1]), ...

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1],直到 predicate 失敗

takewhile(lambda x: x<5, [1,4,6,3,8]) 1 4

tee()

it, n

it1, it2, ... itn,將一個疊代器分成 n 個

tee('ABC', 2) A B C, A B C

zip_longest()

p, q, ...

(p[0], q[0]), (p[1], q[1]), ...

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

組合疊代器:

疊代器

引數

結果

product()

p, q, ... [repeat=1]

笛卡爾乘積 (cartesian product),相當於巢狀的 for 迴圈

permutations()

p[, r]

長度為 r 的元組,所有可能的定序,無重複元素

combinations()

p, r

長度為 r 的元組,按照排序過後的定序,無重複元素

combinations_with_replacement()

p, r

長度為 r 的元組,按照排序過後的定序,有重複元素

範例

結果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函式

以下的函式都會建構並回傳疊代器。一些函式提供無限長度的串流 (stream),因此應僅由截斷串流的函式或迴圈來存取它們。

itertools.accumulate(iterable[, function, *, initial=None])

建立一個回傳累積和的疊代器,或其他二進位函式的累積結果。

function 預設為加法。function 應接受兩個引數,即累積總和和來自 iterable 的值。

如果提供了 initial 值,則累積將從該值開始,並且輸出的元素數將比輸入的可疊代物件多一個。

大致等價於:

def accumulate(iterable, function=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return

    yield total
    for element in iterator:
        total = function(total, element)
        yield total

function 引數可以被設定為 min() 以得到連續的最小值,設定為 max() 以得到連續的最大值,或者設定為 operator.mul() 以得到連續的乘積。也可以透過累積利息和付款來建立攤銷表 (Amortization tables)

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # 運行最大值
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # 運行乘積
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# 攤銷一筆 1000 的 5% 貸款,分 10 年、每年支付 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

可參見 functools.reduce(),其是個類似的函式,但僅回傳最終的累積值。

在 3.2 版被加入.

在 3.3 版的變更: 新增可選的 function 參數。

在 3.8 版的變更: 新增可選的 initial 參數。

itertools.batched(iterable, n, *, strict=False)

將來自 iterable 的資料分批為長度為 n 的元組。最後一個批次可能比 n 短。

如果 strict 為真,則當最後一個批次比 n 短時,會引發 ValueError

對輸入的可疊代物件進行迴圈,並將資料累積到大小為 n 的元組中。輸入是惰性地被消耗 (consumed lazily) 的,會剛好足夠填充一批的資料。一旦批次填滿或輸入的可疊代物件耗盡,就會 yield 出結果:

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致等價於:

def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

在 3.12 版被加入.

在 3.13 版的變更: 新增 strict 選項。

itertools.chain(*iterables)

建立一個疊代器,從第一個可疊代物件回傳元素直到其耗盡,然後繼續處理下一個可疊代物件,直到所有可疊代物件都耗盡。這將多個資料來源結合為單一個疊代器。大致等價於:

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

chain() 的另一個建構函式。從單個可疊代的引數中得到鏈接的輸入,該引數是惰性計算的。大致等價於:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

從輸入 iterable 中回傳長度為 r 的元素的子序列。

輸出是 product() 的子序列,僅保留作為 iterable 子序列的條目。輸出的長度由 math.comb() 給定,當 0 r n 時,長度為 n! / r! / (n - r)!,當 r > n 時為零。

根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則每個組合內將不會有重複的值。

大致等價於:

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123

    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)
itertools.combinations_with_replacement(iterable, r)

回傳來自輸入 iterable 的長度為 r 的子序列,且允許個別元素重複多次。

其輸出是一個 product() 的子序列,僅保留作為 iterable 子序列(可能有重複元素)的條目。當 n > 0 時,回傳的子序列數量為 (n + r - 1)! / r! / (n - 1)!

根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,生成的組合也將是獨特的。

大致等價於:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

在 3.1 版被加入.

itertools.compress(data, selectors)

建立一個疊代器,回傳 data 中對應 selectors 的元素為 true 的元素。當 dataselectors 可疊代物件耗盡時停止。大致等價於:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

在 3.1 版被加入.

itertools.count(start=0, step=1)

建立一個疊代器,回傳從 start 開始的等差的值。可以與 map() 一起使用來產生連續的資料點,或與 zip() 一起使用來增加序列號。大致等價於:

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

當用浮點數計數時,將上述程式碼替換為乘法有時可以獲得更好的精確度,例如:(start + step * i for i in count())

在 3.1 版的變更: 新增 step 引數並允許非整數引數。

itertools.cycle(iterable)

建立一個疊代器,回傳 iterable 中的元素並保存每個元素的副本。當可疊代物件耗盡時,從保存的副本中回傳元素。會無限次的重複。大致等價於:

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...

    saved = []
    for element in iterable:
        yield element
        saved.append(element)

    while saved:
        for element in saved:
            yield element

此 itertool 可能需要大量的輔助儲存空間(取決於可疊代物件的長度)。

itertools.dropwhile(predicate, iterable)

建立一個疊代器,在 predicate 為 true 時丟棄 iterable 中的元素,之後回傳每個元素。大致等價於:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

    iterator = iter(iterable)
    for x in iterator:
        if not predicate(x):
            yield x
            break

    for x in iterator:
        yield x

注意,在 predicate 首次變為 False 之前,這不會產生任何輸出,所以此 itertool 可能會有較長的啟動時間。

itertools.filterfalse(predicate, iterable)

建立一個疊代器,過濾 iterable 中的元素,僅回傳 predicate 為 False 值的元素。如果 predicateNone,則回傳為 False 的項目。大致等價於:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8

    if predicate is None:
        predicate = bool

    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

建立一個疊代器,回傳 iterable 中連續的鍵和群組。key 是一個為每個元素計算鍵值的函式。如果其未指定或為 None,則 key 預設為一個識別性函式 (identity function),並回傳未被更改的元素。一般來說,可疊代物件需要已經用相同的鍵函式進行排序。

groupby() 的操作類似於 Unix 中的 uniq 過濾器。每當鍵函式的值發生變化時,它會產生一個 break 或新的群組(這就是為什麼通常需要使用相同的鍵函式對資料進行排序)。這種行為不同於 SQL 的 GROUP BY,其無論輸入順序如何都會聚合相同的元素。

回傳的群組本身是一個與 groupby() 共享底層可疊代物件的疊代器。由於來源是共享的,當 groupby() 物件前進時,前一個群組將不再可見。因此,如果之後需要該資料,應將其儲存為串列:

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # 將群組疊代器儲存為串列
    uniquekeys.append(k)

groupby() 大致等價於:

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

建立一個疊代器,回傳從 iterable 中選取的元素。其作用類似於序列切片 (sequence slicing),但不支援負數的 startstopstep 的值。

如果 start 為零或 None,則從零開始疊代。否則在達到 start 之前,會跳過 iterable 中的元素。

如果 stopNone,則疊代將繼續前進直到輸入耗盡。如果指定了 stop,則在達到指定位置時停止。

如果 stepNone,則步長 (step) 預設為一。元素會連續回傳,除非將 step 設定為大於一,這會導致一些項目被跳過。

大致等價於:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step

若輸入為疊代器,則完整耗盡 islice 會使輸入的疊代器向前移動 max(start, stop) 步,與 step 的值無關。

itertools.pairwise(iterable)

回傳從輸入的 iterable 中提取的連續重疊對。

輸出疊代器中的 2 元組數量將比輸入少一個。如果輸入的可疊代物件中的值少於兩個,則輸出將為空值。

大致等價於:

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG

    iterator = iter(iterable)
    a = next(iterator, None)

    for b in iterator:
        yield a, b
        a = b

在 3.10 版被加入.

itertools.permutations(iterable, r=None)

回傳 iterable 中連續且長度為 r元素排列

如果未指定 r 或其值為 None,則 r 預設為 iterable 的長度,並產生所有可能的完整長度的排列。

輸出是 product() 的子序列,其中重複元素的條目已被濾除。輸出的長度由 math.perm() 給定,當 0 r n 時,長度為 n! / (n - r)!,當 r > n 時為零。

根據輸入值 iterable 的順序,排列的元組會按照字典順序輸出。如果輸入的 iterable 已排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則排列中將不會有重複的值。

大致等價於:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return
itertools.product(*iterables, repeat=1)

輸入可疊代物的 笛卡爾乘積

大致等價於產生器運算式中的巢狀 for 迴圈。例如,product(A, B) 的回傳結果與 ((x,y) for x in A for y in B) 相同。

巢狀迴圈的循環類似於里程表,最右邊的元素在每次疊代時前進。這種模式會建立字典順序,因此如果輸入的 iterables 已排序,則輸出的乘積元組也將按排序的順序產生。

要計算可疊代物件自身的乘積,可以使用可選的 repeat 關鍵字引數來指定重複次數。例如,product(A, repeat=4)product(A, A, A, A) 相同。

此函式大致等價於以下的程式碼,不同之處在於真正的實作不會在記憶體中建立中間結果:

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

    if repeat < 0:
        raise ValueError('repeat argument cannot be negative')
    pools = [tuple(pool) for pool in iterables] * repeat

    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]

    for prod in result:
        yield tuple(prod)

product() 執行之前,它會完全消耗輸入的 iterables,並將值的池 (pools of values) 保存在記憶體中以產生乘積。因此,它僅對有限的輸入有用。

itertools.repeat(object[, times])

建立一個疊代器,反覆回傳 object。除非指定了 times 引數,否則會執行無限次。

大致等價於:

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

repeat 的常見用途是為 mapzip 提供定值的串流:

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]