int SWAR(unsigned int i)
{
i = i - ((i >> 1) & 0x55555555);
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}
I have seen this code that counts the number of bits equals to 1
in 32-bit integer, and I noticed that its performance is better than __builtin_popcount
but I can't understand the way it works.
Can someone give a detailed explanation of how this code works?
OK, let's go through the code line by line:
i = i - ((i >> 1) & 0x55555555);
First of all, the significance of the constant 0x55555555
is that, written using the Java / GCC style binary literal notation),
0x55555555 = 0b01010101010101010101010101010101
That is, all its odd-numbered bits (counting the lowest bit as bit 1 = odd) are 1
, and all the even-numbered bits are 0
.
The expression ((i >> 1) & 0x55555555)
thus shifts the bits of i
right by one, and then sets all the even-numbered bits to zero. (Equivalently, we could've first set all the odd-numbered bits of i
to zero with & 0xAAAAAAAA
and then shifted the result right by one bit.) For convenience, let's call this intermediate value j
.
What happens when we subtract this j
from the original i
? Well, let's see what would happen if i
had only two bits:
i j i - j
----------------------------------
0 = 0b00 0 = 0b00 0 = 0b00
1 = 0b01 0 = 0b00 1 = 0b01
2 = 0b10 1 = 0b01 1 = 0b01
3 = 0b11 1 = 0b01 2 = 0b10
Hey! We've managed to count the bits of our two-bit number!
OK, but what if i
has more than two bits set? In fact, it's pretty easy to check that the lowest two bits of i - j
will still be given by the table above, and so will the third and fourth bits, and the fifth and sixth bits, and so and. In particular:
despite the >> 1
, the lowest two bits of i - j
are not affected by the third or higher bits of i
, since they'll be masked out of j
by the & 0x55555555
; and
since the lowest two bits of j
can never have a greater numerical value than those of i
, the subtraction will never borrow from the third bit of i
: thus, the lowest two bits of i
also cannot affect the third or higher bits of i - j
.
In fact, by repeating the same argument, we can see that the calculation on this line, in effect, applies the table above to each of the 16 two-bit blocks in i
in parallel. That is, after executing this line, the lowest two bits of the new value of i
will now contain the number of bits set among the corresponding bits in the original value of i
, and so will the next two bits, and so on.
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
Compared to the first line, this one's quite simple. First, note that
0x33333333 = 0b00110011001100110011001100110011
Thus, i & 0x33333333
takes the two-bit counts calculated above and throws away every second one of them, while (i >> 2) & 0x33333333
does the same after shifting i
right by two bits. Then we add the results together.
Thus, in effect, what this line does is take the bitcounts of the lowest two and the second-lowest two bits of the original input, computed on the previous line, and add them together to give the bitcount of the lowest four bits of the input. And, again, it does this in parallel for all the 8 four-bit blocks (= hex digits) of the input.
return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
OK, what's going on here?
Well, first of all, (i + (i >> 4)) & 0x0F0F0F0F
does exactly the same as the previous line, except it adds the adjacent four-bit bitcounts together to give the bitcounts of each eight-bit block (i.e. byte) of the input. (Here, unlike on the previous line, we can get away with moving the &
outside the addition, since we know that the eight-bit bitcount can never exceed 8, and therefore will fit inside four bits without overflowing.)
Now we have a 32-bit number consisting of four 8-bit bytes, each byte holding the number of 1-bit in that byte of the original input. (Let's call these bytes A
, B
, C
and D
.) So what happens when we multiply this value (let's call it k
) by 0x01010101
?
Well, since 0x01010101 = (1 << 24) + (1 << 16) + (1 << 8) + 1
, we have:
k * 0x01010101 = (k << 24) + (k << 16) + (k << 8) + k
Thus, the highest byte of the result ends up being the sum of:
k
term, plusk << 8
term, plusk << 16
term, plusk << 24
term.(In general, there could also be carries from lower bytes, but since we know the value of each byte is at most 8, we know the addition will never overflow and create a carry.)
That is, the highest byte of k * 0x01010101
ends up being the sum of the bitcounts of all the bytes of the input, i.e. the total bitcount of the 32-bit input number. The final >> 24
then simply shifts this value down from the highest byte to the lowest.
Ps. This code could easily be extended to 64-bit integers, simply by changing the 0x01010101
to 0x0101010101010101
and the >> 24
to >> 56
. Indeed, the same method would even work for 128-bit integers; 256 bits would require adding one extra shift / add / mask step, however, since the number 256 no longer quite fits into an 8-bit byte.