stringoptimizationsseintrinsicssse4

Optimizing find_first_not_of with SSE4.2 or earlier


I am writing a textual packet analyzer for a protocol and in optimizing it I found that a great bottleneck is the find_first_not_of call.

In essence, I need to find if a packet is valid if it contains only valid characters, faster than the default C++ function.

For instance, if all allowed characters are f, h, o, t, and w, in C++ I would just call s.find_first_not_of("fhotw"), but in SSEx I have no clue after loading the string in a set of __m128i variables.

Apparently, the _mm_cmpXstrY functions documentation is not really helping me in this. (e.g. _mm_cmpistri). I could at first subtract with _mm_sub_epi8, but I don't think it would be a great idea.

Moreover, I am stuck with SSE (any version).


Solution

  • This article by Wojciech Muła describes a SSSE3 algorithm to accept/reject any given byte value. (Contrary to the article, saturated arithmetic should be used to conduct range checks, but we don't have ranges.)

    SSE4.2 string functions are often slower** than hand-crafted alternatives. For example, 3 uops, 3 cycle throughput on Skylake for pcmpistri, the fastest of the SSE4.2 string instructions. vs. 1 shuffle and 1 pcmpeqb per 16 bytes of input with this, with SIMD AND and movemask to combine results. Plus some load and register-copy instructions, but still very likely faster than 1 vector per 3 cycles. Doesn't quite as easily handle short 0-terminated strings, though; SSE4.2 is worth considering if you also need to worry about that, instead of known-size blocks that are a multiple of the vector width.

    For "fhotw" specifically, try:

    #include <tmmintrin.h> // pshufb
    
    bool is_valid_64bytes (uint8_t* src) {
        const __m128i tab = _mm_set_epi8('o','_','_','_','_','_','_','h',
                                         'w','f','_','t','_','_','_','_');
        
        __m128i src0 = _mm_loadu_si128((__m128i*)&src[0]);
        __m128i src1 = _mm_loadu_si128((__m128i*)&src[16]);
        __m128i src2 = _mm_loadu_si128((__m128i*)&src[32]);
        __m128i src3 = _mm_loadu_si128((__m128i*)&src[48]);
        __m128i acc;
    
        acc = _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src0), src0);
        acc = _mm_and_si128(acc, _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src1), src1));
        acc = _mm_and_si128(acc, _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src2), src2));
        acc = _mm_and_si128(acc, _mm_cmpeq_epi8(_mm_shuffle_epi8(tab, src3), src3));
        return !!(((unsigned)_mm_movemask_epi8(acc)) == 0xFFFF);
    }
    

    Using the low 4 bits of the data, we can select a byte from our set that has that low nibble value. e.g. 'o' (0x6f) goes in the high byte of the table so input bytes of the form 0x?f try to match against it. i.e. it's the first element for _mm_set_epi8, which goes from high to low.

    See the full article for variations on this technique for other special / more general cases.

    **If the search is very simple (doesn't need the functionality of string instructions) or very complex (needs at least two string instructions) then it doesn't make much sense to use the string functions. Also the string instructions don't scale to the 256-bit width of AVX2.