pythonnumpyhuffman-codehuffman-tree

Huffman implementation reduce decompression time


I'm trying to improve the computational speed of this huffman. For small input hex strings its fine but the bigger the input string is the time increments considerably with a large enough string speed (example below) goes up to x50 1ms vs 55ms+

import time
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple, Optional
import numpy as np
from array import array
import ctypes
from line_profiler._line_profiler import byteorder

class Node:
    __slots__ = ['char', 'freq', 'left', 'right']

    def __init__(self, char: str, freq: int, left=None, right=None):
        self.char = char
        self.freq = freq
        self.left = left
        self.right = right




class HybridLookupTable:
    
"""Hybrid approach combining direct lookup for short codes and binary search for long codes"""
    
__slots__ = ['short_table', 'long_codes', 'max_short_bits']

    def __init__(self, max_short_bits: int = 8):
        self.max_short_bits = max_short_bits
        self.short_table = [(None, 0)] * (1 << max_short_bits)  # Changed to tuple list for safety
        self.long_codes = {}

    def add_code(self, code: str, char: str) -> None:
        code_int = int(code, 2)
        code_len = len(code)

        if code_len <= self.max_short_bits:
            # For short codes, use lookup table with limited prefix expansion
            prefix_mask = (1 << (self.max_short_bits - code_len)) - 1
            base_index = code_int << (self.max_short_bits - code_len)
            for i in range(prefix_mask + 1):
                self.short_table[base_index | i] = (char, code_len)
        else:
            # For long codes, store in dictionary
            self.long_codes[code_int] = (char, code_len)

    def lookup(self, bits: int, length: int) -> Optional[Tuple[str, int]]:
        
"""Look up a bit pattern and return (character, code length) if found"""
        
if length <= self.max_short_bits:
            return self.short_table[bits & ((1 << self.max_short_bits) - 1)]

        # Try matching long codes
        for code_bits, (char, code_len) in self.long_codes.items():
            if code_len <= length:
                mask = (1 << code_len) - 1
                if (bits >> (length - code_len)) == (code_bits & mask):
                    return (char, code_len)
        return None
class BitBuffer:
    
"""Fast bit buffer implementation using ctypes"""
    
__slots__ = ['buffer', 'bits_in_buffer']

    def __init__(self):
        self.buffer = ctypes.c_uint64(0)
        self.bits_in_buffer = 0
    def add_byte(self, byte: int) -> None:
        self.buffer.value = (self.buffer.value << 8) | byte
        self.bits_in_buffer += 8
    def peek_bits(self, num_bits: int) -> int:
        return (self.buffer.value >> (self.bits_in_buffer - num_bits)) & ((1 << num_bits) - 1)

    def consume_bits(self, num_bits: int) -> None:
        self.buffer.value &= (1 << (self.bits_in_buffer - num_bits)) - 1
        self.bits_in_buffer -= num_bits


class ChunkDecoder:
    
"""Decoder for a chunk of compressed data"""
    
__slots__ = ['lookup_table', 'tree', 'chunk_size']

    def __init__(self, lookup_table, tree, chunk_size=1024):
        self.lookup_table = lookup_table
        self.tree = tree
        self.chunk_size = chunk_size

    def decode_chunk(self, data: memoryview, start_bit: int, end_bit: int) -> Tuple[List[str], int]:
        
"""Decode a chunk of bits and return (decoded_chars, bits_consumed)"""
        
result = []
        pos = start_bit
        buffer = BitBuffer()
        bytes_processed = start_bit >> 3
        bit_offset = start_bit & 7
        # Pre-fill buffer
        for _ in range(8):
            if bytes_processed < len(data):
                buffer.add_byte(data[bytes_processed])
                bytes_processed += 1
        # Skip initial bit offset
        if bit_offset:
            buffer.consume_bits(bit_offset)

        while pos < end_bit and buffer.bits_in_buffer >= 8:
            # Try lookup table first (optimized for 8-bit codes)
            lookup_bits = buffer.peek_bits(8)
            char_info = self.lookup_table.lookup(lookup_bits, 8)

            if char_info:
                char, code_len = char_info
                buffer.consume_bits(code_len)
                result.append(char)
                pos += code_len
            else:
                # Fall back to tree traversal
                node = self.tree
                while node.left and node.right and buffer.bits_in_buffer > 0:
                    bit = buffer.peek_bits(1)
                    buffer.consume_bits(1)
                    node = node.right if bit else node.left
                    pos += 1
                if not (node.left or node.right):
                    result.append(node.char)

            # Refill buffer if needed
            while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
                buffer.add_byte(data[bytes_processed])
                bytes_processed += 1
        return result, pos - start_bit


class OptimizedHuffmanDecoder:
    def __init__(self, num_threads=4, chunk_size=1024):
        self.tree = None
        self.freqs = {}
        self.lookup_table = HybridLookupTable()
        self.num_threads = num_threads
        self.chunk_size = chunk_size
        self._setup_lookup_tables()

    def _setup_lookup_tables(self):
        # Pre-calculate bit manipulation tables
        self.bit_masks = array('Q', [(1 << i) - 1 for i in range(65)])
        self.bit_shifts = array('B', [x & 7 for x in range(8)])

    def _build_efficient_tree(self) -> None:
        # Use list-based heap instead of sorting
        nodes = [(freq, i, Node(char, freq)) for i, (char, freq) in enumerate(self.freqs.items())]

        # Convert to min-heap
        nodes.sort(reverse=True)  # Sort once at the beginning
        while len(nodes) > 1:
            freq1, _, node1 = nodes.pop()
            freq2, _, node2 = nodes.pop()

            # Create parent node
            parent = Node(node1.char + node2.char, freq1 + freq2, node1, node2)
            nodes.append((freq1 + freq2, len(nodes), parent))
            nodes.sort(reverse=True)

        self.tree = nodes[0][2] if nodes else None
        self._build_codes(self.tree)

    def _build_codes(self, node: Node, code: str = '') -> None:
        
"""Build lookup table using depth-first traversal"""
        
if not node:
            return
        if not node.left and not node.right:
            if code:  # Never store empty codes
                self.lookup_table.add_code(code, node.char)
            return
        if node.left:
            self._build_codes(node.left, code + '0')
        if node.right:
            self._build_codes(node.right, code + '1')

    def _parse_header_fast(self, data: memoryview) -> int:
        
"""Optimized header parsing"""
        
pos = 12  # Skip first 12 bytes (file_len, always0, chars_count)
        chars_count = int.from_bytes(data[8:12], byteorder)

        # Pre-allocate dictionary space
        self.freqs = {}
        self.freqs.clear()

        # Process all characters in a single loop
        for _ in range(chars_count):
            count = int.from_bytes(data[pos:pos + 4], byteorder)
            char = chr(data[pos + 4])  # Faster than decode
            self.freqs[char] = count
            pos += 8
        return pos

    def _decode_bits_parallel(self, data: memoryview, total_bits: int) -> str:
        
"""Parallel decoding using multiple threads"""
        
chunk_bits = (total_bits + self.num_threads - 1) // self.num_threads
        chunks = []

        # Create chunks ensuring they align with byte boundaries when possible
        for i in range(0, total_bits, chunk_bits):
            end_bit = min(i + chunk_bits, total_bits)
            if i > 0:
                # Align to byte boundary when possible
                while (i & 7) != 0 and i > 0:
                    i -= 1
            chunks.append((i, end_bit))

        # Create decoders for each thread
        decoders = [
            ChunkDecoder(self.lookup_table, self.tree, self.chunk_size)
            for _ in range(len(chunks))
        ]

        # Process chunks in parallel
        with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
            futures = [
                executor.submit(decoder.decode_chunk, data, start, end)
                for decoder, (start, end) in zip(decoders, chunks)
            ]

            # Collect results
            results = []
            for future in futures:
                chunk_result, _ = future.result()
                results.extend(chunk_result)

        return ''.join(results)

    def _decode_bits_optimized(self, data: memoryview, total_bits: int) -> str:
        
"""Optimized single-threaded decoding for small inputs"""
        
if total_bits > self.chunk_size:
            return self._decode_bits_parallel(data, total_bits)

        result = []
        buffer = BitBuffer()
        pos = 0
        bytes_processed = 0
        # Pre-fill buffer
        while bytes_processed < min(8, len(data)):
            buffer.add_byte(data[bytes_processed])
            bytes_processed += 1
        while pos < total_bits:
            # Use lookup table for common patterns
            if buffer.bits_in_buffer >= 8:
                lookup_bits = buffer.peek_bits(8)
                char_info = self.lookup_table.lookup(lookup_bits, 8)

                if char_info:
                    char, code_len = char_info
                    buffer.consume_bits(code_len)
                    result.append(char)
                    pos += code_len
                else:
                    # Tree traversal for uncommon patterns
                    node = self.tree
                    while node.left and node.right and buffer.bits_in_buffer > 0:
                        bit = buffer.peek_bits(1)
                        buffer.consume_bits(1)
                        node = node.right if bit else node.left
                        pos += 1
                    if not (node.left or node.right):
                        result.append(node.char)

            # Refill buffer
            while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
                buffer.add_byte(data[bytes_processed])
                bytes_processed += 1
            if buffer.bits_in_buffer == 0:
                break
        return ''.join(result)

    def decode_hex(self, hex_string: str) -> str:
        # Use numpy for faster hex decoding
        clean_hex = hex_string.replace(' ', '')
        data = np.frombuffer(bytes.fromhex(clean_hex), dtype=np.uint8)
        return self.decode_bytes(data.tobytes())

    def decode_bytes(self, data: bytes) -> str:
        view = memoryview(data)
        pos = self._parse_header_fast(view)

        self._build_efficient_tree()

        # Get packed data info using numpy for faster parsing
        header = np.frombuffer(data[pos:pos + 12], dtype=np.uint32)
        packed_bits = int(header[0])
        packed_bytes = int(header[1])
        pos += 12
        # Choose decoding method based on size
        if packed_bits > self.chunk_size:
            return self._decode_bits_parallel(view[pos:pos + packed_bytes], packed_bits)
        else:
            return self._decode_bits_optimized(view[pos:pos + packed_bytes], packed_bits)

    def encode(self, text: str) -> bytes:
        
"""Encode text using Huffman coding - for testing purposes"""
        
# Count frequencies
        self.freqs = {}
        for char in text:
            self.freqs[char] = self.freqs.get(char, 0) + 1
        # Build tree and codes
        self._build_efficient_tree()

        # Convert text to bits
        bits = []
        for char in text:
            code = self.lookup_table.get_code(char)
            bits.extend(code)

        # Pack bits into bytes
        packed_bytes = []
        for i in range(0, len(bits), 8):
            byte = 0
            for j in range(min(8, len(bits) - i)):
                if bits[i + j]:
                    byte |= 1 << (7 - j)
            packed_bytes.append(byte)

        # Create header
        header = bytearray()
        header.extend(len(text).to_bytes(4, byteorder))
        header.extend(b'\x00' * 4)  # always0
        header.extend(len(self.freqs).to_bytes(4, byteorder))

        # Add frequency table
        for char, freq in self.freqs.items():
            header.extend(freq.to_bytes(4, byteorder))
            header.extend(char.encode('ascii'))
            header.extend(b'\x00' * 3)  # padding
        # Add packed data info
        header.extend(len(bits).to_bytes(4, byteorder))
        header.extend(len(packed_bytes).to_bytes(4, byteorder))
        header.extend(b'\x00' * 4)  # unpacked_bytes
        # Combine header and packed data
        return bytes(header + bytes(packed_bytes))

if __name__ == '__main__':
    # Create decoder with custom settings
    decoder = OptimizedHuffmanDecoder(
        num_threads=4,  # Number of threads for parallel processing
        chunk_size=1024  # Minimum size for parallel processing
    )

    test_hex = 'A7 64 00 00 00 00 00 00 0C 00 00 00 38 25 00 00 2D 00 00 00 08 69 00 00 30 00 00 00 2E 13 00 00 31 00 00 00 D4 13 00 00 32 00 00 00 0F 0D 00 00 33 00 00 00 78 08 00 00 34 00 00 00 A4 0A 00 00 35 00 00 00 63 0E 00 00 36 00 00 00 AC 09 00 00 37 00 00 00 D0 07 00 00 38 00 00 00 4D 09 00 00 39 00 00 00 68 0C 00 00 7C 00 00 00 73 21 03 00 2F 64 00 00 01 0B 01 00 C9 63 2A C7 21 77 40 77 25 8D AB E9 E5 E7 80 77'
    start_time = time.perf_counter()
    # Decode data
    result = decoder.decode_hex(test_hex)
    execution_time_ms = (time.perf_counter() - start_time) * 1000  # Convert to milliseconds
    print(f"\nTotal execution time: {execution_time_ms:.2f} milliseconds")
    print(result)

expected output: Total execution time: 1.04 milliseconds 19101-0-418-220000000|19102-0-371-530000000

But if you try with a bigger string it gets extremely slow, id like to improve the performance i tried cythoning it but didnt improve it by any mean, if anyone has any idea of what i can be doing wrong With this second input hex it takes 55ms

bigger hex input example text

I'm wondering if I'm doing anything bad and theres any way of speeding the process up, I tried for hours everything that come to my mind and I'm not sure how to improve further.


Solution

  • I'd like to improve the performance

    For virtually every performance question, the answer is:

    Profile Your Program

    If you haven't profiled, you can't be sure where the slowness is coming from. If you don't know what's slow, you can only speculate about how to make it faster.

    Python has profiling tools built in. Give them a try, and/or use timeit for micro benchmarks.

    Also look at moving things you don't want to profile (eg. hex string conversion) outside the part you're timing.

    Finally, you may well be able to get much better performance in C, or C++, or Rust or some other compiled language - but you'll need to learn how to profile those too, to get the best out of them.