I have a custom piece of golang (1.23.0) assembly which performs AVX512 operations to speed up a very common code path. The function checks to see if a group of players are holding a poker hand by representing hands as int64 bitsets. The code looks like this (a CardSet is simply an int64):
// func SubsetAVX512(cs []CardSet, hs []CardSet) int
// Returns 1 if any card set in cards contains any hand in hands, 0 otherwise
#include "textflag.h"
#define cs_data 0(FP)
#define cs_len 8(FP)
#define cs_cap 16(FP)
#define hs_data 24(FP)
#define hs_len 32(FP)
#define hs_cap 40(FP)
#define ret_off 48(FP)
// Define the function
TEXT ·SubsetAVX512(SB), NOSPLIT, $0-56
// Start of the function
// Load parameters into registers
MOVQ cs+cs_data, R8 // R8 = cards_ptr
MOVQ cs+cs_len, R9 // R9 = cards_len
MOVQ hs+hs_data, R10 // R10 = hands_ptr
MOVQ hs+hs_len, R11 // R11 = hands_len
// Check if hands_len == 0
TESTQ R11, R11
JE return_false
// Check if cards_len == 0
TESTQ R9, R9
JE return_false
// Initialize loop counters
XORQ R12, R12 // R12 = i = 0 (hands index)
// Main loop over hands
outer_loop:
CMPQ R12, R11 // Compare i (R12) with hands_len (R11)
JGE return_false // If i >= hands_len, no match found
// Load 8 hands into Z0 (512-bit register)
LEAQ (R10)(R12*8), R13 // R13 = &hands[i]
VMOVDQU64 0(R13), Z0 // Load 8 int64s from [R13] into Z0
// Inner loop over cards
XORQ R14, R14 // R14 = j = 0 (cards index)
inner_loop:
CMPQ R14, R9 // Compare j (R14) with cards_len (R9)
JGE next_hands_block // If j >= cards_len, move to next hands block
// Load cs from cards[j]
LEAQ (R8)(R14*8), R15 // R15 = &cards[j]
MOVQ 0(R15), AX // AX = cards[j]
// Broadcast cs into Z1
VPBROADCASTQ AX, Z1 // Broadcast RAX into all lanes of Z1
// Compute cs_vec & h_vec
VPANDQ Z0, Z1, Z2 // Z2 = Z0 & Z1
// Compare (cs_vec & h_vec) == h_vec
VPCMPEQQ Z0, Z2, K1 // Compare Z0 == Z2, store result in mask K1
// Check if any comparison is true
KORTESTW K1, K1 // Test if any bits in K1 are set
JNZ found_match // If so, a match is found
// Increment card index
INCQ R14 // j++
JMP inner_loop // Repeat inner loop
next_hands_block:
// Increment hands index by 8
ADDQ $8, R12 // i += 8
JMP outer_loop // Repeat outer loop
found_match:
// Match found, return 1
MOVQ $1, AX // Set return value to 1 (true)
RET
return_false:
// No match found, return 0
XORQ AX, AX // Set return value to 0 (false)
RET
This code works great as long as it's not called concurrently, this works:
type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
cs := []CardSet{3, 1}
hs := []CardSet{3, 0}
var count int64
for i := 0; i < 5; i++ {
if SubsetAVX512(cs, hs) {
atomic.AddInt64(&count, 1)
}
}
require.Equal(t, int64(5), count)
}
however, this fails:
type CardSet int64
func SubsetAVX512(cs, hs []CardSet) bool
func TestSubsetAVX512(t *testing.T) {
cs := []CardSet{3, 1}
hs := []CardSet{3, 0}
var count int64
wg := sync.WaitGroup{}
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if SubsetAVX512(cs, hs) {
atomic.AddInt64(&count, 1)
}
}()
}
wg.Wait()
require.Equal(t, int64(5), count)
}
I believe that the issue has to do with some of the registers I'm using being overwritten by concurrent goroutines. My guess is it's the mask register K1
but that's just a slightly educated guess.
Your problem is that you try to return a result in AX
when the Go calling convention requires you to return results on the stack. Change the return to use
MOVQ $1, ret+ret_off
to properly return a result and you'll see your problems disappear.