javamultithreadinghashmerkle-tree

How to Multithread Merkle Tree Hashing


I have a large list I'd like to be able to get the merkle root of in java. It's large enough that being able to multithread the process would speed it up significantly, and as such, I've been trying to do so.

Here's my code so far:

public static byte[] multiMerkleRoot(ArrayList<byte[]> temp) {
    int count = temp.size();
    List<byte[]> hashList = new ArrayList<>();

    for(byte[] o : temp) {
        hashList.add(merkleHash(o));
    }

    if (count % 2 == 0) {
        return getRoot(hashList);
    } else {
        return merkleHash(concat(getRoot(hashList.subList(0, hashList.size() - 1)), hashList.get(hashList.size() - 1)));
    }
}

private static byte[] getRoot(List<byte[]> temp) {
    if(temp.size() % 2 != 0) {
        return merkleHash(concat(getRoot(temp.subList(0, temp.size() - 1)), temp.get(temp.size() - 1)));
    } else {
        if (temp.size() > 2) {
            List<List<byte[]>> subsets = Lists.partition(temp, temp.size() / 2);

            return merkleHash(concat(getRoot(subsets.get(0)), getRoot(subsets.get(1))));
        } else {
            return merkleHash(concat(temp.get(0), temp.get(1)));
        }
    }
}

public static byte[] trueMultiMerkleRoot(ArrayList<byte[]> temp, int threads) {
    try {
        int count = temp.size();
        List<byte[]> hashList = new ArrayList<>();

        for(byte[] o : temp) {
            hashList.add(merkleHash(o));
        }

        if(count % 2 == 0) {
            byte[] chunk1 = null;

            switch(threads) {
                case 1: chunk1 = getRoot(hashList);
                        break;
                case 2: chunk1 = twoThreadMerkle(hashList);
                        break;
                default: System.out.println("You can only have the following threadcounts: 1, 2, 4, 8.");
                        break;
            }

            return chunk1;
        } else {
            byte[] chunk1 = null;
            byte[] chunk2 = hashList.get(hashList.size() - 1);

            switch(threads) {
                case 1: chunk1 = getRoot(hashList.subList(0, hashList.size() - 1));
                    break;
                case 2: chunk1 = twoThreadMerkle(hashList.subList(0, hashList.size() - 1));
                    break;
                default: System.out.println("You can only have the following threadcounts: 1, 2, 4, 8.");
                    break;
            }

            return chunk1;
        }
    } catch(Exception e) {
        return null;
    }
}

private static byte[] twoThreadMerkle(List<byte[]> temp) throws Exception {
    if (!(temp.size() >= 2)) {
        return twoThreadMerkle(temp);
    } else {
        if(temp.size() % 2 != 0) {
            return getRoot(temp);
        } else {
            List<List<byte[]>> subsets = Lists.partition(temp, temp.size() / 2);

            Executor exe1 = Executors.newSingleThreadExecutor();
            Executor exe2 = Executors.newSingleThreadExecutor();

            Future<byte[]> fut1 = ((ExecutorService) exe1).submit(() -> getRoot(subsets.get(0)));
            Future<byte[]> fut2 = ((ExecutorService) exe2).submit(() -> getRoot(subsets.get(1)));

            while ((!fut1.isDone()) || (!fut2.isDone())) {
                Thread.sleep(500);
            }

            return merkleHash(concat(fut1.get(), fut2.get()));
        }
    }
}

multiMerkleRoot is the single threaded version, trueMultiMerkleRoot is an attempt at the multithreaded version.

Here's my problem: No matter what size list I use (I've tried using exact powers of 2, odd numbers, even numbers, small and large) I always get two different answers from the two methods, and I can't for the life of me figure out how to address that.

In this implementation, merkleHash() is just a wrapper for Keccak 256, which I'm using to hash the two byte arrays that I'm concatenating.

If anyone could help me with this in any way, whether to show me where my code is going wrong and how to fix it, or to just set my code on fire and show me how to do it correctly, I'd really appreciate the help.

EDIT: I've attempted a different method after I realized some problems with my previous method. However, this one still won't multithread, even though I think it's much closer.

Here's my new code:

package crypto;

import org.bouncycastle.util.encoders.Hex;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.*;

import static crypto.Hash.keccak256;
import static util.ByteUtil.concat;

public class Merkle {
    private Queue<byte[]> data;

    public Merkle() {
        this.data = new LinkedList<>();
    }

    public Merkle(ArrayList<byte[]> in) {
        this.data = new LinkedList<>();

        this.data.addAll(in);
    }

    public void add(List<byte[]> in) {
        data.addAll(in);
    }

    public void add(byte[] in) {
        data.add(in);
    }

    public byte[] hash() {
        Queue<byte[]> nextLevel = new LinkedList<>();

        while((data.size() > 1) || (nextLevel.size() > 1)) {
            while(data.size() > 0) {
                if(data.size() > 1) {
                    nextLevel.add(merkleHash(data.remove(), data.remove()));
                } else {
                    nextLevel.add(data.remove());
                }

            }

            data.addAll(nextLevel);

            nextLevel.clear();
        }

        return data.remove();
    }

    private byte[] hash(Queue<byte[]> data) {
        Queue<byte[]> nextLevel = new LinkedList<>();

        while((data.size() > 1) || (nextLevel.size() > 1)) {

            while(data.size() > 0) {
                if(data.size() > 1) {
                    nextLevel.add(merkleHash(data.remove(), data.remove()));
                } else {
                    nextLevel.add(data.remove());
                }

            }

            data.addAll(nextLevel);

            nextLevel.clear();
        }

        return data.remove();
    }

    public byte[] dualHash() throws Exception {
        Queue<byte[]> temp1 = new LinkedList<>();
        Queue<byte[]> temp2 = new LinkedList<>();

        if(data.size() == Math.pow(2, log2(data.size()))) return hash();

        int temponesize = (int)Math.pow(2, log2(data.size()) + 1) / 2;
        while(temp1.size() < temponesize) {
            temp1.add(data.remove());
        }

        while(!data.isEmpty()) {
            temp2.add(data.remove());
        }

        /*
        ExecutorService exe1 = Executors.newSingleThreadExecutor();
        ExecutorService exe2 = Executors.newSingleThreadExecutor();
        Callable<byte[]> call1 = new Callable<byte[]>() {
            @Override
            public byte[] call() throws Exception {
                return hash(temp1);
            }
        };
        Callable<byte[]> call2 = new Callable<byte[]>() {
            @Override
            public byte[] call() throws Exception {
                return hash(temp2);
            }
        };

        Future<byte[]> fut1 = exe1.submit(call1);
        Future<byte[]> fut2 = exe2.submit(call2);
        */

        byte[] tem1 = hash(temp1);
        byte[] tem2 = hash(temp2);



        return merkleHash(tem1, tem2);
    }

    public int size() {
        return data.size();
    }

    private byte[] merkleHash(byte[] a, byte[] b) {
        return keccak256(concat(a, b));
    }

    private byte[] merkleHash(byte[] a) {
        return keccak256(a);
    }

    private int log2(int x) {
        return (int)Math.floor((Math.log(x))/(Math.log(2)));
    }
}

If we look specifically at the dualHash method, in this case, it works and gives me the same result as the hash method. However, when I try to delegate it to two threads, like so:

public byte[] dualHash() throws Exception {
        Queue<byte[]> temp1 = new LinkedList<>();
        Queue<byte[]> temp2 = new LinkedList<>();

        if(data.size() == Math.pow(2, log2(data.size()))) return hash();

        int temponesize = (int)Math.pow(2, log2(data.size()) + 1) / 2;
        while(temp1.size() < temponesize) {
            temp1.add(data.remove());
        }

        while(!data.isEmpty()) {
            temp2.add(data.remove());
        }

        ExecutorService exe1 = Executors.newSingleThreadExecutor();
        ExecutorService exe2 = Executors.newSingleThreadExecutor();
        Callable<byte[]> call1 = new Callable<byte[]>() {
            @Override
            public byte[] call() throws Exception {
                return hash(temp1);
            }
        };
        Callable<byte[]> call2 = new Callable<byte[]>() {
            @Override
            public byte[] call() throws Exception {
                return hash(temp2);
            }
        };

        Future<byte[]> fut1 = exe1.submit(call1);
        Future<byte[]> fut2 = exe2.submit(call2);

        byte[] tem1 = fut1.get();
        byte[] tem2 = fut2.get();



        return merkleHash(tem1, tem2);
    }

It no longer gives me the expected result. Any idea as to why?

Thanks!


Solution

  • Solution found!

    Turns out that my code wasn't the issue (at least the code from the edit, I'm 100% sure the first code chunk was completely wrong). The problem is, the two threads were attempting to hash results while both using one instance of MessageDigest. Now that I've forced them to use separate instances, the code runs just fine.