I need the fastest (i.e. branchless, minimize uops) AVX2 code equivalent to this one:
prevlen = 0
for i=0..7:
len = matched_bytes(target[i], src)
if len > prevlen:
prevlen = len
index = i
where target[i] and src are 4-byte strings and matched_bytes returns 0..4 - number of the equal lower bytes:
def matched_bytes(target, src):
return tzcnt(target ^ src) / 8
The code below takes 15 commands. I can live without the best length, the index is enough for some of my use cases.
Can it be made in fewer commands? I care less about latencies or unfair ALU usage, since it's a part of larger code.
byte_eq = pmovmskb( pcmpeqb( broadcast(src), targets))
// bit 4*i set if byte1..4 is equal
byte_eq1 = flags
byte_eq2 = flags >> 1
byte_eq3 = flags >> 2
byte_eq4 = flags >> 3
// bit 4*i set if at least 1..4 bytes are equal
len1 = byte_eq1 & 0x11111111
len2 = len1 & byte_eq2
len3 = len2 & byte_eq3
len4 = len3 & byte_eq4
// Just one CMOV after the corresponding assignment, interleaved with the previous block
if(len2==0) len2 = len1
if(len3==0) len3 = len2
if(len4==0) len4 = len3
index = lzcnt(len4) / 4
// if len4==0 then no match was found
Here are two strategies:
A. Pack the compare mask down to 16-bits then use phminposuw
.
B. Transpose bits from pmovmskb
such that tzcnt
yields the index and length of the best match.
Method B
is probably better. However, it requires extra loads from memory for the shuffle control indices.
Both methods will use 'trailing bit manipulation' on the comparision mask to ignore bits after a mismatch.
#include <smmintrin.h> // SSE4.1 intrinsics
#include <stdint.h>
#include <stdio.h>
void method_A (uint32_t* arr, uint32_t val) {
const __m128i neg1 = _mm_set1_epi32(-1);
__m128i search_value = _mm_set1_epi32(val);
__m128i row0 = _mm_loadu_si128((__m128i*)&arr[0]);
__m128i row1 = _mm_loadu_si128((__m128i*)&arr[4]);
__m128i matched_bytes0 = _mm_cmpeq_epi8(row0, search_value);
__m128i matched_bytes1 = _mm_cmpeq_epi8(row1, search_value);
__m128i packed = _mm_packs_epi16(matched_bytes0, matched_bytes1);
__m128i t1mskc = _mm_or_si128(_mm_xor_si128(packed, neg1), _mm_sub_epi16(packed, neg1));
__m128i best_match = _mm_minpos_epu16(t1mskc);
uint32_t match_desc = (uint32_t)_mm_cvtsi128_si32(best_match);
uint32_t match_index = match_desc >> 16;
uint32_t match_length = __builtin_ctzl(match_desc | 0x10000) >> 2;
printf("index: %d, length: %d\n", match_index, match_length);
}
The 16-bit input element to the packing step is a pair of 0
or -1
compare results. These get interpreted as signed 16-bit integers and saturated to signed 8-bit -128 (0x80)
to +127 (0x7f)
.
input vpacksswb result
0xffff 0xff (-1)
0xff00 0x80 (large negative: saturates)
0x00ff 0x7f (large positive: saturates)
0x0000 0x00 (0)
This step preserves the ordering when interpreting the result byte as unsigned.
With further processing, we get the complement of the trailing one bits in each 16-bit lane. This prepares the input for phminposuw
such that a 4-byte match would map to the smallest unsigned 16-bit value, while the shortest (no match) gives the highest, with the other 3 possibilities also being in order.
#include <immintrin.h>
#include <stdint.h>
#include <stdio.h>
void method_B (uint32_t* arr, uint32_t val) {
const __m256i shuf_bytes = _mm256_set_epi8(
12,8,4,0, 13,9,5,1, 14,10,6,2, 15,11,7,3,
12,8,4,0, 13,9,5,1, 14,10,6,2, 15,11,7,3
);
const __m256i shuf_dwords = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
__m256i vec = _mm256_loadu_si256((__m256i*)arr);
__m256i diff = _mm256_xor_si256(vec, _mm256_set1_epi32(val));
__m256i tzmsk = _mm256_andnot_si256(diff, _mm256_add_epi32(_mm256_set1_epi32(-1), diff));
__m256i t0 = _mm256_shuffle_epi8(tzmsk, shuf_bytes);
__m256i t1 = _mm256_permutevar8x32_epi32(t0, shuf_dwords);
uint32_t mask = (uint32_t)_mm256_movemask_epi8(t1);
uint32_t n = (uint32_t)_tzcnt_u32(mask);
size_t len = 4 - (n >> 3);
size_t idx = n & 7;
printf("index: %d, length: %d\n", (int)idx, (int)len);
}
The mask bits are shuffled around such that tzcnt
gets both the index and length of the best match:
bit_0 = dword_0 : bit_31
bit_1 = dword_1 : bit_31
bit_2 = dword_2 : bit_31
...
bit_7 = dword_7 : bit_31
bit_8 = dword_0 : bit_23
bit_9 = dword_1 : bit_23
...
bit_30 = dword_6: bit_7
bit_31 = dword_7: bit_7