ForkJoinPool
什么是Fork/Join
Fork/Join 是一个并行计算的框架,主要就是用来支持分治任务模型的,这个计算框架里的 Fork 对应的是分治任务模型里的任务分解,Join 对应的是结果合并。
它的核心思想是将一个大任务分成许多小任务,然后并行执行这些小任务,最终将它们的结果合并成一个大的结果。
一句话总价:多人合作完成一件事。
深入了解 Fork/Join
在Java 中,Fork/Join 框架的主要组成部分是 ForkJoinPool 和 ForkJoinTask。
ForkJoinPool 是一个线程池,它用于管理 ForkJoin 任务的执行。
ForkJoinTask 是一个抽象类,用于表示可以被分割成更小部分的任务。
ForkJoinPool
ForkJoinPool 是 ExecutorService 的实现类,因此是一种特殊的线程池。
我们在使用 Executors 通过调用 newWorkStealingPool 方法即可创建一个线程池为 ForkJoinPool 的 ExecutorService 对象。
ForkJoinPool 核心参数
parallelism
:并行度级别。默认为 CPU 数,最小为 1。factory
:工作线程工厂。handler
:指定异常处理器,当任务在运行时出错时,将有设定的处理器处理。asyncMode
:设置工作队列工作模式。true 表示先进先出,false 表示后进先出。默认:false。
任务提交方式
submit 方法描述:任务提交后,会返回一个 ForkJoinTask。 提交参数:Callable、ForkJoinTask、Runnable
execute 方法描述:任务提交后,不会有返回值。 提交参数:ForkJoinTask、Runnable
invoke 方法描述:任务提交后,会返回泛型结果。 提交参数:ForkJoinTask、Callable集合
TIP
这些提交方式中的参数有 3 种,但是他们最后都会变成 ForkJoinTask 的任务类型。
ForkJoinTask
ForkJoinTask 是 ForkJoinPool 的核心,它是任务的实际载体,定义了任务执行时的具体逻辑和拆分逻辑。
ForkJoinTask 是一个抽象类,它还有两个抽象子类:RecusiveAction 和 RecusiveTask。其中 RecusiveTask 代表有返回值的任务,而 RecusiveAction 代表没有返回值的任务。
核心方法
fork() 用于向当前任务所运行的线程池中提交任务。如果当前线程是 ForkJoinWorkerThread 类型,将会放入该线程的工作队列,否则放入 common 线程池的工作队列中。
join() 用于获取任务的执行结果。调用 join() 时,将阻塞当前线程直到对应的子任务完成运行并返回结果。
工作窃取算法
为了平衡各个工作线程之间的工作负载,ForkJoinPool 采用了工作窃取算法。
每个工作线程都有自己的任务队列,当某个线程完成了自己队列中的所有任务时,它会尝试从其他线程的队列中窃取任务来执行。
工作窃取算法的实现基于双端队列(Deque)。每个工作线程都有一个双端队列来存储待处理的任务。当线程需要执行新任务时,它会将任务放入队列的头部(top),并以 LIFO(后进先出)的顺序处理队列中的任务。这样,最近添加的任务会优先被执行。
同时,当某个线程尝试窃取其他线程的任务时,它会从目标线程的队列的尾部(base)窃取任务。这种窃取方式是 FIFO(先进先出)的,也就是说被窃取的任务是队列中等待时间最长的任务。这种机制有助于减少线程间的竞争,提高CPU的利用率。
任务的拆分与合并
在 ForkJoinPool 中,任务的拆分和合并是通过继承 RecursiveAction 或 RecursiveTask 自己去实现 compute 方法去解决的。
也就是说,我们需要在 compute 中解决如下几个问题:
- 如何拆
- 如何装
在这个过程中,其实是处于递归的,也就是说每个小任务还可以继续被拆分成更小的任务并行执行。这种递归拆分和合并的方式使得 ForkJoinPool 能够处理非常复杂和庞大的任务。
注意
递归请注意深度,小心 OOM!
实战
我们就计算从 1 加到 10,000,000。
创建任务
首先我们创建一个类,用于实现 RecusiveTask,泛型我们使用 Long 即可。
class Task extends RecursiveTask<Long> {
private static final long serialVersionUID = 1L;
@Override
protected Long compute() {
return null;
}
}
设定参数
我们需要知道几个点。
- 什么时候进行拆分,拆分肯定得有一个阈值。
- 从 1 加到 10,000,000,这个值怎么来?
接下来,我们去改造这个 Task 类。
class Task extends RecursiveTask<Long> {
private Integer start = null;
private Integer end = null;
private Integer DISASSEMBLE = 1000000;
public Task(Integer start, Integer end) {
this.start = start;
this.end = end;
}
private static final long serialVersionUID = 1L;
@Override
protected Long compute() {
Long result = 0l;
if((end - start) <= DISASSEMBLE) {
for(int i = start; i <= end; i++) {
result += i;
}
} else {
}
System.out.println(start + "+" + end + "=" + result);
return result;
}
}
解释:
- 我们设定了阈值
DISASSEMBLE
,当超过该阈值时,走compute()
方法里的 else,否则就直接走 if。判断依据就是根据我们 end - start 的差值进行判断。 - 加到多少,我们可以通过构造方法传递进来。如果值太大,我们就走 else 进行拆分任务即可。
任务拆分 & 合并
接下来,我们就需要向 else
里编写任务的拆分与合并了。
class Task extends RecursiveTask<Long> {
private Integer start = null;
private Integer end = null;
private Integer DISASSEMBLE = 1000000;
public Task(Integer start, Integer end) {
this.start = start;
this.end = end;
}
private static final long serialVersionUID = 1L;
@Override
protected Long compute() {
Long result = 0l;
if((end - start) <= DISASSEMBLE) {
for(int i = start; i <= end; i++) {
result += i;
}
} else {
int middle = (start + end) / 2;
Task leftTask = new Task(start, middle);
Task rightTask = new Task(middle + 1, end);
leftTask.fork();
rightTask.fork();
result = leftTask.join() + rightTask.join();
}
System.out.println(start + "+" + end + "=" + result);
return result;
}
}
解释:
- 拆分,我们就利用二分法进行拆分就好,从开始到中间,中间 + 1 至 结束。拆分好任务后,调用
fork()
方法提交任务。 - 接着我们调用
join()
方法等待返回即可。
测试
接着我们弄个 main 方法测试一下。
public static void main(String[] args) throws Exception {
ForkJoinPool forkjoinPool = new ForkJoinPool();
Task task = new Task(0, 10000000);
Future<Long> result = forkjoinPool.submit(task);
try {
System.out.println("result:" + result.get());
} catch (Exception e) {
e.printStackTrace();
}
}
运行结果:
0+625000=195312812500
625001+1250000=585937812500
0+1250000=781250625000
1250001+1875000=976562812500
1875001+2500000=1367187812500
1250001+2500000=2343750625000
0+2500000=3125001250000
2500001+3125000=1757812812500
3125001+3750000=2148437812500
2500001+3750000=3906250625000
3750001+4375000=2539062812500
4375001+5000000=2929687812500
3750001+5000000=5468750625000
2500001+5000000=9375001250000
0+5000000=12500002500000
5000001+5625000=3320312812500
8125001+8750000=5273437812500
6875001+7500000=4492187812500
7500001+8125000=4882812812500
7500001+8750000=10156250625000
6250001+6875000=4101562812500
5625001+6250000=3710937812500
9375001+10000000=6054687812500
6250001+7500000=8593750625000
8750001+9375000=5664062812500
5000001+6250000=7031250625000
8750001+10000000=11718750625000
5000001+7500000=15625001250000
7500001+10000000=21875001250000
5000001+10000000=37500002500000
0+10000000=50000005000000
result:50000005000000
ForkJoinPool 的优势
高效: ForkJoinPool 通过工作窃取算法和并行处理机制,能够充分利用多核处理器的性能,提高程序的并发处理能力。
简化: 使用 ForkJoinPool 可以简化并发编程的复杂性,开发者只需要关注任务的拆分和合并逻辑,而无需关心线程的创建、管理和调度等细节。
分治: ForkJoinPool 特别适合于处理可以被拆分成较小子任务的大任务,如递归算法、排序算法、图算法等。
注意
千万小心递归算法,尽量多测试算法,保证不会出现深度递归的情况。