goassemblyavxavx512

AVX512 assembly breaks when called concurrently from different goroutines


I have a custom piece of golang (1.23.0) assembly which performs AVX512 operations to speed up a very common code path. The function checks to see if a group of players are holding a poker hand by representing hands as int64 bitsets. The code looks like this (a CardSet is simply an int64):

// func SubsetAVX512(cs []CardSet, hs []CardSet) int
// Returns 1 if any card set in cards contains any hand in hands, 0 otherwise

#include "textflag.h"

#define cs_data 0(FP)
#define cs_len  8(FP)
#define cs_cap  16(FP)
#define hs_data 24(FP)
#define hs_len  32(FP)
#define hs_cap  40(FP)
#define ret_off 48(FP)

// Define the function
TEXT ·SubsetAVX512(SB), NOSPLIT, $0-56

// Start of the function
    // Load parameters into registers
    MOVQ cs+cs_data, R8         // R8 = cards_ptr
    MOVQ cs+cs_len, R9          // R9 = cards_len

    MOVQ hs+hs_data, R10        // R10 = hands_ptr
    MOVQ hs+hs_len, R11         // R11 = hands_len

    // Check if hands_len == 0
    TESTQ R11, R11
    JE return_false

    // Check if cards_len == 0
    TESTQ R9, R9
    JE return_false

    // Initialize loop counters
    XORQ R12, R12                 // R12 = i = 0 (hands index)

    // Main loop over hands
outer_loop:
    CMPQ R12, R11                 // Compare i (R12) with hands_len (R11)
    JGE return_false              // If i >= hands_len, no match found

    // Load 8 hands into Z0 (512-bit register)
    LEAQ (R10)(R12*8), R13        // R13 = &hands[i]
    VMOVDQU64 0(R13), Z0          // Load 8 int64s from [R13] into Z0

    // Inner loop over cards
    XORQ R14, R14                 // R14 = j = 0 (cards index)
inner_loop:
    CMPQ R14, R9                  // Compare j (R14) with cards_len (R9)
    JGE next_hands_block          // If j >= cards_len, move to next hands block

    // Load cs from cards[j]
    LEAQ (R8)(R14*8), R15         // R15 = &cards[j]
    MOVQ 0(R15), AX               // AX = cards[j]

    // Broadcast cs into Z1
    VPBROADCASTQ AX, Z1           // Broadcast RAX into all lanes of Z1

    // Compute cs_vec & h_vec
    VPANDQ Z0, Z1, Z2             // Z2 = Z0 & Z1

    // Compare (cs_vec & h_vec) == h_vec
    VPCMPEQQ Z0, Z2, K1           // Compare Z0 == Z2, store result in mask K1

    // Check if any comparison is true
    KORTESTW K1, K1               // Test if any bits in K1 are set
    JNZ found_match               // If so, a match is found

    // Increment card index
    INCQ R14                      // j++
    JMP inner_loop                // Repeat inner loop

next_hands_block:
    // Increment hands index by 8
    ADDQ $8, R12                  // i += 8
    JMP outer_loop                // Repeat outer loop

found_match:
    // Match found, return 1
    MOVQ $1, AX                   // Set return value to 1 (true)
    RET

return_false:
    // No match found, return 0
    XORQ AX, AX                   // Set return value to 0 (false)
    RET

This code works great as long as it's not called concurrently, this works:

type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
    cs := []CardSet{3, 1}
    hs := []CardSet{3, 0}
    var count int64
    for i := 0; i < 5; i++ {
        if SubsetAVX512(cs, hs) {
            atomic.AddInt64(&count, 1)
        }
    }
    require.Equal(t, int64(5), count)
}

however, this fails:

type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
    cs := []CardSet{3, 1}
    hs := []CardSet{3, 0}
    var count int64
    wg := sync.WaitGroup{}
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            if SubsetAVX512(cs, hs) {
                atomic.AddInt64(&count, 1)
            }
        }()
    }
    wg.Wait()
    require.Equal(t, int64(5), count)
}

I believe that the issue has to do with some of the registers I'm using being overwritten by concurrent goroutines. My guess is it's the mask register K1 but that's just a slightly educated guess.


Solution

  • Your problem is that you try to return a result in AX when the Go calling convention requires you to return results on the stack. Change the return to use

    MOVQ $1, ret+ret_off
    

    to properly return a result and you'll see your problems disappear.