Python 生成器和迭代器
生成器和迭代器是 Python 中处理数据流的核心概念。它们提供了一种优雅的方式来处理大数据集、无限序列和惰性计算,是 Python "简洁胜于复杂"哲学的完美体现。
理解生成器和迭代器不仅能让你写出更高效的代码,还能帮助你深入理解 Python 的 for 循环、推导式等语言特性的底层实现。
迭代器协议:Python 循环的基石
什么是迭代器协议?
迭代器协议(Iterator Protocol)是 Python 中实现迭代的标准接口。它定义了对象如何参与 for 循环和其他迭代上下文。
迭代器协议包含两个核心方法:
| 方法 | 说明 |
|---|---|
__iter__() | 返回迭代器对象本身 |
__next__() | 返回下一个元素,没有更多元素时抛出 StopIteration |
可迭代对象(Iterable):实现了 __iter__() 方法的对象,返回一个迭代器。
迭代器(Iterator):同时实现了 __iter__() 和 __next__() 方法的对象。
迭代器协议的工作流程
for 循环的底层实现
当你写 for item in iterable: 时,Python 实际上执行了以下操作:
# for item in iterable:
# do_something(item)
# 等价于:
iterator = iter(iterable) # 调用 __iter__()
while True:
try:
item = next(iterator) # 调用 __next__()
do_something(item)
except StopIteration:
break # 捕获 StopIteration,结束循环
这解释了为什么 for 循环能够优雅地处理各种可迭代对象,而无需关心它们的具体类型。
手动创建迭代器
理解迭代器协议后,我们可以创建自定义迭代器:
class Counter:
"""一个简单的计数器迭代器"""
def __init__(self, start, end):
self.current = start
self.end = end
def __iter__(self):
# 迭代器必须返回自身
return self
def __next__(self):
# 返回下一个值,或抛出 StopIteration
if self.current >= self.end:
raise StopIteration
value = self.current
self.current += 1
return value
# 使用自定义迭代器
counter = Counter(1, 5)
for num in counter:
print(num) # 1, 2, 3, 4
# 迭代器是一次性的
for num in counter:
print(num) # 不会打印任何内容,迭代器已耗尽
可迭代对象 vs 迭代器
一个常见的混淆点:可迭代对象不是迭代器。
# 列表是可迭代对象,但不是迭代器
my_list = [1, 2, 3]
# 每次调用 iter() 都返回新的迭代器
it1 = iter(my_list)
it2 = iter(my_list)
print(it1 is it2) # False,两个不同的迭代器
# 列表本身没有 __next__ 方法
# my_list.__next__() # AttributeError
# 迭代器必须同时有 __iter__ 和 __next__
print(hasattr(my_list, '__next__')) # False
print(hasattr(it1, '__next__')) # True
关键区别:
| 特性 | 可迭代对象 | 迭代器 |
|---|---|---|
__iter__() | 返回新的迭代器 | 返回自身 |
__next__() | 不存在 | 返回下一个元素 |
| 可重复遍历 | 是 | 否(一次性) |
| 示例 | 列表、字符串、字典 | 文件对象、生成器 |
创建可重复遍历的可迭代对象
如果你需要一个可以重复遍历的自定义容器,应该让 __iter__() 返回新的迭代器:
class Fibonacci:
"""可重复遍历的斐波那契数列"""
def __init__(self, n):
self.n = n
def __iter__(self):
# 每次调用都返回新的迭代器
return FibonacciIterator(self.n)
class FibonacciIterator:
"""斐波那契迭代器"""
def __init__(self, n):
self.n = n
self.a, self.b = 0, 1
self.count = 0
def __iter__(self):
return self
def __next__(self):
if self.count >= self.n:
raise StopIteration
value = self.a
self.a, self.b = self.b, self.a + self.b
self.count += 1
return value
# 可以重复遍历
fib = Fibonacci(5)
print(list(fib)) # [0, 1, 1, 2, 3]
print(list(fib)) # [0, 1, 1, 2, 3] - 可以再次遍历
生成器:优雅的迭代器
什么是生成器?
生成器(Generator)是一种简化迭代器创建的语法糖。它使用 yield 关键字,让 Python 自动为你实现迭代器协议。
def count_up_to(n):
"""一个简单的生成器函数"""
i = 1
while i <= n:
yield i # 暂停并返回值
i += 1
# 创建生成器对象
counter = count_up_to(5)
print(type(counter)) # <class 'generator'>
# 使用生成器
for num in counter:
print(num) # 1, 2, 3, 4, 5
yield 的工作原理
yield 是生成器的核心。当函数中包含 yield 时,它就变成了生成器函数。
yield 的执行过程:
- 调用生成器函数时,函数体不会立即执行,而是返回一个生成器对象
- 调用
next()时,函数执行到第一个yield,暂停并返回值 - 再次调用
next()时,从上次暂停的位置继续执行 - 函数结束或遇到
return时,抛出StopIteration
def demonstrate_yield():
"""演示 yield 的执行过程"""
print("第一步:开始")
yield 1
print("第二步:继续")
yield 2
print("第三步:结束")
yield 3
print("函数执行完毕")
gen = demonstrate_yield()
print("生成器已创建,函数体尚未执行")
print(next(gen)) # 输出:第一步:开始 \n 1
print(next(gen)) # 输出:第二步:继续 \n 2
print(next(gen)) # 输出:第三步:结束 \n 3
# next(gen) # 输出:函数执行完毕,然后抛出 StopIteration
关键洞察:生成器函数是一种"可暂停"的函数。每次 yield 都会保存当前的执行状态(局部变量、指令指针等),下次调用时从暂停点继续。
生成器 vs 普通函数
# 普通函数:一次性计算所有结果
def get_squares_list(n):
result = []
for i in range(n):
result.append(i * i)
return result
# 生成器函数:按需计算,惰性求值
def get_squares_generator(n):
for i in range(n):
yield i * i
# 内存使用对比
import sys
# 列表:一次性占用所有内存
squares_list = get_squares_list(1000000)
print(f"列表大小: {sys.getsizeof(squares_list)} bytes") # 约 8MB+
# 生成器:几乎不占用内存
squares_gen = get_squares_generator(1000000)
print(f"生成器大小: {sys.getsizeof(squares_gen)} bytes") # 约 112 bytes
生成器表达式
生成器表达式是创建简单生成器的简洁语法,类似于列表推导式,但使用圆括号:
# 列表推导式:立即创建列表
squares_list = [x * x for x in range(10)]
# 生成器表达式:创建生成器
squares_gen = (x * x for x in range(10))
print(type(squares_list)) # <class 'list'>
print(type(squares_gen)) # <class 'generator'>
# 生成器表达式可以用于任何需要迭代器的地方
sum_of_squares = sum(x * x for x in range(1000000)) # 内存友好
max_value = max(x * x for x in range(100))
# 带条件的生成器表达式
even_squares = (x * x for x in range(20) if x % 2 == 0)
print(list(even_squares)) # [0, 4, 16, 36, 64, 100, 144, 196, 256, 324]
列表推导式 vs 生成器表达式:
| 特性 | 列表推导式 [] | 生成器表达式 () |
|---|---|---|
| 返回类型 | 列表 | 生成器 |
| 内存占用 | 高(存储所有元素) | 低(惰性计算) |
| 可重复遍历 | 是 | 否 |
| 适用场景 | 需要多次访问、索引操作 | 单次遍历、大数据集 |
嵌套生成器表达式
生成器表达式支持嵌套,从左到右阅读:
# 等价于嵌套 for 循环
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
# 展平矩阵
flat = (item for row in matrix for item in row)
print(list(flat)) # [1, 2, 3, 4, 5, 6, 7, 8, 9]
# 带条件的嵌套
# 只选择偶数
evens = (item for row in matrix for item in row if item % 2 == 0)
print(list(evens)) # [2, 4, 6, 8]
生成器的高级特性
yield from:委托给子生成器
yield from 用于将生成操作委托给另一个生成器,简化嵌套生成器的代码:
# 不使用 yield from
def chain_old(*iterables):
for iterable in iterables:
for item in iterable:
yield item
# 使用 yield from(更简洁)
def chain(*iterables):
for iterable in iterables:
yield from iterable
result = chain([1, 2], [3, 4], [5, 6])
print(list(result)) # [1, 2, 3, 4, 5, 6]
# 递归使用 yield from
def flatten(nested):
"""展平任意嵌套的序列"""
for item in nested:
if isinstance(item, (list, tuple)):
yield from flatten(item)
else:
yield item
nested = [1, [2, 3], [4, [5, 6]], 7]
print(list(flatten(nested))) # [1, 2, 3, 4, 5, 6, 7]
yield from 的返回值:
yield from 表达式有返回值,它是子生成器终止时由 StopIteration 异常携带的值,或者是子生成器的 return 值:
def accumulator():
total = 0
value = 0
while True:
value = yield total
if value is None:
break
total += value
return total # 生成器的返回值
def wrapper():
# result 会获取 accumulator() 的返回值
result = yield from accumulator()
print(f"累积结果: {result}")
gen = wrapper()
next(gen) # 启动生成器
print(gen.send(10)) # 10
print(gen.send(20)) # 30
gen.send(None) # 累积结果: 30, 然后抛出 StopIteration
send():向生成器发送数据
生成器不仅可以产出数据,还可以接收数据。send() 方法允许我们向生成器发送值,这个值会成为 yield 表达式的结果:
def accumulator():
"""累加器生成器"""
total = 0
while True:
# yield 返回累计值,同时接收外部发送的新值
value = yield total
if value is not None:
total += value
acc = accumulator()
next(acc) # 启动生成器,必须先调用 next() 或 send(None)
print(acc.send(10)) # 10
print(acc.send(20)) # 30
print(acc.send(5)) # 35
# 关闭生成器
acc.close()
# acc.send(10) # StopIteration
协程风格的生成器:
def averager():
"""计算移动平均值的协程"""
total = 0
count = 0
average = None
while True:
value = yield average
if value is None:
break
total += value
count += 1
average = total / count
return total, count, average
avg = averager()
next(avg) # 启动
print(avg.send(10)) # 10.0
print(avg.send(20)) # 15.0
print(avg.send(30)) # 20.0
# 获取最终结果
try:
avg.send(None)
except StopIteration as e:
print(f"总计: {e.value}") # 总计: (60, 3, 20.0)
throw():向生成器抛出异常
throw() 方法允许在生成器暂停的位置抛出异常:
def robust_processor():
"""带有异常处理的生成器"""
while True:
try:
value = yield
print(f"处理: {value}")
except ValueError as e:
print(f"捕获 ValueError: {e}")
except Exception as e:
print(f"捕获异常: {e}")
raise # 重新抛出未处理的异常
proc = robust_processor()
next(proc) # 启动
proc.send(10) # 处理: 10
proc.throw(ValueError, "无效值") # 捕获 ValueError: 无效值
proc.send(20) # 处理: 20(生成器继续运行)
close():关闭生成器
close() 方法用于提前终止生成器,它会在暂停点抛出 GeneratorExit 异常:
def resource_user():
"""使用资源的生成器"""
print("获取资源")
try:
for i in range(100):
yield i
finally:
print("释放资源")
gen = resource_user()
print(next(gen)) # 获取资源 \n 0
print(next(gen)) # 1
gen.close() # 释放资源
# next(gen) # StopIteration
生成器的状态
生成器有四种状态,可以通过 inspect.getgeneratorstate() 查看:
import inspect
def my_gen():
yield 1
yield 2
gen = my_gen()
print(inspect.getgeneratorstate(gen)) # GEN_CREATED(已创建,未启动)
next(gen)
print(inspect.getgeneratorstate(gen)) # GEN_SUSPENDED(暂停中)
next(gen)
print(inspect.getgeneratorstate(gen)) # GEN_SUSPENDED
next(gen) # StopIteration
print(inspect.getgeneratorstate(gen)) # GEN_CLOSED(已关闭)
# 另一种状态:手动关闭后
gen2 = my_gen()
next(gen2)
gen2.close()
print(inspect.getgeneratorstate(gen2)) # GEN_CLOSED
| 状态 | 说明 |
|---|---|
GEN_CREATED | 生成器已创建,尚未启动 |
GEN_RUNNING | 生成器正在执行(多线程环境下可见) |
GEN_SUSPENDED | 生成器在 yield 处暂停 |
GEN_CLOSED | 生成器已关闭或执行完毕 |
生成器的应用场景
1. 处理大文件
生成器是处理大文件的理想选择,因为它不需要一次性加载整个文件到内存:
def read_large_file(file_path):
"""逐行读取大文件"""
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
def count_words(file_path):
"""统计大文件的单词数"""
word_count = 0
for line in read_large_file(file_path):
word_count += len(line.split())
return word_count
# 处理 GB 级别的文件也不会耗尽内存
# word_count = count_words('huge_file.txt')
2. 无限序列
生成器可以表示无限序列,因为它是惰性求值的:
def fibonacci():
"""无限斐波那契数列"""
a, b = 0, 1
while True:
yield a
a, b = b, a + b
def primes():
"""无限质数序列"""
def is_prime(n):
if n < 2:
return False
for i in range(2, int(n ** 0.5) + 1):
if n % i == 0:
return False
return True
n = 2
while True:
if is_prime(n):
yield n
n += 1
# 使用 itertools.islice 截取
from itertools import islice
# 获取前 10 个斐波那契数
fib = fibonacci()
print(list(islice(fib, 10))) # [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
# 获取前 10 个质数
prime_gen = primes()
print(list(islice(prime_gen, 10))) # [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
3. 数据管道
生成器非常适合构建数据处理管道,每个生成器处理一个步骤:
def read_lines(file_path):
"""读取文件行"""
with open(file_path, 'r') as f:
for line in f:
yield line
def filter_comments(lines):
"""过滤注释行"""
for line in lines:
if not line.strip().startswith('#'):
yield line
def strip_whitespace(lines):
"""去除空白"""
for line in lines:
yield line.strip()
def to_uppercase(lines):
"""转换为大写"""
for line in lines:
yield line.upper()
# 构建管道
pipeline = to_uppercase(
strip_whitespace(
filter_comments(
read_lines('config.txt')
)
)
)
for line in pipeline:
print(line)
4. 生成器实现树遍历
class TreeNode:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def inorder_traversal(node):
"""中序遍历生成器"""
if node is not None:
yield from inorder_traversal(node.left)
yield node.value
yield from inorder_traversal(node.right)
def preorder_traversal(node):
"""前序遍历生成器"""
if node is not None:
yield node.value
yield from preorder_traversal(node.left)
yield from preorder_traversal(node.right)
def level_order_traversal(root):
"""层序遍历生成器"""
from collections import deque
queue = deque([root])
while queue:
node = queue.popleft()
if node:
yield node.value
queue.append(node.left)
queue.append(node.right)
# 构建树
# 1
# / \
# 2 3
# / \
# 4 5
tree = TreeNode(1,
TreeNode(2, TreeNode(4), TreeNode(5)),
TreeNode(3)
)
print(list(inorder_traversal(tree))) # [4, 2, 5, 1, 3]
print(list(preorder_traversal(tree))) # [1, 2, 4, 5, 3]
5. 批量处理
def batch(iterable, size):
"""将可迭代对象分成批次"""
batch = []
for item in iterable:
batch.append(item)
if len(batch) == size:
yield batch
batch = []
if batch:
yield batch
# 处理大量数据
data = range(1000)
for chunk in batch(data, 100):
process_chunk(chunk) # 每次处理 100 个元素
6. 圆环迭代
from itertools import cycle
def round_robin(*iterables):
"""轮询迭代多个可迭代对象"""
pending = [iter(it) for it in iterables]
nexts = cycle(lambda: next(p) for p in pending)
while pending:
try:
for next_func in nexts:
yield next_func()
except StopIteration:
pending.remove(next_func.__closure__[0].cell_contents)
# 更简单的实现
def round_robin_simple(*iterables):
"""轮询迭代"""
from itertools import zip_longest
sentinel = object()
for tuple_item in zip_longest(*iterables, fillvalue=sentinel):
for item in tuple_item:
if item is not sentinel:
yield item
list1 = [1, 2, 3]
list2 = ['a', 'b', 'c', 'd', 'e']
list3 = ['X', 'Y']
print(list(round_robin_simple(list1, list2, list3)))
# [1, 'a', 'X', 2, 'b', 'Y', 3, 'c', 'd', 'e']
itertools 模块:迭代器工具箱
Python 的 itertools 模块提供了一系列高效的迭代器工具:
无限迭代器
from itertools import count, cycle, repeat
# count:无限计数
for i in count(10, 2): # 从 10 开始,步长 2
if i > 20:
break
print(i) # 10, 12, 14, 16, 18, 20
# cycle:无限循环
colors = cycle(['red', 'green', 'blue'])
for _ in range(5):
print(next(colors)) # red, green, blue, red, green
# repeat:重复元素
for item in repeat('hello', 3):
print(item) # hello, hello, hello
终止迭代器
from itertools import (
accumulate, chain, compress, dropwhile,
takewhile, filterfalse, islice, starmap
)
# accumulate:累积运算
print(list(accumulate([1, 2, 3, 4]))) # [1, 3, 6, 10]
print(list(accumulate([1, 2, 3, 4], initial=0))) # [0, 1, 3, 6, 10]
print(list(accumulate([1, 2, 3, 4], lambda x, y: x * y))) # [1, 2, 6, 24]
# chain:连接多个可迭代对象
print(list(chain([1, 2], [3, 4], [5]))) # [1, 2, 3, 4, 5]
print(list(chain.from_iterable([[1, 2], [3, 4]]))) # [1, 2, 3, 4]
# compress:根据选择器过滤
print(list(compress([1, 2, 3, 4], [1, 0, 1, 0]))) # [1, 3]
# dropwhile:丢弃直到条件为假
print(list(dropwhile(lambda x: x < 3, [1, 2, 3, 4, 5]))) # [3, 4, 5]
# takewhile:取值直到条件为假
print(list(takewhile(lambda x: x < 3, [1, 2, 3, 4, 5]))) # [1, 2]
# filterfalse:过滤掉满足条件的元素
print(list(filterfalse(lambda x: x % 2, [1, 2, 3, 4]))) # [2, 4]
# islice:迭代器切片
print(list(islice(range(10), 2, 8, 2))) # [2, 4, 6]
# starmap:应用函数到解包后的参数
print(list(starmap(pow, [(2, 3), (3, 2), (10, 0)]))) # [8, 9, 1]
组合迭代器
from itertools import product, permutations, combinations, combinations_with_replacement
# product:笛卡尔积
print(list(product([1, 2], ['a', 'b'])))
# [(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b')]
# permutations:排列
print(list(permutations([1, 2, 3], 2)))
# [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]
# combinations:组合
print(list(combinations([1, 2, 3], 2)))
# [(1, 2), (1, 3), (2, 3)]
# combinations_with_replacement:带重复的组合
print(list(combinations_with_replacement([1, 2], 2)))
# [(1, 1), (1, 2), (2, 2)]
常见陷阱与最佳实践
陷阱一:生成器是一次性的
def get_numbers():
yield from [1, 2, 3]
gen = get_numbers()
print(list(gen)) # [1, 2, 3]
print(list(gen)) # [] - 生成器已耗尽!
# 解决方案1:重新创建生成器
gen = get_numbers()
print(list(gen)) # [1, 2, 3]
gen = get_numbers()
print(list(gen)) # [1, 2, 3]
# 解决方案2:如果需要多次遍历,使用 list
numbers = list(get_numbers())
print(numbers) # [1, 2, 3]
print(numbers) # [1, 2, 3]
陷阱二:生成器表达式中的变量绑定
# 问题:闭包中的变量绑定
funcs = [lambda: i for i in range(3)]
print([f() for f in funcs]) # [2, 2, 2] - 不是 [0, 1, 2]!
# 生成器表达式也有同样的问题
funcs = list((lambda: i) for i in range(3))
print([f() for f in funcs]) # [2, 2, 2]
# 解决方案:使用默认参数捕获当前值
funcs = [lambda i=i: i for i in range(3)]
print([f() for f in funcs]) # [0, 1, 2]
陷阱三:部分消费生成器
def get_data():
yield from range(10)
gen = get_data()
# 消费部分元素
first_three = [next(gen) for _ in range(3)]
print(first_three) # [0, 1, 2]
# 剩余元素
print(list(gen)) # [3, 4, 5, 6, 7, 8, 9]
# 如果需要保留剩余元素,使用 tee
from itertools import tee
gen = get_data()
gen1, gen2 = tee(gen)
print(list(islice(gen1, 3))) # [0, 1, 2]
print(list(gen2)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
陷阱四:在生成器中修改正在迭代的序列
def problematic_generator():
data = [1, 2, 3, 4, 5]
for item in data:
if item == 3:
data.remove(item) # 危险!在迭代时修改序列
yield item
# 输出可能不是预期的
print(list(problematic_generator())) # [1, 2, 3, 5] 或其他结果
# 正确做法:迭代副本或使用列表推导式
def safe_generator():
data = [1, 2, 3, 4, 5]
for item in data[:]: # 迭代副本
if item != 3:
yield item
最佳实践
- 使用
yield from简化嵌套生成器 - 使用生成器表达式进行简单转换
- 使用
itertools模块而不是自己实现 - 明确区分可迭代对象和迭代器
- 在文档中注明生成器是一次性的
from typing import Generator, Iterator, Iterable
# 使用类型注解提高代码可读性
def count_up_to(n: int) -> Generator[int, None, None]:
"""生成 1 到 n 的整数序列。
这是一个生成器函数,返回的生成器是一次性的。
Args:
n: 上限(包含)
Yields:
从 1 到 n 的整数
"""
for i in range(1, n + 1):
yield i
def get_iterator(items: list) -> Iterator:
"""返回列表的迭代器。"""
return iter(items)
def process_items(items: Iterable[int]) -> list:
"""处理任何可迭代对象。"""
return [x * 2 for x in items]
小结
本章深入学习了 Python 生成器和迭代器:
迭代器协议:
__iter__()返回迭代器__next__()返回下一个元素或抛出StopIteration- 可迭代对象 ≠ 迭代器
生成器基础:
- 使用
yield创建生成器函数 - 生成器表达式:
(x for x in iterable) - 惰性求值,节省内存
生成器高级特性:
yield from委托给子生成器send()向生成器发送数据throw()抛出异常close()关闭生成器
应用场景:
- 处理大文件
- 无限序列
- 数据管道
- 树遍历
最佳实践:
- 使用
itertools模块 - 注意生成器是一次性的
- 使用类型注解
练习
- 实现一个生成器,产出指定范围内的所有质数
- 实现一个生成器,按层遍历树结构
- 使用生成器实现一个简单的日志分析管道(读取 → 过滤 → 格式化 → 输出)
- 实现一个带缓存的生成器,可以重复遍历
- 使用
send()实现一个简单的计算器生成器