javaparallel-processingjava-21project-loomstructured-concurrency

Java 21 structured concurrency, need predictable subtask exception ordering


I'm rather new to parallel code, and I tried to convert some code based on executors to structured concurrency, but I lost an important property that I must somehow keep.

Given the following code using structured concurrency with Java 21 preview:

try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
    Subtask<Data1> d1Subtask = scope.fork(() -> getData1(input));
    Subtask<Data2> d2Subtask = scope.fork(() -> getData2(input));

    scope.join().throwIfFailed(); // [1]

    var data1 = d1Subtask.get(); // [2]
    var data2 = d2Subtask.get();

    return new Response(data1, data2);
}

In [1] an eventual first exception out of the two subtasks is thrown, and I don't want that. I need to run both tasks in parallel but I need the result of d1Subtask first in case it fails. In other words:

If I change it to scope.join(); then [2] can fail if d1Subtask is not done. There is d1Subtask.state() but waiting for it to leave the State.UNAVAILABLE state seems against the idea of structured concurrency.

This can be achieved with Executors and pure StructuredTaskScope, but that means potentially running d2Subtask to completion even when the scope could be shut down and that task aborted.

Given that, is possible to modify the code above to wait for the result of d1Subtask in a clean, readable way? I imagined that something like scope.join(d1Subtask) or d1Subtask.join() would be the way of doing it, or maybe a different policy, if that API existed.


Edit: clearer explanation of the desired logic with each possible outcome.


Solution

  • You can use StructuredTaskScope directly, without ShutdownOnFailure, to wait for all jobs to complete, then, you can check the results and failures in the intended order, e.g.

    static Response simpleApproach() throws ExecutionException, InterruptedException {
        try(var scope = new StructuredTaskScope<>()) {
            Subtask<Data1> d1Subtask = scope.fork(() -> getData1(input));
            Subtask<Data2> d2Subtask = scope.fork(() -> getData2(input));
    
            scope.join();
    
            var data1 = get(d1Subtask);
            var data2 = get(d2Subtask);
    
            return new Response(data1, data2);
        }
    }
    
    static <T> T get(Subtask<T> task) throws ExecutionException {
        if(task.state() == State.FAILED)
            throw new ExecutionException(task.exception());
        return task.get();
    }
    

    This is the simplest approach. It ensures that if both jobs failed, the exception of “data1” is propagated to the caller. The only disadvantage is that if “data1” failed before “data2”’s completion, it will wait for “data2”, without an attempt to interrupt it. This, however, may be acceptable as we’re usually not trying (too hard) to optimize the exceptional case.


    But you can also implement your own policy. Here’s an example of a policy having a “primary job”. When other jobs failed, it will wait for the primary job’s completion, to prefer its exception if it failed too. But when the primary job failed, it will shut down immediately, trying to interrupt all other jobs and not wait for their completion:

    static Response customPolicy() throws ExecutionException, InterruptedException {
        try(var scope = new ShutdownOnPrimaryFailure<>()) {
            Subtask<Data1> d1Subtask = scope.forkPrimary(() -> getData1(input));
            Subtask<Data2> d2Subtask = scope.fork(() -> getData2(input));
    
            scope.join().throwIfFailed();
    
            var data1 = d1Subtask.get();
            var data2 = d2Subtask.get();
    
            return new Response(data1, data2);
        }
    }
    
    class ShutdownOnPrimaryFailure<T> extends StructuredTaskScope<T> {
        private final AtomicReference<Throwable> failure = new AtomicReference<>();
        private Subtask<?> primary;
    
        public <U extends T> Subtask<U> forkPrimary(Callable<? extends U> task) {
            ensureOwnerAndJoined();
            Subtask<U> forked = super.fork(task);
            primary = forked;
            return forked;
        }
    
        @Override
        protected void handleComplete(Subtask<? extends T> subtask) {
            super.handleComplete(subtask);
            if(subtask.state() == State.FAILED) {
                if(subtask == primary) {
                    failure.set(subtask.exception());
                    shutdown();
                }
                else failure.compareAndSet(null, subtask.exception());
            }
        }
    
        @Override
        public ShutdownOnPrimaryFailure<T> join() throws InterruptedException {
            super.join();
            primary = null;
            return this;
        }
    
        @Override
        public ShutdownOnPrimaryFailure<T> joinUntil(Instant deadline)
            throws InterruptedException, TimeoutException {
    
            super.joinUntil(deadline);
            primary = null;
            return this;
        }
    
        public void throwIfFailed() throws ExecutionException {
            ensureOwnerAndJoined();
            Throwable t = failure.get();
            if(t != null) throw new ExecutionException(t);
        }
    }
    

    For completeness, I provide code for testing all scenarios at the end of this answer. It checks for all combinations of success and failures.

    With the implemented approaches, it will print

      *** Original
    D1 ↓  D2 →  SUCCESS      D1 D2   FAIL_FAST    D1 D2   FAIL_SLOW    D1 D2
    SUCCESS:    Success       F  F   Data2 Fail    F  F   Data2 Fail    F  F
    FAIL_FAST:  Data1 Fail    F  F   -             F  F   Data1 Fail    F  I
    FAIL_SLOW:  Data1 Fail    F  F   Data2 Fail    I  F   -             I  F
    
      *** Simple
    D1 ↓  D2 →  SUCCESS      D1 D2   FAIL_FAST    D1 D2   FAIL_SLOW    D1 D2
    SUCCESS:    Success       F  F   Data2 Fail    F  F   Data2 Fail    F  F
    FAIL_FAST:  Data1 Fail    F  F   -             F  F   Data1 Fail    F  F
    FAIL_SLOW:  Data1 Fail    F  F   Data1 Fail    F  F   -             F  F
    
      *** Custom Policy
    D1 ↓  D2 →  SUCCESS      D1 D2   FAIL_FAST    D1 D2   FAIL_SLOW    D1 D2
    SUCCESS:    Success       F  F   Data2 Fail    F  F   Data2 Fail    F  F
    FAIL_FAST:  Data1 Fail    F  F   -             F  F   Data1 Fail    F  I
    FAIL_SLOW:  Data1 Fail    F  F   Data1 Fail    F  F   -             F  F
    

    Abbrev. status: Finished, Interrupted, or Running

    The issue was the scenario of D1 failing slow and D2 failing fast, in the middle of the 3rd line. The ShutdownOnFailure then aborted D1 (D1 status Interrupted) and propagated D2’s failure. The simple approach clearly fixes it but loses the ability to fail fast when D1 failed fast (the last scenario in the 2nd line, D2 status now Finished). The custom policy solves the original issue while retaining the fail-fast support.

    public class StructuredExample {
        public static void main(String[] args) {
            record Approach(String name, Callable<?> method) {}
            List<Approach> approaches = List.of(
                new Approach("Original", StructuredExample::originalApproach),
                new Approach("Simple", StructuredExample::simpleApproach),
                new Approach("Custom Policy", StructuredExample::customPolicy));
    
            for(var approach: approaches) {
                System.out.println("  *** " + approach.name());
                System.out.printf("%-12s", "D1 \u2193  D2 \u2192");
                for(Mode d2Mode: Mode.values()) System.out.printf("%-12s D1 D2   ", d2Mode);
                System.out.println();
                for(Mode d1Mode: Mode.values()) {
                    System.out.printf("%-12s", d1Mode + ":");
                    for(Mode d2Mode: Mode.values()) {
                        String result = "-";
                        if(d2Mode == Mode.SUCCESS || d1Mode != d2Mode) try {
                            ScopedValue.where(data1Mode, d1Mode)
                                .where(data2Mode, d2Mode)
                                .call(() -> approach.method().call());
                            result = "Success";
                        }
                        catch(ExecutionException ex) { result = ex.getCause().getMessage(); }
                        catch(Exception ex) { result = ex.getMessage(); }
                        System.out.printf("%-12s%3s%3s   ", result, d1Running.name().charAt(0), d2Running.name().charAt(0));
                    }
                    System.out.println();
                }
                System.out.println();
            }
        }
    
        // mock for the getData1 and getData2 operations, producing success or failure and recording running state
    
        enum Mode { SUCCESS, FAIL_FAST, FAIL_SLOW }
        enum StateDebug { RUNNING, FINISHED, INTERRUPTED; }
    
        static final ScopedValue<Mode> data1Mode = ScopedValue.newInstance();
        static final ScopedValue<Mode> data2Mode = ScopedValue.newInstance();
    
        static volatile StateDebug d1Running, d2Running;
    
        static Data1 getData1(Object input) throws Exception {
            return getDataImpl("Data1", data1Mode, Data1::new, s -> d1Running = s);
        }
    
        static Data2 getData2(Object input) throws Exception {
            return getDataImpl("Data2", data2Mode, Data2::new, s -> d2Running = s);
        }
    
        static <T> T getDataImpl(String which, ScopedValue<Mode> mode, Supplier<T> s, Consumer<StateDebug> c) throws Exception {
            c.accept(StateDebug.RUNNING);
            boolean interrupted = false;
            try {
                Thread.sleep(500);
                switch(mode.get()) {
                    case SUCCESS: return s.get();
                    case FAIL_SLOW: Thread.sleep(500);
                }
                throw new Exception(which + " Fail");
            }
            catch(InterruptedException ex) {
                interrupted = true;
                c.accept(StateDebug.INTERRUPTED);
                throw ex;
            }
            finally {
                if(!interrupted) c.accept(StateDebug.FINISHED);
            }
        }
    
        // dummy data and types
    
        record Data1() {}
        record Data2() {}
    
        record Response(Data1 data1, Data2 data2)  {}
    
        static Object input;
    
        // the implementations
    
        static Response originalApproach() throws ExecutionException, InterruptedException {
            try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
                Subtask<Data1> d1Subtask = scope.fork(() -> getData1(input));
                Subtask<Data2> d2Subtask = scope.fork(() -> getData2(input));
        
                scope.join().throwIfFailed(); // [1]
        
                var data1 = d1Subtask.get(); // [2]
                var data2 = d2Subtask.get();
        
                return new Response(data1, data2);
            }
        }
    
        static Response simpleApproach() throws ExecutionException, InterruptedException {
            try(var scope = new StructuredTaskScope<>()) {
                Subtask<Data1> d1Subtask = scope.fork(() -> getData1(input));
                Subtask<Data2> d2Subtask = scope.fork(() -> getData2(input));
    
                scope.join();
    
                var data1 = get(d1Subtask);
                var data2 = get(d2Subtask);
    
                return new Response(data1, data2);
            }
        }
    
        static <T> T get(Subtask<T> task) throws ExecutionException {
            if(task.state() == State.FAILED)
                throw new ExecutionException(task.exception());
            return task.get();
        }
    
        static Response customPolicy() throws ExecutionException, InterruptedException {
            try(var scope = new ShutdownOnPrimaryFailure<>()) {
                Subtask<Data1> d1Subtask = scope.forkPrimary(() -> getData1(input));
                Subtask<Data2> d2Subtask = scope.fork(() -> getData2(input));
    
                scope.join().throwIfFailed();
    
                var data1 = d1Subtask.get();
                var data2 = d2Subtask.get();
    
                return new Response(data1, data2);
            }
        }
    }
    
    class ShutdownOnPrimaryFailure<T> extends StructuredTaskScope<T> {
        private final AtomicReference<Throwable> failure = new AtomicReference<>();
        private Subtask<?> primary;
    
        public <U extends T> Subtask<U> forkPrimary(Callable<? extends U> task) {
            ensureOwnerAndJoined();
            Subtask<U> forked = super.fork(task);
            primary = forked;
            return forked;
        }
    
        @Override
        protected void handleComplete(Subtask<? extends T> subtask) {
            super.handleComplete(subtask);
            if(subtask.state() == State.FAILED) {
                if(subtask == primary) {
                    failure.set(subtask.exception());
                    shutdown();
                }
                else failure.compareAndSet(null, subtask.exception());
            }
        }
    
        @Override
        public ShutdownOnPrimaryFailure<T> join() throws InterruptedException {
            super.join();
            primary = null;
            return this;
        }
    
        @Override
        public ShutdownOnPrimaryFailure<T> joinUntil(Instant deadline) throws InterruptedException, TimeoutException {
            super.joinUntil(deadline);
            primary = null;
            return this;
        }
    
        public void throwIfFailed() throws ExecutionException {
            ensureOwnerAndJoined();
            Throwable t = failure.get();
            if(t != null) throw new ExecutionException(t);
        }
    }