跳到主要内容

闭包和迭代器

闭包和迭代器是 Rust 函数式编程的核心特性。闭包是匿名函数,可以捕获环境变量;迭代器提供惰性求值和链式操作。

闭包

什么是闭包?

闭包是可以保存在变量中或作为参数传递的匿名函数。与普通函数不同,闭包可以捕获其定义环境中的变量。

fn main() {
// 普通函数
fn add_one(x: i32) -> i32 {
x + 1
}

// 闭包语法
let add_one_closure = |x: i32| x + 1;

println!("函数: {}", add_one(5));
println!("闭包: {}", add_one_closure(5));
}

闭包语法

fn main() {
// 完整语法
let add = |x: i32| -> i32 { x + 1 };

// 省略返回类型
let add = |x: i32| { x + 1 };

// 省略大括号(单表达式)
let add = |x: i32| x + 1;

// 省略参数类型(编译器推断)
let add = |x| x + 1;

// 无参数闭包
let say_hello = || println!("你好!");
say_hello();

// 多参数闭包
let multiply = |a, b| a * b;
println!("乘积: {}", multiply(3, 4));
}

捕获环境变量

闭包可以捕获其定义环境中的变量:

fn main() {
let x = 10;

// 闭包捕获 x
let add_x = |y| x + y;

println!("结果: {}", add_x(5)); // 15

// 可以访问多次
println!("结果: {}", add_x(10)); // 20
}

捕获方式

闭包有三种捕获方式:

方式关键字说明
不可变借用Fn只读取,不修改
可变借用FnMut可以修改
获取所有权FnOnce消耗变量
fn main() {
// Fn:不可变借用
let list = vec![1, 2, 3];
let print_list = || {
println!("列表: {:?}", list); // 不可变借用
};
print_list();
println!("列表仍然可用: {:?}", list);

// FnMut:可变借用
let mut count = 0;
let mut increment = || {
count += 1; // 可变借用
println!("计数: {}", count);
};
increment();
increment();
println!("最终计数: {}", count);

// FnOnce:获取所有权
let name = String::from("Rust");
let consume = move || {
println!("名字: {}", name);
// name 被消耗
};
consume();
// println!("{}", name); // 错误!name 已被移动
}

move 关键字

使用 move 强制闭包获取所有权:

fn main() {
let data = vec![1, 2, 3];

// 使用 move 获取所有权
let closure = move || {
println!("数据: {:?}", data);
};

// data 已被移动
// println!("{:?}", data); // 错误!

closure();
}

// 常用于创建线程
use std::thread;

fn main() {
let v = vec![1, 2, 3];

let handle = thread::spawn(move || {
println!("线程中的向量: {:?}", v);
});

handle.join().unwrap();
}

闭包作为参数

// 使用泛型接受闭包
fn apply<F>(f: F)
where
F: FnOnce(),
{
f();
}

// 使用 Fn trait
fn apply_to_value<F>(value: i32, f: F) -> i32
where
F: Fn(i32) -> i32,
{
f(value)
}

// 使用 FnMut trait
fn apply_mut<F>(list: &mut Vec<i32>, f: F)
where
F: FnMut(&mut i32),
{
for item in list.iter_mut() {
f(item);
}
}

fn main() {
let x = 10;

// 作为参数传递
apply(|| println!("x = {}", x));

// 带返回值
let doubled = apply_to_value(5, |x| x * 2);
println!("加倍: {}", doubled);

// 修改值
let mut numbers = vec![1, 2, 3];
apply_mut(&mut numbers, |x| *x *= 2);
println!("修改后: {:?}", numbers);
}

闭包作为返回值

fn create_adder(x: i32) -> impl Fn(i32) -> i32 {
move |y| x + y
}

fn create_multiplier(x: i32) -> Box<dyn Fn(i32) -> i32> {
Box::new(move |y| x * y)
}

fn main() {
let adder = create_adder(10);
println!("结果: {}", adder(5)); // 15

let multiplier = create_multiplier(3);
println!("结果: {}", multiplier(4)); // 12
}

实际应用:缓存计算

use std::collections::HashMap;
use std::hash::Hash;

struct Cacher<T, K, V>
where
T: Fn(K) -> V,
K: Eq + Hash + Clone,
V: Clone,
{
calculation: T,
cache: HashMap<K, V>,
}

impl<T, K, V> Cacher<T, K, V>
where
T: Fn(K) -> V,
K: Eq + Hash + Clone,
V: Clone,
{
fn new(calculation: T) -> Cacher<T, K, V> {
Cacher {
calculation,
cache: HashMap::new(),
}
}

fn value(&mut self, arg: K) -> V {
if let Some(v) = self.cache.get(&arg) {
v.clone()
} else {
let v = (self.calculation)(arg.clone());
self.cache.insert(arg, v.clone());
v
}
}
}

fn main() {
let mut cache = Cacher::new(|n| {
println!("计算 {} 的平方", n);
n * n
});

println!("结果: {}", cache.value(5)); // 会计算
println!("结果: {}", cache.value(5)); // 使用缓存
println!("结果: {}", cache.value(10)); // 会计算
}

迭代器

什么是迭代器?

迭代器是实现了 Iterator trait 的类型,提供遍历集合元素的方法。

pub trait Iterator {
type Item;

fn next(&mut self) -> Option<Self::Item>;

// 其他方法有默认实现...
}

基本使用

fn main() {
let v = vec![1, 2, 3];

// 创建迭代器
let mut iter = v.iter();

// 手动调用 next
println!("{:?}", iter.next()); // Some(1)
println!("{:?}", iter.next()); // Some(2)
println!("{:?}", iter.next()); // Some(3)
println!("{:?}", iter.next()); // None
}

三种迭代器

fn main() {
let v = vec![1, 2, 3];

// iter():不可变借用
let iter: Iterator<Item = &i32> = v.iter();
for val in v.iter() {
println!("不可变: {}", val);
}

// iter_mut():可变借用
let mut v_mut = vec![1, 2, 3];
for val in v_mut.iter_mut() {
*val *= 2;
}
println!("修改后: {:?}", v_mut);

// into_iter():获取所有权
let v_own = vec![1, 2, 3];
for val in v_own.into_iter() {
println!("获取所有权: {}", val);
}
// v_own 不再可用
}

消费适配器

消费适配器会消耗迭代器并返回结果:

fn main() {
let v = vec![1, 2, 3, 4, 5];

// sum:求和
let sum: i32 = v.iter().sum();
println!("求和: {}", sum);

// product:求积
let product: i32 = v.iter().product();
println!("求积: {}", product);

// collect:收集成集合
let doubled: Vec<i32> = v.iter().map(|x| x * 2).collect();
println!("加倍: {:?}", doubled);

// count:计数
let count = v.iter().count();
println!("数量: {}", count);

// max/min:最大/最小值
let max = v.iter().max();
let min = v.iter().min();
println!("最大: {:?}, 最小: {:?}", max, min);

// find:查找
let found = v.iter().find(|&&x| x > 3);
println!("找到: {:?}", found);

// position:位置
let pos = v.iter().position(|&x| x == 3);
println!("位置: {:?}", pos);

// any/all:任意/全部
let any_gt_3 = v.iter().any(|&x| x > 3);
let all_positive = v.iter().all(|&x| x > 0);
println!("有大于3的: {}, 全部大于0: {}", any_gt_3, all_positive);

// fold:累积
let sum_fold = v.iter().fold(0, |acc, &x| acc + x);
println!("fold求和: {}", sum_fold);

// reduce:简化版 fold
let max_reduce = v.iter().reduce(|a, b| if a > b { a } else { b });
println!("reduce最大值: {:?}", max_reduce);
}

迭代适配器

迭代适配器返回新的迭代器,支持链式调用:

fn main() {
let v = vec![1, 2, 3, 4, 5];

// map:映射
let doubled: Vec<i32> = v.iter().map(|x| x * 2).collect();
println!("加倍: {:?}", doubled);

// filter:过滤
let even: Vec<&i32> = v.iter().filter(|x| *x % 2 == 0).collect();
println!("偶数: {:?}", even);

// filter_map:过滤并映射
let processed: Vec<i32> = v.iter()
.filter_map(|x| if *x > 2 { Some(x * 2) } else { None })
.collect();
println!("处理后: {:?}", processed);

// take:取前 n 个
let first_three: Vec<&i32> = v.iter().take(3).collect();
println!("前三个: {:?}", first_three);

// skip:跳过前 n 个
let skip_two: Vec<&i32> = v.iter().skip(2).collect();
println!("跳过两个: {:?}", skip_two);

// take_while:条件取值
let less_than_four: Vec<&i32> = v.iter()
.take_while(|x| **x < 4)
.collect();
println!("小于4: {:?}", less_than_four);

// skip_while:条件跳过
let after_one: Vec<&i32> = v.iter()
.skip_while(|x| **x < 3)
.collect();
println!("跳过小于3: {:?}", after_one);

// zip:配对
let names = vec!["Alice", "Bob", "Charlie"];
let ages = vec![25, 30, 35];
let pairs: Vec<_> = names.iter().zip(ages.iter()).collect();
println!("配对: {:?}", pairs);

// chain:链接
let v1 = vec![1, 2];
let v2 = vec![3, 4];
let chained: Vec<&i32> = v1.iter().chain(v2.iter()).collect();
println!("链接: {:?}", chained);

// enumerate:枚举
for (i, v) in v.iter().enumerate() {
println!("索引 {}: 值 {}", i, v);
}

// rev:反转
let reversed: Vec<&i32> = v.iter().rev().collect();
println!("反转: {:?}", reversed);

// inspect:检查(用于调试)
let result: Vec<i32> = v.iter()
.inspect(|x| println!("处理: {}", x))
.map(|x| x * 2)
.inspect(|x| println!("结果: {}", x))
.collect();

// flatten:展平
let nested = vec![vec![1, 2], vec![3, 4], vec![5]];
let flat: Vec<&i32> = nested.iter().flatten().collect();
println!("展平: {:?}", flat);

// flat_map:映射后展平
let words = vec!["hello", "world"];
let chars: Vec<char> = words.iter()
.flat_map(|s| s.chars())
.collect();
println!("字符: {:?}", chars);
}

// sorted:排序(需要引入 itertools crate 或自己实现)
fn sorted_demo() {
let mut v = vec![3, 1, 4, 1, 5, 9];
v.sort();
println!("排序: {:?}", v);
}

链式操作

迭代器的强大之处在于链式调用:

fn main() {
let numbers = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];

// 复杂链式操作
let result: i32 = numbers.iter()
.filter(|x| *x % 2 == 0) // 过滤偶数
.map(|x| x * x) // 平方
.take(3) // 取前3个
.sum(); // 求和

println!("结果: {}", result); // 4 + 16 + 36 = 56

// 另一个示例
let text = "hello world rust programming";
let words_with_e: Vec<&str> = text.split_whitespace()
.filter(|word| word.contains('e'))
.map(|word| word.to_uppercase())
.map(|word| format!("[{}]", word))
.collect();

println!("包含e的词: {:?}", words_with_e);
}

创建自定义迭代器

// 自定义迭代器:计数器
struct Counter {
count: u32,
max: u32,
}

impl Counter {
fn new(max: u32) -> Counter {
Counter { count: 0, max }
}
}

impl Iterator for Counter {
type Item = u32;

fn next(&mut self) -> Option<Self::Item> {
if self.count < self.max {
self.count += 1;
Some(self.count)
} else {
None
}
}
}

fn main() {
let counter = Counter::new(5);

// 使用自定义迭代器
let sum: u32 = counter.sum();
println!("求和: {}", sum);

// 链式操作
let counter2 = Counter::new(10);
let result: Vec<u32> = counter2
.filter(|x| x % 2 == 0)
.map(|x| x * 2)
.collect();
println!("处理结果: {:?}", result);
}

惰性求值

迭代器是惰性的,只有在消费时才会执行:

fn main() {
let v = vec![1, 2, 3, 4, 5];

// 这行不会执行任何操作
let iter = v.iter()
.inspect(|x| println!("过滤前: {}", x))
.filter(|x| *x % 2 == 0)
.inspect(|x| println!("过滤后: {}", x));

println!("迭代器创建完毕,但还未执行");

// 消费时才执行
let result: Vec<&i32> = iter.collect();
println!("结果: {:?}", result);
}

性能

迭代器的性能与手写循环相当:

fn main() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];

// 手写循环
let mut sum1 = 0;
for x in &data {
if x % 2 == 0 {
sum1 += x * x;
}
}

// 迭代器
let sum2: i32 = data.iter()
.filter(|x| *x % 2 == 0)
.map(|x| x * x)
.sum();

assert_eq!(sum1, sum2);
println!("两种方式结果相同: {}", sum1);
}

零成本抽象:迭代器在编译时会被优化,不会产生运行时开销。

实际应用示例

文本处理

fn main() {
let text = "The quick brown fox jumps over the lazy dog";

// 统计单词
let word_count = text.split_whitespace().count();
println!("单词数: {}", word_count);

// 找出最长的单词
let longest = text.split_whitespace()
.max_by_key(|word| word.len());
println!("最长单词: {:?}", longest);

// 转换为首字母大写
let capitalized: Vec<String> = text.split_whitespace()
.map(|word| {
let mut chars = word.chars();
match chars.next() {
Some(first) => first.to_uppercase().chain(chars).collect(),
None => String::new(),
}
})
.collect();
println!("首字母大写: {}", capitalized.join(" "));
}

数据处理管道

#[derive(Debug)]
struct Student {
name: String,
score: u32,
}

fn main() {
let students = vec![
Student { name: String::from("Alice"), score: 85 },
Student { name: String::from("Bob"), score: 92 },
Student { name: String::from("Charlie"), score: 78 },
Student { name: String::from("David"), score: 95 },
Student { name: String::from("Eve"), score: 88 },
];

// 找出分数最高的学生
let top_student = students.iter()
.max_by_key(|s| s.score);
println!("最高分: {:?}", top_student);

// 计算平均分
let avg_score: f64 = students.iter()
.map(|s| s.score)
.sum::<u32>() as f64 / students.len() as f64;
println!("平均分: {:.2}", avg_score);

// 筛选优秀学生(分数 >= 90)
let excellent: Vec<&Student> = students.iter()
.filter(|s| s.score >= 90)
.collect();
println!("优秀学生: {:?}", excellent);

// 按分数分组
let (passed, failed): (Vec<_>, Vec<_>) = students.iter()
.partition(|s| s.score >= 60);
println!("通过: {:?}, 不通过: {:?}",
passed.iter().map(|s| &s.name).collect::<Vec<_>>(),
failed.iter().map(|s| &s.name).collect::<Vec<_>>());
}

小结

本章我们学习了:

  1. 闭包:语法、捕获环境变量、三种 Fn trait
  2. 迭代器:基本使用、消费适配器、迭代适配器
  3. 链式操作:组合多个迭代器方法
  4. 自定义迭代器:实现 Iterator trait
  5. 性能:零成本抽象

练习

  1. 使用闭包和迭代器实现一个简单的计算器
  2. 编写一个函数,使用迭代器找出字符串中所有元音字母
  3. 实现一个自定义迭代器,生成斐波那契数列
  4. 使用 fold 实现一个简单的词频统计程序