cintrinsicsavx512half-precision-float

AVX-512 BF16: load bf16 values directly instead of converting from fp32


On CPU's with AVX-512 and BF16 support, you can use the 512 bit vector registers to store 32 16 bit floats.

I have found intrinsics to convert FP32 values to BF16 values (for example: _mm512_cvtne2ps_pbh), but I have not found any intrinsics to load BF16 values directly from memory. It seems a bit wasteful to always load the values in FP32 if I will then always convert them to BF16. Are direct BF16 loads not supported or have I just not found the right intrinsic yet?


Solution

  • Strange oversight in the intrinsics. There isn't a special vmov instruction for BH16 in asm because you don't need one: you'd just use vmovups because asm doesn't care about types. (Except sometimes integer vs. FP domain, so probably prefer an FP load or store instruction - integer vmovdqu16 might perhaps have an extra cycle of latency forwarding from load to FP ALU on some CPUs.)

    If aligned load/store works for your use-case, just point a __m512bh* at your data and deref it. (Is `reinterpret_cast`ing between hardware SIMD vector pointer and the corresponding type an undefined behavior? - it's well-defined as being equivalent to an aligned load or store intrinsic, and is allowed to alias any other data).

    If not, then as @chtz points out, you can memcpy to/from a __m512bh variable. Modern compilers know how to inline and optimize away small fixed-size memcpy, especially of the exact size of a variable. @chtz's demo on Godbolt shows it optimizes the way we want with GCC and clang -O1, like with deref of a __m512bh* but working for unaligned.

    But not so good with MSVC; it works correctly, but the memcpy to a local var actually reserves stack space and stores the value to it, as well as leaving it in ZMM0 as the return value. (Not reloading the copy, but not optimizing away the storage and the dead store to res.)


    With intrinsics, there isn't even a cast intrinsic from __m512, __m512d, or __m512i. (Or for any narrower vector width.)

    But most compilers do also let you use a C-style cast on the vector type, like this to reinterpret (type-pun) the bits as a different vector type:

     __m512bh vec = (__m512bh) _mm512_loadu_ps( ptr );  // Not supported by MSVC
    

    This is not a standard thing defined by Intel's intrinsics guide, but GCC and clang at least implement C-style casts (and C++ std::bit_cast and probably static_cast) the same way as the intrinsics API's functions like _mm512_castsi512_ps or _mm512_castps_ph (the FP16 intrinsic that we wish existed for BF16).

    The AVX-512 load intrinsics take void*, making it clear that it's fine to use them on any type of data. So this just works with no casting of the pointer, just the vector data.

    The 256-bit and 128-bit integer loads / stores take the respective __m256i* or __m128i* pointers, the FP loads take float*. But it's still strict-aliasing safe to do _mm_loadu_ps( (float*)&int_vector[i] ). Anyway, once you get a __m256 or __m128, (__m256bh) vec will work in most compilers.

    MSVC chokes on this cast. You might get away with a C++20 std::bit_cast<__m512h>( vec ) for MSVC if you're using C++. But if you want to write portable C that compiles efficiently on MSVC as well as GCC/Clang, your only option might be to deref an aligned pointer. memcpy compiles to a dead store on MSVC, casting the value doesn't work, and deref of a vector pointer requires alignment on GCC/Clang. MSVC always avoids alignment-checking versions of instructions, so if you're willing to #ifdef, it might be safe to deref an unaligned __m512h* on MSVC.

    (It's not safe to deref a __m128* without AVX because it could fold into a memory source operand like addps xmm0, [rdi] which does require alignment, but that's only for legacy-SSE things. VEX / EVEX encodings allow unaligned by default. A raw deref won't invent vmovntps stores that only come in alignment-required flavour; if a vmovxxx is required it'll use vmovups instead of vmovaps even if the pointer is known to be aligned. GCC and clang will use alignment-enforcing instructions when they can prove it's safe, unlike MSVC and classic ICC.)