javaconcurrencyjava.util.concurrentvirtual-threads

Can't reproduce virtual thread pinning (jdk 21), yet mysql not as parallel as should be


The jdk 21 doc about virtual threads (hereafter "VT") (https://docs.oracle.com/en/java/javase/21/core/virtual-threads.html) is pretty clear about a synchronized block causing a thread to be pinned, or native methods, or ...

It all started with another SO question about mysql and virtual threads (Can Virtual Threads improve querying a database in Java?), and this showed that VT are not parallelizing upon mysql statements as they should (on driver prior to 9.0.0, like 8.4). I though, I can surely reproduce that in a kata.... Let's try!

I created 12 tasks at 1 thread per task, each task doing 3 blocking op of 1000ms inside a synchronized block. It should take 3x1000ms = 3seconds per task. At least on platform threads. And those platform threads never fail to do so predictably.

I wrote variations of blocking ops as Object.wait() inside s synchronized, as simple Thread.sleep() (synchronized and not), as ReentrantLock's condition.await() and as socket inputstream read() (synchronized and not). My money was on the network read, since that what mysql driver did. Also, we can see from the jdk source that object.wait() and thread.sleep() have been made VT friendly. I even tried but deleted a Pipe Input/Output Stream setup because it's all based on the same previous in-memory synchronization.

On VTs, according to the doc, I should expect synchronized and native methods, but not reentrantlocks, to take more time, but it would depends how many carrier threads are around... They don't say.

(UPDATE: actually I found later that there are as many as cores, but extras can be created up to 256 total).

If there are only as many carrier threads as there are cores, then I have 4 carriers. I should expect my tasks to run only 4 at a time if I really pinned the VT, and the whole tasks to finish by the 9 seconds mark.

I wanted to prove that. So I wrote this complex test below. And I failed to pin any carrier thread; (at first: see my own answer) the 12 tasks end nearly together at the 3 seconds mark.

What's even more interesting is that I started 12 threads, but 14 carriers were used at some point.

(UPDATE: this is the extra carriers expected after finding better doc: "The capture of an OS thread is compensated by temporarily adding a carrier thread to the scheduler." Ref: https://liakh-aliaksandr.medium.com/concurrent-programming-in-java-with-virtual-threads-8f66bccc6460)

To see the carried thread id, I used JNA at first, to pull the kernel 32 GetCurrentThreadId(). Later I used Foreign Function Invocation (FFI) as suggested by a commenter. I removed this JNA code since.

I will not repeat the original question's code as it is obsolete and much even if you don't read the answer too.

I really wanted to know how the older mysql driver 8.4 can pin a VT carrier to the point of reducing concurrency, and particularly how they would fix it in driver 9.0.x, but I didn't feel like reading mysql code to do that. It seemed like a single-file-sized kata was in order.

(UPDATE: so I found the way to reproduce pinning, see answer.)

(UPDATE: since the answer, commenters have noted that java 24 will implement synchronized() blocks and methods in a way that VTs can unbind from the platform threads.)


Solution

  • I figured how to reproduce. I was very close, with my synchronized wait. I needed a synchronized block, but I put a thread sleep instead, not releasing the monitor.

    Here is the new code, without JNA. Thanks for the FFI (Foreign Function Invocation) idea). Now it is a single file too.

    I though I could shave a lot to minimize it to the reproduction, but showing the variants and how they behave is a great complement to learn.

    package vttests;
    
    import java.io.Closeable;
    import java.io.DataInputStream;
    import java.io.DataOutputStream;
    import java.io.IOException;
    import java.io.InputStream;
    import java.io.OutputStream;
    import java.lang.foreign.FunctionDescriptor;
    import java.lang.foreign.Linker;
    import java.lang.foreign.SymbolLookup;
    import java.lang.foreign.ValueLayout;
    import java.lang.invoke.MethodHandle;
    import java.lang.reflect.InvocationTargetException;
    import java.net.ServerSocket;
    import java.net.Socket;
    import java.util.Collections;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    import java.util.TreeSet;
    import java.util.concurrent.ConcurrentHashMap;
    import java.util.concurrent.ExecutionException;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    import java.util.concurrent.Future;
    import java.util.concurrent.TimeUnit;
    import java.util.concurrent.atomic.AtomicInteger;
    import java.util.concurrent.locks.Condition;
    import java.util.concurrent.locks.ReentrantLock;
    import java.util.function.BiConsumer;
    import java.util.function.Supplier;
    import java.util.stream.IntStream;
    
    public class TestCustomBlockageWithVTs {
        static final boolean DEBUGGING = java.lang.management.ManagementFactory.getRuntimeMXBean().getInputArguments().toString().contains("-agentlib:jdwp");
        static final boolean TRACELOG = false;
        
        static final int PORT = 42000;
        static final int BLOCKAGE_DURATION_MS = 1000;
        static final int NUM_REPETITION = 3;
        
        static MethodHandle getTidMH;
        static {
            System.loadLibrary("kernel32");
            SymbolLookup.loaderLookup().find("GetCurrentThreadId").ifPresent(mseg -> {
                FunctionDescriptor mdesc = FunctionDescriptor.of(ValueLayout.JAVA_INT);
                getTidMH = Linker.nativeLinker().downcallHandle(mseg, mdesc);
            });
        }
        
        static Integer tid() {
            try {
                return getTidMH != null ? (int)getTidMH.invokeExact() : null;
            } catch (Throwable e) {
                throw new RuntimeException(e);
            }
        }
        
        static void p(Object msg) {
            System.out.println("["
                + (Thread.currentThread().isVirtual() ? "V" : "P")
                + "/"+ Thread.currentThread().getName()
                + "("+tid()+")]: " 
                + msg
            );
        }
        
        public static void main(String[] args) throws Exception {
            //NOTE: jdk.tracePinnedThreads hangs the JVM in debug mode
            if(!DEBUGGING)
                System.setProperty("jdk.tracePinnedThreads", "short"); //or "full"
            
            boolean useVirtual = true;
            
            runTest("synchronized waits", useVirtual, SyncWaitTask.class);
            runTest("synchronized sleep", useVirtual, SyncSleepTask.class);
            runTest("reentrant lock waits", useVirtual, RLWaitTask.class);
            runTest("reentrant lock sleep", useVirtual, RLSleepTask.class);
            runTest("sleep", useVirtual, SleepTask.class);
            
            try(Closeable startServer = startServer()) {
                runTest("net read", useVirtual, NetTask.class);
                runTest("sync net read", useVirtual, SyncNetTask.class);
            }           
        }
        
        static void runTest(String title, boolean useVT, Class<? extends AbstractBlockingTask> clazz) throws InterruptedException, ExecutionException {
            Map<Integer,Set<Integer>> tid2id = new ConcurrentHashMap<>();
            Map<Integer,Set<Integer>> id2tid = new ConcurrentHashMap<>();
            BiConsumer<Integer,Integer> recorder = (tid,taskId) -> {
                tid2id.computeIfAbsent(tid, _tid -> Collections.synchronizedSet(new TreeSet<>())).add(taskId);
                id2tid.computeIfAbsent(taskId, _id -> Collections.synchronizedSet(new TreeSet<>())).add(tid);
            };
            
            AtomicInteger nextTaskId = new AtomicInteger(0);
            Supplier<Runnable> taskFactory = () -> newTask(clazz, nextTaskId.getAndIncrement(), BLOCKAGE_DURATION_MS, recorder);
            
            long t = System.nanoTime();
            runTasks(title, useVT, taskFactory);
            double d = 1e-9*(System.nanoTime() - t);
            p("Whole test took " + d + " sec");
            
            double expectedLow = 1e-3 * BLOCKAGE_DURATION_MS * NUM_REPETITION;
            double expectedHigh = 1.10 * expectedLow;
            if(d<expectedLow || d>expectedHigh) {
                p("\n\n>>>>>>>>>>>> "+title+": test lasted outside of expected time of ~"+expectedLow+" sec : "+d+" sec <<<<<<<<<<\n\n");
            }
            
            p("thread-taskid mappings recorded "+tid2id.size()+" distinct native threads");
            tid2id.values().removeIf(set -> set.size()<=1);
            tid2id.forEach( (tid,id) -> p("\tThread "+tid+" was shared among tasks ids "+id));
            
            p("task-thread mappings recorded "+id2tid.size()+" distinct tasks");
            id2tid.values().removeIf(set -> set.size()<=1);
            id2tid.forEach( (id,tid) -> p("\tTask "+id+" was carried with thread ids "+tid));
        }
        
        static AbstractBlockingTask newTask(Class<? extends AbstractBlockingTask> clazz, int taskId, int delayms, BiConsumer<Integer, Integer> recorder) {
            AbstractBlockingTask t;
            try {
                t = clazz.getDeclaredConstructor().newInstance();
            } catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException e) {
                throw new RuntimeException(e);
            }
            t.taskId = taskId;
            t.indent = "\t".repeat(taskId) +"#";
            t.delayMs = delayms;
            t.recorder = recorder;
            return t;
        }
        
        static void runTasks(String title, boolean useVT, Supplier<Runnable> taskFactory) throws InterruptedException, ExecutionException {
            int cores = Runtime.getRuntime().availableProcessors();
            int N = 3 * cores;
            
            p("\n\n=========== "+title+" "+(useVT?"using VT":"using platform")+" threads ===============\n");
            p("Processors count  = " + cores+"; threads to create = " + N);
            
            long t0;
            long d;
            
            ExecutorService es = useVT ? Executors.newVirtualThreadPerTaskExecutor() : Executors.newThreadPerTaskExecutor(Thread.ofPlatform().factory());
            try (es) {
                t0 = System.nanoTime();
                List<? extends Future<?>> futures = IntStream
                    .range(0, N)
                    .mapToObj(i -> taskFactory.get())
                    .map(es::submit)
                    .toList();
                
                //--------
                
                if(TRACELOG) p("waiting for all futures...");
                for(Future<?> future: futures) {
                            future.get();
                }
                d = System.nanoTime()-t0;
                if(TRACELOG) p("waiting futures took "+1e-9*d+" sec");
                
                //--------
                
                if(TRACELOG) p("shutting down es and await termination...");
                t0 = System.nanoTime();
                es.shutdown();
                es.awaitTermination(1, TimeUnit.SECONDS);
                d = System.nanoTime()-t0;
                if(TRACELOG) p("shut down took "+1e-9*d+" sec");
                
                //--------
                
                if(TRACELOG) p("closing executor...");
                t0 = System.nanoTime();
            }
            
            if(TRACELOG) p("closing executor took " + 1e-9*d + " sec");
        }
        
        
        static abstract class AbstractBlockingTask implements Runnable {
            int taskId;
            private int delayMs;
            private String indent;
            BiConsumer<Integer,Integer> recorder;
            
            AbstractBlockingTask() {
            }
            
            @Override
            public void run() {
                if(TRACELOG) p(indent + "task " + taskId + " started.");
                work();
                if(TRACELOG) p(indent + "task " + taskId + " ended.");
            }
            
            void work() {
                for(int i=1; i<=NUM_REPETITION; i++) {
                    if(TRACELOG) p(indent + "call "+i+" starting ... ");
                    try {
                        block(delayMs);
                    } catch (InterruptedException e) {
                        p("call "+i+" interrupted!");
                    }
                    if(TRACELOG) p(indent + "call "+i+" finished.");
                }
            }
            
            abstract void block(int delayms) throws InterruptedException;
        }
        
        
        //======= blocking tasks variants ====================
        
        static class SyncWaitTask extends AbstractBlockingTask {
            @Override
            void block(int delayms) throws InterruptedException {
                synchronized(this) {
                    recorder.accept(tid(), taskId);
                    this.wait(delayms);
                    recorder.accept(tid(), taskId);
                }
            }
        }
        
        static class SyncSleepTask extends AbstractBlockingTask {
            @Override
            void block(int delayms) throws InterruptedException {
                synchronized(this) {
                    recorder.accept(tid(), taskId);
                    Thread.sleep(delayms);
                    recorder.accept(tid(), taskId);
                }
            }
        }
        
        static class RLWaitTask extends AbstractBlockingTask {
            ReentrantLock rl = new ReentrantLock();
            Condition cond = rl.newCondition();
            @Override
            void block(int delayms) throws InterruptedException {
                rl.lock();
                try {
                    recorder.accept(tid(), taskId);
                    cond.await(delayms, TimeUnit.MILLISECONDS);
                    recorder.accept(tid(), taskId);
                } finally {
                    rl.unlock();
                }
            }
        }
        
        static class RLSleepTask extends AbstractBlockingTask {
            ReentrantLock rl = new ReentrantLock();
            Condition cond = rl.newCondition();
            @Override
            void block(int delayms) throws InterruptedException {
                rl.lock();
                try {
                    recorder.accept(tid(), taskId);
                    Thread.sleep(delayms);
                    recorder.accept(tid(), taskId);
                } finally {
                    rl.unlock();
                }
            }
        }
        
        static class SleepTask extends AbstractBlockingTask {
            @Override
            void block(int delayms) throws InterruptedException {
                recorder.accept(tid(), taskId);
                Thread.sleep(delayms);
                recorder.accept(tid(), taskId);
            }
        }
        
        static class NetTask extends AbstractBlockingTask {
            @Override
            void block(int delayms) throws InterruptedException {
                recorder.accept(tid(), taskId);
                connectAndRead(PORT, delayms);
                recorder.accept(tid(), taskId);
            }
        }
        
        static class SyncNetTask extends AbstractBlockingTask {
            @Override
            void block(int delayms) throws InterruptedException {
                synchronized(this) {
                    recorder.accept(tid(), taskId);
                    connectAndRead(PORT, delayms);
                    recorder.accept(tid(), taskId);
                }
            }
        }
        
        static Closeable startServer() throws IOException {
            ExecutorService exsvc = Executors.newThreadPerTaskExecutor(Thread.ofPlatform().factory());
            
            ServerSocket ss = new ServerSocket(PORT);
            Future<?> f = exsvc.submit(() -> {
                try {
                    while(!Thread.interrupted() && !ss.isClosed()) {
                        Socket s = ss.accept();
                        exsvc.submit(() -> serviceOne(s));
                    }
                } catch (IOException e) {
                    //e.printStackTrace();
                }
            });
            
            return () -> {
                f.cancel(true);
                exsvc.shutdownNow();
                ss.close();
            };
        }
        
        static void serviceOne(Socket s) {
            try {
                try(s) {
                    InputStream is = s.getInputStream();
                    DataInputStream dis = new DataInputStream(is);
                    int delayms = dis.readInt();
                    
                    OutputStream os = s.getOutputStream();
                    Thread.sleep(delayms);
                    os.write(42);
                    os.flush();
                    os.close();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        
        static void connectAndRead(int port, int delayms) {
            try (Socket s = new Socket("localhost", port)) {
                OutputStream os = s.getOutputStream();
                DataOutputStream dos = new DataOutputStream(os);
                dos.writeInt(delayms);
                dos.flush();
                
                InputStream is = s.getInputStream();
                is.read();//should be delayed
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        
    }