I'm trying to implement an extended Hamming code encoder after viewing 3Blue1Brown's excellent videos on the subject, and I can't seem to figure out what I'm doing wrong. I have the following code
import math
def get_bitstring(n, bit_length) -> str:
return format(n, f'0{bit_length}b')
def get_bit(n: int, bit_index: int) -> int:
return (n >> bit_index) & 1
def set_bit(n: int, bit_index: int, value: int) -> int:
mask = 1 << bit_index
n &= ~mask
if value:
n |= mask
return n
def get_on_bits(bits, bit_length):
# Global parity has index 0, so subtract the index from block_size - 1
return [bit_length - i - 1 for i, bit in enumerate(bits) if bit]
class HammingCode:
def __init__(self, bits_per_block: int) -> None:
self.raw_message: int = 0
self.encoded_message: int = 0
self.bits_per_block: int = bits_per_block
self.parity_bits_per_block: int = int(math.log2(self.bits_per_block)) + 1
self.data_bits_per_block: int = self.bits_per_block - self.parity_bits_per_block
def encode(self, message: bytes | bytearray) -> int:
self.raw_message = int.from_bytes(message, 'big')
self.copy_bits()
self.compute_parity()
return self.encoded_message
def copy_bits(self) -> int:
self.encoded_message = 0
bits_used = 0
# Copies bits into encoded message
for i in range(self.bits_per_block):
if not self.is_parity_bit(i):
bit_value = get_bit(self.raw_message, bits_used)
self.encoded_message = set_bit(self.encoded_message, i, bit_value)
bits_used += 1
assert(bits_used == self.data_bits_per_block)
return self.encoded_message
def compute_parity(self):
from functools import reduce
bits = [int(bit) for bit in self.get_bitstring()]
parity = reduce(lambda x, y: x ^ y, get_on_bits(bits, self.bits_per_block))
for i in range(1, self.parity_bits_per_block):
parity_index = 1 << i
if parity & parity_index != 0:
bit_value = 0 if get_bit(self.encoded_message, parity_index) else 1
self.encoded_message = set_bit(self.encoded_message, i + 1, bit_value)
bits = [int(bit) for bit in self.get_bitstring()]
parity = len(get_on_bits(bits, self.bits_per_block))
self.encoded_message = set_bit(self.encoded_message, 0, 1 if parity & 1 == 1 else 0)
return self.encoded_message
def is_parity_bit(self, bit_index: int) -> int:
assert(bit_index < self.bits_per_block)
return (bit_index & (bit_index - 1)) == 0
def get_bitstring(self) -> str:
return get_bitstring(self.encoded_message, self.bits_per_block)
bytes = b'\x03\x8c'
input = int.from_bytes(bytes)
print(f"Running with {hex(input)} {get_bitstring(input, 16)}")
h = HammingCode(16)
n = h.encode(bytes)
print(f"Encoded message: {get_bitstring(int.from_bytes(bytes), h.data_bits_per_block)} => {get_bitstring(n, h.bits_per_block)}")
print(f"Encoded block in hex is {hex(n)}")
Any tips or ideas on what I'm doing wrong are much appreciated!
I get the correct output for some values, but not others. For example, I get the expected 0xf if I give 0x1 as input, and I get 0x801f with 0x400 as input. If I give 0x38c as input, I should get 0x71d4 (I've found this by doing it on paper), but I instead get 0x70ce.
I've tested and retested each function on its own, such as copy_bits
, which inserts a 0 at each index that is a power of 2, and the functions above the HammingCode
class, and they all return what I expect. With the 0x38c example, copy_bits
returns 0b0111000011000000, which is what I want (it has the 5 parity bits inserted in 0b01110001100).
I believe the problem is in my compute_parity
method, but I can't for the life of me figure it out. The logic seems to hold. After xor-ing the indices of the on bits, I want to set the parity bits to ensure that if I call reduce(...)
again, it would return 0. For each bit that's set in parity
, I want to flip that bit in the encoded message. After which, I xor all the bits and set the 0th bit to ensure even parity.
The problem was indeed in compute_parity
. I was looping over all the bits again, rather than just setting the parity bits. Even my check to see if I should set the bit was wrong, I should have been checking parity_index
against i
, not the parity
value returned by reduce
. Here is the corrected code:
def compute_parity(self) -> int:
from functools import reduce
if self.encoded_message == 0:
return 0
bits = [int(bit) for bit in self.get_bitstring()]
parity = reduce(lambda x, y: x ^ y, get_on_bits(bits, self.bits_per_block))
parity_bits = list(get_bitstring(parity, self.parity_bits_per_block - 1))
parity_bits = parity_bits[::-1] # Needs to be reversed to have the LSB at index 0
for i, parity_bit in enumerate(parity_bits):
parity_index = 1 << i
self.encoded_message = set_bit(self.encoded_message, parity_index, int(parity_bit))
bits = [int(bit) for bit in self.get_bitstring()]
parity = len(get_on_bits(bits, self.bits_per_block))
self.encoded_message = set_bit(self.encoded_message, 0, 1 if parity & 1 else 0)
assert(reduce(lambda x, y: x ^ y, get_on_bits(bits, self.bits_per_block)) == 0)
return self.encoded_message
Also had to make sure that the bit value I was passing to set_bit
was not a string.