跳到主要内容

Fork/Join 框架

Fork/Join 框架是 Java 7 引入的并行计算框架,专为可以递归分解的任务设计。它采用"分而治之"的策略,将大任务拆分成小任务并行执行,最后合并结果。这是 Java 中实现并行计算的核心工具之一。

什么是 Fork/Join?

Fork/Join 的核心思想非常简单:

  • Fork(分叉):将一个大任务拆分成多个小任务
  • Join(合并):等待所有小任务完成,合并它们的结果

这种模式特别适合可以递归分解的问题,比如:

  • 大数组的排序、搜索
  • 矩阵运算
  • 图像处理
  • 大规模数据聚合

与传统线程池的区别

传统线程池(ThreadPoolExecutor)和 Fork/Join 框架都用于并行任务执行,但它们的设计目标不同:

特性ThreadPoolExecutorForkJoinPool
任务类型独立任务可递归分解的任务
任务依赖任务之间无依赖子任务依赖父任务
线程调度任务由队列分配工作窃取算法
适用场景IO 密集型、独立任务CPU 密集型、可分解任务

传统线程池中,如果一个任务创建了子任务并等待其完成,可能导致线程饥饿死锁。Fork/Join 通过工作窃取算法解决了这个问题。

工作窃取算法

工作窃取(Work-Stealing)是 Fork/Join 框架的核心算法。

传统线程池的问题

在传统线程池中,每个线程从一个共享队列中获取任务。如果一个任务创建了子任务并等待其完成,该线程会被阻塞,无法执行其他任务。如果所有线程都在等待子任务完成,就没有线程来执行这些子任务,导致死锁。

工作窃取的解决方案

Fork/Join 框架为每个线程维护一个独立的双端队列:

  • 每个线程从自己队列的尾部获取任务(LIFO,后进先出)
  • 当线程的队列为空时,从其他线程队列的头部窃取任务(FIFO,先进先出)

这种设计的优点:

  1. 避免竞争:大多数情况下线程只操作自己的队列
  2. 充分利用资源:空闲线程不会闲置,而是帮助忙碌的线程
  3. 避免死锁:即使线程在等待子任务,其他线程也能执行这些子任务

核心类

Fork/Join 框架包含以下核心类:

ForkJoinPool

ForkJoinPool 是执行 Fork/Join 任务的线程池。它实现了 ExecutorService 接口,专门用于执行 ForkJoinTask。

// 创建线程池
ForkJoinPool pool = new ForkJoinPool(); // 使用当前处理器核心数
ForkJoinPool pool = new ForkJoinPool(4); // 指定并行度

// 使用公共池
ForkJoinPool commonPool = ForkJoinPool.commonPool();

公共池:从 Java 8 开始,可以使用 ForkJoinPool.commonPool() 获取一个公共的 ForkJoinPool。这个池在 parallelStream() 等并行操作中自动使用。

ForkJoinTask

ForkJoinTask 是 Fork/Join 框架中任务的抽象基类。它比标准线程更轻量,可以在 ForkJoinPool 中高效执行。

常用方法:

fork();      // 异步执行任务
join(); // 等待任务完成并返回结果
invoke(); // 执行任务并等待结果(相当于 fork + join)

通常不直接继承 ForkJoinTask,而是使用它的两个子类:

RecursiveTask

RecursiveTask 是有返回值的递归任务:

public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
protected abstract V compute();
}

RecursiveAction

RecursiveAction 是无返回值的递归任务:

public abstract class RecursiveAction extends ForkJoinTask<Void> {
protected abstract void compute();
}

快速开始

计算 1 到 n 的和

这是一个经典的 Fork/Join 示例,展示如何将大任务分解成小任务:

import java.util.concurrent.*;

public class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 10000; // 阈值

private final long[] array;
private final int start;
private final int end;

public SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}

@Override
protected Long compute() {
int length = end - start;

// 如果任务足够小,直接计算
if (length <= THRESHOLD) {
return computeDirectly();
}

// 否则拆分成两个子任务
int mid = start + length / 2;
SumTask left = new SumTask(array, start, mid);
SumTask right = new SumTask(array, mid, end);

// 执行子任务
left.fork();
long rightResult = right.compute(); // 直接计算右半部分(复用当前线程)
long leftResult = left.join();

return leftResult + rightResult;
}

private long computeDirectly() {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}

public static void main(String[] args) {
// 准备数据
long[] array = new long[1_000_000];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1;
}

// 创建任务
SumTask task = new SumTask(array, 0, array.length);

// 创建线程池并执行
ForkJoinPool pool = new ForkJoinPool();
long result = pool.invoke(task);

System.out.println("结果: " + result); // 500000500000

pool.shutdown();
}
}

关键点

  1. 定义阈值(THRESHOLD),小于阈值时直接计算
  2. 大于阈值时拆分成两个子任务
  3. 使用 fork() 异步执行左子任务,直接计算右子任务(复用当前线程)
  4. 使用 join() 等待左子任务结果并合并

invokeAll 方法

也可以使用 invokeAll 方法一次性提交多个任务:

@Override
protected Long compute() {
int length = end - start;

if (length <= THRESHOLD) {
return computeDirectly();
}

int mid = start + length / 2;
SumTask left = new SumTask(array, start, mid);
SumTask right = new SumTask(array, mid, end);

// 使用 invokeAll 同时执行两个任务
invokeAll(left, right);

return left.join() + right.join();
}

invokeAll 会自动处理任务的执行和等待,代码更简洁。

实战示例

并行数组排序

import java.util.concurrent.*;

public class ParallelMergeSort {
private static final int THRESHOLD = 1000;

public static void sort(int[] array) {
ForkJoinPool pool = new ForkJoinPool();
pool.invoke(new SortTask(array, 0, array.length));
pool.shutdown();
}

private static class SortTask extends RecursiveAction {
private final int[] array;
private final int start;
private final int end;

SortTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}

@Override
protected void compute() {
if (end - start <= THRESHOLD) {
// 小数组直接排序
sequentialSort(array, start, end);
return;
}

int mid = start + (end - start) / 2;

// 并行排序两个子数组
invokeAll(
new SortTask(array, start, mid),
new SortTask(array, mid, end)
);

// 合并两个已排序的子数组
merge(array, start, mid, end);
}
}

private static void sequentialSort(int[] array, int start, int end) {
for (int i = start + 1; i < end; i++) {
int key = array[i];
int j = i - 1;
while (j >= start && array[j] > key) {
array[j + 1] = array[j];
j--;
}
array[j + 1] = key;
}
}

private static void merge(int[] array, int start, int mid, int end) {
int[] temp = new int[end - start];
int i = start, j = mid, k = 0;

while (i < mid && j < end) {
if (array[i] <= array[j]) {
temp[k++] = array[i++];
} else {
temp[k++] = array[j++];
}
}

while (i < mid) temp[k++] = array[i++];
while (j < end) temp[k++] = array[j++];

System.arraycopy(temp, 0, array, start, temp.length);
}

public static void main(String[] args) {
int[] array = new int[100000];
for (int i = 0; i < array.length; i++) {
array[i] = (int) (Math.random() * 1000000);
}

long start = System.currentTimeMillis();
sort(array);
long end = System.currentTimeMillis();

System.out.println("排序耗时: " + (end - start) + "ms");
System.out.println("是否有序: " + isSorted(array));
}

private static boolean isSorted(int[] array) {
for (int i = 1; i < array.length; i++) {
if (array[i] < array[i - 1]) return false;
}
return true;
}
}

并行搜索

import java.util.concurrent.*;

public class ParallelSearch {
private static final int THRESHOLD = 10000;

public static int search(int[] array, int target) {
ForkJoinPool pool = new ForkJoinPool();
int result = pool.invoke(new SearchTask(array, target, 0, array.length));
pool.shutdown();
return result;
}

private static class SearchTask extends RecursiveTask<Integer> {
private final int[] array;
private final int target;
private final int start;
private final int end;

SearchTask(int[] array, int target, int start, int end) {
this.array = array;
this.target = target;
this.start = start;
this.end = end;
}

@Override
protected Integer compute() {
if (end - start <= THRESHOLD) {
return searchDirectly();
}

int mid = start + (end - start) / 2;
SearchTask left = new SearchTask(array, target, start, mid);
SearchTask right = new SearchTask(array, target, mid, end);

left.fork();
int rightResult = right.compute();
int leftResult = left.join();

// 返回最先找到的位置
if (leftResult != -1 && rightResult != -1) {
return Math.min(leftResult, rightResult);
}
return leftResult != -1 ? leftResult : rightResult;
}

private int searchDirectly() {
for (int i = start; i < end; i++) {
if (array[i] == target) {
return i;
}
}
return -1;
}
}

public static void main(String[] args) {
int[] array = new int[1_000_000];
for (int i = 0; i < array.length; i++) {
array[i] = i * 2;
}

int target = 999998;
int index = search(array, target);

System.out.println("目标值 " + target + " 的索引: " + index);
}
}

斐波那契数列

import java.util.concurrent.*;
import java.util.HashMap;
import java.util.Map;

public class FibonacciTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 20;
private final int n;
private static final Map<Integer, Long> cache = new ConcurrentHashMap<>();

public FibonacciTask(int n) {
this.n = n;
}

@Override
protected Long compute() {
// 使用缓存避免重复计算
if (cache.containsKey(n)) {
return cache.get(n);
}

if (n <= THRESHOLD) {
return computeSequentially(n);
}

FibonacciTask f1 = new FibonacciTask(n - 1);
FibonacciTask f2 = new FibonacciTask(n - 2);

f1.fork();
long result = f2.compute() + f1.join();

cache.put(n, result);
return result;
}

private static long computeSequentially(int n) {
if (n <= 1) return n;
long a = 0, b = 1;
for (int i = 2; i <= n; i++) {
long temp = a + b;
a = b;
b = temp;
}
return b;
}

public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool();

long start = System.currentTimeMillis();
long result = pool.invoke(new FibonacciTask(50));
long end = System.currentTimeMillis();

System.out.println("Fibonacci(50) = " + result);
System.out.println("耗时: " + (end - start) + "ms");

pool.shutdown();
}
}

性能优化

选择合适的阈值

阈值是影响 Fork/Join 性能的关键参数。太小会导致过多的任务拆分和调度开销,太大会降低并行度。

一般建议:

  • 阈值应保证每个任务有足够的工作量
  • 通过基准测试找到最优阈值
  • 典型值在 1000-10000 之间

避免任务拆分过细

// 不推荐:每次只拆分成两个任务
@Override
protected Long compute() {
if (length <= 1) { // 阈值太小
return array[start];
}
// ...
}

// 推荐:设置合理的阈值
@Override
protected Long compute() {
if (length <= THRESHOLD) { // THRESHOLD = 10000
return computeDirectly();
}
// ...
}

使用公共池

Java 8 引入了公共 ForkJoinPool,避免创建新线程池的开销:

// 使用公共池
ForkJoinPool pool = ForkJoinPool.commonPool();
long result = pool.invoke(task);

// 或者直接提交给公共池
long result = task.invoke(); // 内部使用公共池

避免阻塞操作

Fork/Join 框架设计用于 CPU 密集型任务,不要在任务中执行阻塞操作:

// 不推荐:任务中执行 IO 操作
@Override
protected Long compute() {
// 不要这样!
String data = readFromFile(); // 阻塞 IO
return process(data);
}

// 推荐:IO 操作放在任务之外
String data = readFromFile();
long result = pool.invoke(new ProcessTask(data));

线程数配置

// 获取处理器核心数
int processors = Runtime.getRuntime().availableProcessors();

// ForkJoinPool 默认使用 processors - 1 个工作线程
// 可以手动指定
ForkJoinPool pool = new ForkJoinPool(processors);

// 查看当前并行度
System.out.println("并行度: " + pool.getParallelism());

常见陷阱

死锁问题

虽然工作窃取算法避免了大多数死锁,但如果任务之间有循环依赖,仍可能死锁:

// 危险!可能导致死锁
class BadTask extends RecursiveTask<Integer> {
@Override
protected Integer compute() {
BadTask task = new BadTask();
task.fork();
task.join(); // 等待子任务,但子任务又在等待...
return 0;
}
}

正确做法是确保任务可以独立完成或只依赖已完成的子任务。

内存消耗

递归创建大量任务可能导致栈溢出或内存不足:

// 危险!可能创建过多任务
class DeepRecursion extends RecursiveTask<Integer> {
@Override
protected Integer compute() {
if (baseCase) return result;

// 递归深度太大
DeepRecursion d = new DeepRecursion();
d.fork();
return d.join() + 1;
}
}

解决方法是确保阈值足够大,限制递归深度。

异常处理

Fork/Join 任务中的异常会被封装在 ExecutionException 中:

try {
long result = pool.invoke(task);
} catch (ExecutionException e) {
Throwable cause = e.getCause(); // 获取原始异常
cause.printStackTrace();
}

任务内部也可以手动处理异常:

@Override
protected Long compute() {
try {
return computeInternal();
} catch (Exception e) {
completeExceptionally(e); // 标记任务异常完成
return null;
}
}

与并行流的关系

Java 8 的并行流(parallelStream)底层使用 Fork/Join 框架:

// 并行流内部使用 ForkJoinPool.commonPool()
long sum = Arrays.stream(array)
.parallel()
.sum();

// 这等价于
long sum = LongStream.of(array)
.parallel()
.reduce(0, Long::sum);

自定义并行流使用的线程池:

ForkJoinPool customPool = new ForkJoinPool(4);
long sum = customPool.submit(() ->
Arrays.stream(array)
.parallel()
.sum()
).get();

小结

Fork/Join 框架是 Java 并行计算的重要工具:

  1. 核心概念:Fork(分叉)拆分任务,Join(合并)等待结果
  2. 工作窃取:空闲线程从忙碌线程窃取任务,提高资源利用率
  3. 核心类:ForkJoinPool、RecursiveTask、RecursiveAction
  4. 适用场景:CPU 密集型、可递归分解的任务
  5. 性能优化:选择合适的阈值、避免阻塞操作、使用公共池

在实际开发中,对于简单的并行任务,可以优先考虑并行流;对于复杂的递归分解任务,Fork/Join 框架提供了更精细的控制。