javaoptimizationvectorizationsimd

Optimizing the Calculation of the Dot Product of int16 Vectors in Java using Vector API


TL;DR: Optimizing 16 bit integer array multiplications without overflowing using Java's Vector API.

I am trying to optimize a performance-critical loop that applies an activation function and calculates the dot product of two int16 arrays using Java's (incubating) Vector API. Here's my current scalar implementation:

for (int i = 0; i < HIDDEN_SIZE; i++)
{
    result += screlu(us.values[i]) * network.L1Weights[i]
        + screlu(them.values[i]) * network.L1Weights[i + HIDDEN_SIZE];
}

where

private static int screlu(short i)
{
    int v = Math.max(0, Math.min(i, QA));
    return v * v;
}

I tried to vectorize this like so:

int[] usValues = new int[HIDDEN_SIZE];
int[] themValues = new int[HIDDEN_SIZE];

for (int i = 0; i < HIDDEN_SIZE; i++)
{
    usValues[i] = (int) us.values[i];
    themValues[i] = (int) them.values[i];
}

IntVector sum = IntVector.zero(INT_SPECIES);

for (; i < upperBound; i += INT_SPECIES.length())
{
    IntVector va = IntVector.fromArray(INT_SPECIES, usValues, i);
    IntVector vb = IntVector.fromArray(INT_SPECIES, themValues, i);
    IntVector vc = IntVector.fromArray(INT_SPECIES, network.L1Weights, i);
    IntVector vd = IntVector.fromArray(INT_SPECIES, network.L1Weights, i + HIDDEN_SIZE);

    va = va.max(0).min(QA);
    va = va.mul(va).mul(vc);

    vb = vb.max(0).min(QA);
    vb = vb.mul(vb).mul(vd);

    sum = sum.add(va).add(vb);
}

int result = sum.reduceLanes(VectorOperators.ADD);

Due to overflow, I had to resort to using 32-bit wide lanes, halving the throughput. As a result, performance is only slightly better. After some research, I found that intrinsic such as _mm256_madd_epi16 solves exactly my problem, but I could not find any information on it in the documentation. Do any equivalent operation exist within the Vector API, and if not, are there any other solutions to this problem?


Solution

  • With some help I eventually figured out an implementation using the S2I operator, which is faster than the scalar implementation but probably slower than what could've been done had vpmaddwd been available.

    import static jdk.incubator.vector.VectorOperators.S2I;
    
    for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length())
    {
        ShortVector usInputs = ShortVector.fromArray(SHORT_SPECIES, us.values, i);
        ShortVector themInputs = ShortVector.fromArray(SHORT_SPECIES, them.values, i);
        ShortVector usWeights = ShortVector.fromArray(SHORT_SPECIES, network.L2Weights[chosenBucket], i);
        ShortVector themWeights = ShortVector.fromArray(SHORT_SPECIES, network.L2Weights[chosenBucket],
                i + HIDDEN_SIZE);
    
        usInputs = usInputs.max(ShortVector.zero(SHORT_SPECIES)).min(ShortVector.broadcast(SHORT_SPECIES, QA));
        themInputs = themInputs.max(ShortVector.zero(SHORT_SPECIES)).min(ShortVector.broadcast(SHORT_SPECIES, QA));
    
        ShortVector usWeightedTerms = usInputs.mul(usWeights);
        ShortVector themWeightedTerms = themInputs.mul(themWeights);
    
        Vector<Integer> usInputsLo = usInputs.convert(S2I, 0);
        Vector<Integer> usInputsHi = usInputs.convert(S2I, 1);
        Vector<Integer> themInputsLo = themInputs.convert(S2I, 0);
        Vector<Integer> themInputsHi = themInputs.convert(S2I, 1);
    
        Vector<Integer> usWeightedTermsLo = usWeightedTerms.convert(S2I, 0);
        Vector<Integer> usWeightedTermsHi = usWeightedTerms.convert(S2I, 1);
        Vector<Integer> themWeightedTermsLo = themWeightedTerms.convert(S2I, 0);
        Vector<Integer> themWeightedTermsHi = themWeightedTerms.convert(S2I, 1);
    
        sum = sum.add(usInputsLo.mul(usWeightedTermsLo)).add(usInputsHi.mul(usWeightedTermsHi))
        .add(themInputsLo.mul(themWeightedTermsLo)).add(themInputsHi.mul(themWeightedTermsHi));
    }
    
    int result = sum.reduceLanes(VectorOperators.ADD);