I am trying to create an efficient AVX2 implementation for a 17x17-bit truncated squarer that returns the 15 most significant bits of the result. This operation appears as a building block in transcendental function approximation, as described in the following publication:
Stuart F. Oberman and Michael Y. Siu, "A High-Performance Area-Efficient Multifunction Interpolator." In 17th IEEE Symposium on Computer Arithmetic, June 2005, pp. 272-279.
My assumption is that a truncated squarer is used to guarantee single-cycle execution, such that squaring proceeds in parallel with necessary table lookup. Various implementations of truncated multipliers are described in the literature. I picked a minimal one based on clues in the paper and that I find fulfils the functional requirements in the given context. For more accurate (but potentially slower) implementations, see
Theo Drane, et al., "On the Systematic Creation of Faithfully-Rounded Commutative Truncated Booth Multipliers." In 31st IEEE Symposium on Computer Arithmetic, June 2024, pp. 108-115.
For the details of operation of the truncated squarer I refer to the diagram in the comment for the reference implementation truncated_squarer_ref()
. Here, letters 'a' through 'q' denote individual bits of source operand x
, x<i> denotes the i-th bit of x
, '0' indicates bits that are ignored due to the truncating nature of the squarer, 'R' denotes the "round" bit, 'C' denotes a potential carry-out, while 'r' denotes any bit of the result.
For the reference implementation, I essentially used binary long-hand multiplication with certain source bits and partial-product bits disregarded as indicated by the truncating nature of the squarer. I did not exploit the fact that this is a squaring operation rather than a general multiplication, because no optimization along those lines readily occurred to me.
An AVX2 implementation that fully exploits all available parallelism seemed trivial, given that the task basically comprises conditional summation of fifteen 15-bit partial products, all of which fit into a single 256-bit register. Straightforward translation thus resulted in truncated_squarer_avx2()
as shown below. Because I usually compile with /QxHOST
(meaning: target the architecture of the build machine) and my build machine implements the sklylake-avx512
architecture, I did not notice until considerable time later that through the wonders of non-orthogonal ISA design, _mm256_srlv_epi16()
is not actually a thing in AVX2.
Emulating _mm256_srlv_epi16()
via _mm256_srlv_epi32()
would seem inefficient and I do not use AVX2 frequently enough to readily think of an alternative way of addressing the problem. What is an efficient way of implementing the functionality of the truncated squarer using AVX2? Any fast and concise intrinsic-based implementation that passes the exhaustive functional test is welcome.
Addendum: It was asked in comments whether it is possible to perform a full multiplication first and then subtract from it the sum of the discarded partial products. This is indeed possible, as demonstrated by truncated_squarer_mul_sub()
below.
Addendum 2: I found through further experimentation at Compiler Explorer that the auto-SIMD-ization of the Intel compiler icx 2025.1.1 can generate a 25-instruction solution for a code variant I include as truncated_squarer_alt()
below. This came as a pleasant surprise.
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include "immintrin.h"
/* 17x17-bit truncated squarer that returns the most significant 15 bits
of the result. This uses a "round" bit to reduce the error introduced
by truncation.
17 source bits
_______^_______
/ \
q0000000000000000 ( (x >> 16)* x<2>
qp000000000000000 + (x >> 15)* x<3>
qpo00000000000000 + (x >> 14)* x<4>
qpon0000000000000 + (x >> 13)* x<5>
qponm000000000000 + (x >> 12)* x<6>
qponml00000000000 + (x >> 11)* x<7>
qponmlk0000000000 + (x >> 10)* x<8>
qponmlkj000000000 + (x >> 9) * x<9>
qponmlkji00000000 + (x >> 8) * x<10>
qponmlkjih0000000 + (x >> 7) * x<11>
qponmlkjihg000000 + (x >> 6) * x<12>
qponmlkjihgf00000 + (x >> 5) * x<13>
qponmlkjihgfe0000 + (x >> 4) * x<14>
qponmlkjihgfed000 + (x >> 3) * x<15>
qponmlkjihgfedc00 + (x >> 2) * x<16>) >> 1
C R
rrrrrrrrrrrrrrr
\______ ______/
v
15 result bits
*/
uint32_t truncated_squarer_ref (uint32_t x)
{
uint32_t r = 0;
for (int i = 16; i >= 2; i--) {
r += (0 - ((x >> i) & 1)) & (x >> (18 - i));
}
return r >> 1;
}
/* As is, this does NOT actually work on AVX2 but runs fine on AVX512! */
uint32_t truncated_squarer_avx2 (uint32_t x)
{
__m256i a, b, lsb, shift_count1_v, shift_count2_v;
uint16_t shift_count1[16] = { 0, 1, 2, 3, 4,5,6,7,8,9,10,11,12,13,14,15};
uint16_t shift_count2[16] = {14,13,12,11,10,9,8,7,6,5, 4, 3, 2, 1, 0,15};
uint16_t xs = x >> 2;
memcpy (&shift_count1_v, shift_count1, sizeof shift_count1_v);
memcpy (&shift_count2_v, shift_count2, sizeof shift_count2_v);
a = _mm256_set1_epi16 (xs);
lsb = _mm256_set1_epi16 (1);
b = _mm256_srlv_epi16 (a, shift_count2_v); // doesn't exist in AVX2!
a = _mm256_srlv_epi16 (a, shift_count1_v); // doesn't exist in AVX2!
b = _mm256_and_si256 (b, lsb);
a = _mm256_mullo_epi16 (a, b);
a = _mm256_hadd_epi16 (a, a);
a = _mm256_hadd_epi16 (a, a);
a = _mm256_hadd_epi16 (a, a);
return (uint32_t)(_mm256_extract_epi16(a,0)+_mm256_extract_epi16(a,8)) >> 1;
}
// perform full multiply, then remove the sum of the discarded partial products
uint32_t truncated_squarer_mul_sub (uint32_t x)
{
uint32_t r = 0;
x = x >> 2;
for (int i = 14; i >= 0; i--) {
r += (0 - ((x >> i) & 1)) & ((x << i) & 0x3fff);
}
return (x * x - r) >> 15;
}
// compiles to 25 instructions with icx 2025.1.1 through auto-SIMD-ization
uint32_t truncated_squarer_alt (uint32_t x)
{
uint32_t r = 0;
x = x >> 2;
for (int i = 0; i < 15; i++) {
if ((x >> i) & 1) r = r + ((x << i) & ~0x3fff);
}
return r >> 15;
}
int main (void)
{
uint32_t x, res, ref;
for (x = 0; x < 0x20000; x++) {
res = truncated_squarer_avx2 (x);
ref = truncated_squarer_ref (x);
if (res != ref) {
printf ("x=%05x res=%04x ref=%04x\n", x, res, ref);
return EXIT_FAILURE;
}
}
return EXIT_SUCCESS;
}
Edit2: Final SSE/SWAR solution (the initial broadcast may be faster, if AVX2-broadcasts are available -- depending on context and target architecture).
The qnjfc²+lh²
term sums up 7 1x1 bit products at within 0x1c000
, the remaining product sums add up 4 or 5 1xN products, possibly with one additional bit at the lowest position, which however gets masked away before final reduction (Similar to @Aki Suihkonen's idea).
The final reduction I now took from @njuffa's truncated_squarer_sse_tuned_reduction
since that seems to work best in their context ...
uint32_t truncated_squarer_swar(uint32_t x)
{
__m128i vx = _mm_set1_epi16(x >> 2);
// qnjfc, lh, qpo, nm, lk, ji, hg, fed
__m128i va = _mm_and_si128(vx, _mm_setr_epi16(0x4889, 0x0220, 0x7000,0x0c00, 0x0300, 0x00c0, 0x0030, 0x000e));
// qnjfc, lh, q..d, q..g, q..i, q..k, q..m, qpo
__m128i vb = _mm_and_si128(vx, _mm_setr_epi16(0x4889, 0x0220, 0x7ffe,0x7ff0, 0x7fc0, 0x7f00, 0x7c00, 0x7000));
__m128i prod = _mm_madd_epi16(va, vb);
__m128i hsum = _mm_and_si128(prod, _mm_setr_epi32(0x1c000,0x7fffc000,0x7fffc000,0x7fffc000));
// sum partial products (reduction)
hsum = _mm_add_epi32 (hsum, _mm_shuffle_epi32 (hsum, 0xee));
hsum = _mm_add_epi32 (hsum, _mm_shuffle_epi32 (hsum, 0x55));
// extract result and scale to desired precision:
return ((uint32_t)_mm_cvtsi128_si32(hsum)) >> 15 ;
}
Edit: Here is a straight-forward AVX2 version, which should be better than the original SSSE3 version (assuming adequate hardware support):
First of all, I borrowed a 4x32bit horizontal sum from one of Peter's answers (depending on surrounding port usage, there may be room for improvement, e.g., extract a 64bit register before the last reduction -- Peter made a comment about that below this answer.)
// https://stackoverflow.com/a/35270026/6870253
int hsum_epi32_sse2(__m128i x) {
#ifdef __AVX__
__m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a mov
#else
__m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2));
#endif
__m128i sum64 = _mm_add_epi32(hi64, x);
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements
__m128i sum32 = _mm_add_epi32(sum64, hi32);
return _mm_cvtsi128_si32(sum32); // SSE2 movd
}
The actual method just broadcasts the 15 bit input into all 16 words of an AVX register. From that it masks bit n
or bits [14:14-n]
in two registers, which are multiplied and added using pmaddwd
into 8 32-bit integers and afterwards reduced to a single integer which gets shifted to the desired precision.
uint32_t truncated_squarer_avx2(uint32_t x)
{
x >>= 2;
__m256i mask1 = _mm256_setr_epi16(0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0);
__m256i mask2 = _mm256_setr_epi16(0x4000, 0x6000, 0x7000, 0x7800, 0x7c00, 0x7e00, 0x7f00, 0x7f80, 0x7fc0, 0x7fe0, 0x7ff0, 0x7ff8, 0x7ffc, 0x7ffe, 0x7fff, 0);
__m256i vx = _mm256_set1_epi16(x);
__m256i a = _mm256_and_si256(vx, mask1);
__m256i b = _mm256_and_si256(vx, mask2);
__m256i hsum_256 = _mm256_madd_epi16(a, b);
__m128i hsum = _mm_add_epi32(_mm256_castsi256_si128(hsum_256), _mm256_extracti128_si256(hsum_256, 1));
return hsum_epi32_sse2(hsum) >> 15;
}
This solution could very easily be generalized to an x * y
product with the same truncation.
For reference,
here is my original solution using only SSSE3 (for pshufb
and pmaddubsw
). It splits the product into one 7x7 bit square product (the q..k
bits in OPs notation), one 1x1 product (the j
-bit), and 8 1x[1..7]
bit products (one bit of j..c
by the upper bits of q..k
) which due to symmetry later get doubled. These products can be computed efficiently using pmaddubsw
and afterwards scaled and accumulated into 4 int32 by pmaddwd
(some care has to be taken due to the signed/unsigned input of pmaddubsw
)
Afterwards the result is reduced to a single int32
, extracted and shifted to the desired precision.
uint32_t truncated_squarer_sse(uint32_t x)
{
x >>= 2; // lower two bits do not contribute to result
__m128i vx = _mm_cvtsi32_si128(x), va, vb;
// qk qk qk qk j qk qk h g f e qp q
va = _mm_shuffle_epi8(vx,_mm_setr_epi8( 1, 1, 1, 1, 0, -1,-1,-1, 1, 1, 0, 0, 0, 0, 1, 1));
va = _mm_and_si128(va, _mm_setr_epi8(0x7f,0x7f,0x7f,0x7f, 0x80, 0, 0, 0, 0x7f,0x7f,0x20,0x10,0x08,0x04,0x60,0x40));
// qk qk qk qk j j i ql qm qn qo d c
vb = _mm_shuffle_epi8(vx,_mm_setr_epi8(1, 1, 1, 1, 0, -1,-1,-1, 0, 0, 1, 1, 1, 1, 0, 0));
vb = _mm_and_si128(vb, _mm_setr_epi8(0x7f,0x7f,0x7f,0x7f, 0x80, 0, 0, 0, 0x80,0x40,0x7e,0x7c,0x78,0x70,0x02,0x01));
__m128i prod = _mm_maddubs_epi16(vb, va);
// upper square needs to be scaled by 0x10000=4*0x4000
// j-square needs to be scaled by -1, since one factor is considered as -0x80 by pmaddubsw
// the remaining triangle gets scaled by 0x100 and added twice, i.e., scaled by 0x200
__m128i hsum = _mm_madd_epi16(prod, _mm_setr_epi16(0x4000,0x4000,-1,0,0x200,0x200,0x200,0x200));
// standard 4xint32 reduction:
hsum = _mm_add_epi32(hsum, _mm_srli_epi64(hsum, 32));
hsum = _mm_add_epi32(hsum, _mm_srli_si128(hsum, 8));
// extract result and scale to desired precision:
return (_mm_cvtsi128_si32(hsum)) >> 15 ;
}
N.B. vb
could be permuted from va
(before masking), adding one cycle of latency, but saving one shuffle mask (in case constant register usage is limited).