ForkJoinPool was introduced with the release of Java 7 to solve a very particular set of problems that tend to be hard to solve with any other thread pool implementation.
The class is designed to work with divide and conquer algorithm: those where a task can recursively broken into broken sub-tasks.
It looks just like any other thread pool e.e.,ThreadPoolExecutor class, it implements the Executor and ExecutorService interfaces. It uses an unbounded list of tasks that will be run by the number of worker threads configured or by default the number of CPUs exist.
Example: Parallel Merge Sort
Sorting an array of 1 million elements. We have 3 main sub-tasks to sort the array:
- Sort the first half of the array
- Sort the second half of the array
- Merge the two sorted sub-arrays
The base case is when its faster to use insertion sort to sort the sub-array (lets assume when the array has 10 elements) of course makes more sense than using parallel merge sort here. In the end there will be 1 million tasks to sort the leaf arrays, more than 500,000 tasks are needed to merge those sorted sub-arrays, and more than 250,000 tasks to sort the next merged sub-arrays .... and so on.
The most important point to notice is that none of the tasks can complete until the tasks that they have spawned have also completed. Here is when the ForkJoinPool comes very handy. Of course its doable through a ThreadPoolExecutor but can't be done as efficient as ForkJoinPool and the implementation is much more complex.
In ThreadPoolExecutor a parent task must wait for its child tasks to complete, A thread cannot add another task to the queue and then wait for it to complete as once a thread is waiting it can't be used to run one of the sub-tasks.
bla bla bla ... show me the code.
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import java.util.Arrays; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveTask; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.IntStream; import lombok.AllArgsConstructor; @AllArgsConstructor public class ForkJoinPoolSample { private static int[] elements; @AllArgsConstructor private static class MergeSortTask extends RecursiveTask<Integer> { private int first; private int last; @Override protected Integer compute() { int len = 0; if (last-first <= 10) { Arrays.sort(elements, first, last+1); len = (last-first+1); } else { /* Sort two sub-arrays */ int mid = (first+last) >>> 1; MergeSortTask leftSubtask = new MergeSortTask(first, mid); leftSubtask.fork(); MergeSortTask rightSubtask = new MergeSortTask(mid+1, last); rightSubtask.fork(); len += leftSubtask.join(); len += rightSubtask.join(); /* Merge two sorted sub-arrays */ MergeTask mergeTask = new MergeTask(first, mid, last); mergeTask.fork(); mergeTask.join(); } return len; } } @AllArgsConstructor private static class MergeTask extends RecursiveTask<Integer> { private int first; private int mid; private int last; @Override protected Integer compute() { int[] tmp = new int[last - first + 1]; int left = first, right = mid + 1, indx = 0; while (left <= mid && right <= last) { if (elements[left] <= elements[right]) { tmp[indx++] = elements[left++]; } else { tmp[indx++] = elements[right++]; } } while (left <= mid) { tmp[indx++] = elements[left++]; } while (right <= last) { tmp[indx++] = elements[right++]; } for (indx = 0; indx < tmp.length; indx++) { elements[first + indx] = tmp[indx]; } return tmp.length; } } private static void createRandomInts() { elements = new int[100000]; final ThreadLocalRandom random = ThreadLocalRandom.current(); IntStream.range(0, 100000) .forEach(i -> elements[i] = random.nextInt()); } public static void main(String[] args) { createRandomInts(); long before = System.currentTimeMillis(); int n = new ForkJoinPool().invoke(new MergeSortTask(0, elements.length-1)); long after = System.currentTimeMillis(); System.out.println("Sorted " + n + " Elements in " + (after-before) + " ms."); boolean sorted = IntStream.range(0, elements.length-1) .allMatch(i -> elements[i] <= elements[i+1]); assertThat(sorted, is(true)); } }
From the doc:
fork(): Arranges to asynchronously execute this task in the pool the
current task is running in
join(): Returns the result of the computation when it is done
Those methods use a series of internal, per-thread queues to manipulate the tasks and switch threads from executing one task to executing another. Of course all of that is transparent to the developer.
References