I'm trying to improve the computational speed of this huffman. For small input hex strings its fine but the bigger the input string is the time increments considerably with a large enough string speed (example below) goes up to x50 1ms vs 55ms+
import time
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple, Optional
import numpy as np
from array import array
import ctypes
from line_profiler._line_profiler import byteorder
class Node:
__slots__ = ['char', 'freq', 'left', 'right']
def __init__(self, char: str, freq: int, left=None, right=None):
self.char = char
self.freq = freq
self.left = left
self.right = right
class HybridLookupTable:
"""Hybrid approach combining direct lookup for short codes and binary search for long codes"""
__slots__ = ['short_table', 'long_codes', 'max_short_bits']
def __init__(self, max_short_bits: int = 8):
self.max_short_bits = max_short_bits
self.short_table = [(None, 0)] * (1 << max_short_bits) # Changed to tuple list for safety
self.long_codes = {}
def add_code(self, code: str, char: str) -> None:
code_int = int(code, 2)
code_len = len(code)
if code_len <= self.max_short_bits:
# For short codes, use lookup table with limited prefix expansion
prefix_mask = (1 << (self.max_short_bits - code_len)) - 1
base_index = code_int << (self.max_short_bits - code_len)
for i in range(prefix_mask + 1):
self.short_table[base_index | i] = (char, code_len)
else:
# For long codes, store in dictionary
self.long_codes[code_int] = (char, code_len)
def lookup(self, bits: int, length: int) -> Optional[Tuple[str, int]]:
"""Look up a bit pattern and return (character, code length) if found"""
if length <= self.max_short_bits:
return self.short_table[bits & ((1 << self.max_short_bits) - 1)]
# Try matching long codes
for code_bits, (char, code_len) in self.long_codes.items():
if code_len <= length:
mask = (1 << code_len) - 1
if (bits >> (length - code_len)) == (code_bits & mask):
return (char, code_len)
return None
class BitBuffer:
"""Fast bit buffer implementation using ctypes"""
__slots__ = ['buffer', 'bits_in_buffer']
def __init__(self):
self.buffer = ctypes.c_uint64(0)
self.bits_in_buffer = 0
def add_byte(self, byte: int) -> None:
self.buffer.value = (self.buffer.value << 8) | byte
self.bits_in_buffer += 8
def peek_bits(self, num_bits: int) -> int:
return (self.buffer.value >> (self.bits_in_buffer - num_bits)) & ((1 << num_bits) - 1)
def consume_bits(self, num_bits: int) -> None:
self.buffer.value &= (1 << (self.bits_in_buffer - num_bits)) - 1
self.bits_in_buffer -= num_bits
class ChunkDecoder:
"""Decoder for a chunk of compressed data"""
__slots__ = ['lookup_table', 'tree', 'chunk_size']
def __init__(self, lookup_table, tree, chunk_size=1024):
self.lookup_table = lookup_table
self.tree = tree
self.chunk_size = chunk_size
def decode_chunk(self, data: memoryview, start_bit: int, end_bit: int) -> Tuple[List[str], int]:
"""Decode a chunk of bits and return (decoded_chars, bits_consumed)"""
result = []
pos = start_bit
buffer = BitBuffer()
bytes_processed = start_bit >> 3
bit_offset = start_bit & 7
# Pre-fill buffer
for _ in range(8):
if bytes_processed < len(data):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
# Skip initial bit offset
if bit_offset:
buffer.consume_bits(bit_offset)
while pos < end_bit and buffer.bits_in_buffer >= 8:
# Try lookup table first (optimized for 8-bit codes)
lookup_bits = buffer.peek_bits(8)
char_info = self.lookup_table.lookup(lookup_bits, 8)
if char_info:
char, code_len = char_info
buffer.consume_bits(code_len)
result.append(char)
pos += code_len
else:
# Fall back to tree traversal
node = self.tree
while node.left and node.right and buffer.bits_in_buffer > 0:
bit = buffer.peek_bits(1)
buffer.consume_bits(1)
node = node.right if bit else node.left
pos += 1
if not (node.left or node.right):
result.append(node.char)
# Refill buffer if needed
while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
return result, pos - start_bit
class OptimizedHuffmanDecoder:
def __init__(self, num_threads=4, chunk_size=1024):
self.tree = None
self.freqs = {}
self.lookup_table = HybridLookupTable()
self.num_threads = num_threads
self.chunk_size = chunk_size
self._setup_lookup_tables()
def _setup_lookup_tables(self):
# Pre-calculate bit manipulation tables
self.bit_masks = array('Q', [(1 << i) - 1 for i in range(65)])
self.bit_shifts = array('B', [x & 7 for x in range(8)])
def _build_efficient_tree(self) -> None:
# Use list-based heap instead of sorting
nodes = [(freq, i, Node(char, freq)) for i, (char, freq) in enumerate(self.freqs.items())]
# Convert to min-heap
nodes.sort(reverse=True) # Sort once at the beginning
while len(nodes) > 1:
freq1, _, node1 = nodes.pop()
freq2, _, node2 = nodes.pop()
# Create parent node
parent = Node(node1.char + node2.char, freq1 + freq2, node1, node2)
nodes.append((freq1 + freq2, len(nodes), parent))
nodes.sort(reverse=True)
self.tree = nodes[0][2] if nodes else None
self._build_codes(self.tree)
def _build_codes(self, node: Node, code: str = '') -> None:
"""Build lookup table using depth-first traversal"""
if not node:
return
if not node.left and not node.right:
if code: # Never store empty codes
self.lookup_table.add_code(code, node.char)
return
if node.left:
self._build_codes(node.left, code + '0')
if node.right:
self._build_codes(node.right, code + '1')
def _parse_header_fast(self, data: memoryview) -> int:
"""Optimized header parsing"""
pos = 12 # Skip first 12 bytes (file_len, always0, chars_count)
chars_count = int.from_bytes(data[8:12], byteorder)
# Pre-allocate dictionary space
self.freqs = {}
self.freqs.clear()
# Process all characters in a single loop
for _ in range(chars_count):
count = int.from_bytes(data[pos:pos + 4], byteorder)
char = chr(data[pos + 4]) # Faster than decode
self.freqs[char] = count
pos += 8
return pos
def _decode_bits_parallel(self, data: memoryview, total_bits: int) -> str:
"""Parallel decoding using multiple threads"""
chunk_bits = (total_bits + self.num_threads - 1) // self.num_threads
chunks = []
# Create chunks ensuring they align with byte boundaries when possible
for i in range(0, total_bits, chunk_bits):
end_bit = min(i + chunk_bits, total_bits)
if i > 0:
# Align to byte boundary when possible
while (i & 7) != 0 and i > 0:
i -= 1
chunks.append((i, end_bit))
# Create decoders for each thread
decoders = [
ChunkDecoder(self.lookup_table, self.tree, self.chunk_size)
for _ in range(len(chunks))
]
# Process chunks in parallel
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = [
executor.submit(decoder.decode_chunk, data, start, end)
for decoder, (start, end) in zip(decoders, chunks)
]
# Collect results
results = []
for future in futures:
chunk_result, _ = future.result()
results.extend(chunk_result)
return ''.join(results)
def _decode_bits_optimized(self, data: memoryview, total_bits: int) -> str:
"""Optimized single-threaded decoding for small inputs"""
if total_bits > self.chunk_size:
return self._decode_bits_parallel(data, total_bits)
result = []
buffer = BitBuffer()
pos = 0
bytes_processed = 0
# Pre-fill buffer
while bytes_processed < min(8, len(data)):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
while pos < total_bits:
# Use lookup table for common patterns
if buffer.bits_in_buffer >= 8:
lookup_bits = buffer.peek_bits(8)
char_info = self.lookup_table.lookup(lookup_bits, 8)
if char_info:
char, code_len = char_info
buffer.consume_bits(code_len)
result.append(char)
pos += code_len
else:
# Tree traversal for uncommon patterns
node = self.tree
while node.left and node.right and buffer.bits_in_buffer > 0:
bit = buffer.peek_bits(1)
buffer.consume_bits(1)
node = node.right if bit else node.left
pos += 1
if not (node.left or node.right):
result.append(node.char)
# Refill buffer
while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
if buffer.bits_in_buffer == 0:
break
return ''.join(result)
def decode_hex(self, hex_string: str) -> str:
# Use numpy for faster hex decoding
clean_hex = hex_string.replace(' ', '')
data = np.frombuffer(bytes.fromhex(clean_hex), dtype=np.uint8)
return self.decode_bytes(data.tobytes())
def decode_bytes(self, data: bytes) -> str:
view = memoryview(data)
pos = self._parse_header_fast(view)
self._build_efficient_tree()
# Get packed data info using numpy for faster parsing
header = np.frombuffer(data[pos:pos + 12], dtype=np.uint32)
packed_bits = int(header[0])
packed_bytes = int(header[1])
pos += 12
# Choose decoding method based on size
if packed_bits > self.chunk_size:
return self._decode_bits_parallel(view[pos:pos + packed_bytes], packed_bits)
else:
return self._decode_bits_optimized(view[pos:pos + packed_bytes], packed_bits)
def encode(self, text: str) -> bytes:
"""Encode text using Huffman coding - for testing purposes"""
# Count frequencies
self.freqs = {}
for char in text:
self.freqs[char] = self.freqs.get(char, 0) + 1
# Build tree and codes
self._build_efficient_tree()
# Convert text to bits
bits = []
for char in text:
code = self.lookup_table.get_code(char)
bits.extend(code)
# Pack bits into bytes
packed_bytes = []
for i in range(0, len(bits), 8):
byte = 0
for j in range(min(8, len(bits) - i)):
if bits[i + j]:
byte |= 1 << (7 - j)
packed_bytes.append(byte)
# Create header
header = bytearray()
header.extend(len(text).to_bytes(4, byteorder))
header.extend(b'\x00' * 4) # always0
header.extend(len(self.freqs).to_bytes(4, byteorder))
# Add frequency table
for char, freq in self.freqs.items():
header.extend(freq.to_bytes(4, byteorder))
header.extend(char.encode('ascii'))
header.extend(b'\x00' * 3) # padding
# Add packed data info
header.extend(len(bits).to_bytes(4, byteorder))
header.extend(len(packed_bytes).to_bytes(4, byteorder))
header.extend(b'\x00' * 4) # unpacked_bytes
# Combine header and packed data
return bytes(header + bytes(packed_bytes))
if __name__ == '__main__':
# Create decoder with custom settings
decoder = OptimizedHuffmanDecoder(
num_threads=4, # Number of threads for parallel processing
chunk_size=1024 # Minimum size for parallel processing
)
test_hex = 'A7 64 00 00 00 00 00 00 0C 00 00 00 38 25 00 00 2D 00 00 00 08 69 00 00 30 00 00 00 2E 13 00 00 31 00 00 00 D4 13 00 00 32 00 00 00 0F 0D 00 00 33 00 00 00 78 08 00 00 34 00 00 00 A4 0A 00 00 35 00 00 00 63 0E 00 00 36 00 00 00 AC 09 00 00 37 00 00 00 D0 07 00 00 38 00 00 00 4D 09 00 00 39 00 00 00 68 0C 00 00 7C 00 00 00 73 21 03 00 2F 64 00 00 01 0B 01 00 C9 63 2A C7 21 77 40 77 25 8D AB E9 E5 E7 80 77'
start_time = time.perf_counter()
# Decode data
result = decoder.decode_hex(test_hex)
execution_time_ms = (time.perf_counter() - start_time) * 1000 # Convert to milliseconds
print(f"\nTotal execution time: {execution_time_ms:.2f} milliseconds")
print(result)
expected output: Total execution time: 1.04 milliseconds 19101-0-418-220000000|19102-0-371-530000000
But if you try with a bigger string it gets extremely slow, id like to improve the performance i tried cythoning it but didnt improve it by any mean, if anyone has any idea of what i can be doing wrong With this second input hex it takes 55ms
bigger hex input example text
I'm wondering if I'm doing anything bad and theres any way of speeding the process up, I tried for hours everything that come to my mind and I'm not sure how to improve further.
I'd like to improve the performance
For virtually every performance question, the answer is:
If you haven't profiled, you can't be sure where the slowness is coming from. If you don't know what's slow, you can only speculate about how to make it faster.
Python has profiling tools built in. Give them a try, and/or use timeit
for micro benchmarks.
Also look at moving things you don't want to profile (eg. hex string conversion) outside the part you're timing.
Finally, you may well be able to get much better performance in C, or C++, or Rust or some other compiled language - but you'll need to learn how to profile those too, to get the best out of them.