pythonalgorithmsortingcomplexity-theorypypy

Radix sort slower than expected compared to standard sort


I've implemented two versions of radix sort (the version that allows sorting integers whose values go up to n² where n is the size of the list to sort) in Python for benchmarking against the standard sort (Timsort). I'm using PyPy for a fairer comparison.

Surprisingly my radix sort implementation, even without using a hashmap (using a direct access array instead), is slower than the standard sort even for larger input sizes. There is micro optimisation that I'm lacking since O(n) should eventually beat O(nlogn). I'm seeking advice to achieve better performance. I'm doing this for learning purposes, therefore I'm not looking for built-ins, libraries, or C-compiled code to call from Python.

Are there micro-optimisations I could bring? Is my code really O(n)?

enter image description here

My code takes up to 10 seconds to run on an AMD Ryzen 9 7950X CPU:

import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict

def radix_sort(arr, size):
    least_sig_digit = defaultdict(list)
    for num in arr:
        q, r = divmod(num, size)
        least_sig_digit[r].append(q)
    highest_sig_digit = defaultdict(list)
    for k in range(size):  # k goes in order of lowest significant digit
        for q in least_sig_digit[k]:
            highest_sig_digit[q].append(q*size+k)
    i: int = 0
    for k in range(size):
        for num in highest_sig_digit[k]:
            arr[i] = num
            i += 1
    return arr

def radix_sort_no_hashmap(arr, size):
    least_sig_digit = [[] for _ in range(size)]
    for num in arr:
        q, r = divmod(num, size)
        least_sig_digit[r].append(q)
    highest_sig_digit = [[] for _ in range(size)]
    for k in range(size):  # k goes in order of lowest significant digit
        for q in least_sig_digit[k]:
            highest_sig_digit[q].append(q*size+k)
    i: int = 0
    for k in range(size):
        for num in highest_sig_digit[k]:
            arr[i] = num
            i += 1
    return arr


def benchmark_sorting_algorithms():
    sizes = [1000, 10000, 100000, 200000, 1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 10000000]
    radix_times = []
    radix_sort_no_hashmap_times = []
    std_sort_times = []

    for size in sizes:
        array = random.sample(range(1, size**2), size)

        new_arr = array.copy()
        start_time = time.time()
        a = radix_sort(new_arr, size)
        radix_times.append(time.time() - start_time)

        new_arr = array.copy()
        start_time = time.time()
        b = radix_sort_no_hashmap(new_arr, size)
        radix_sort_no_hashmap_times.append(time.time() - start_time)

        new_arr = array.copy()
        start_time = time.time()
        c = sorted(new_arr)
        std_sort_times.append(time.time() - start_time)

        for k in range(len(array)):
            assert a[k] == b[k] == c[k]

    return sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times


sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times = benchmark_sorting_algorithms()

plt.figure(figsize=(12, 6))
plt.plot(sizes, radix_times, label='Radix Sort (O(n))')
plt.plot(sizes, std_sort_times, label='Standard Sort (O(nlogn))')
plt.plot(sizes, radix_sort_no_hashmap_times, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()

The same question was posted here in CPP following the advice of @user24714692.


Solution

  • Here's a version that seems about 1.5-2.5 times as fast as yours in current CPython and some old PyPy (I hope you can include it in your benchmark/plot). I imagine it might be even more beneficial in C++.

    It avoids creating so many list objects. It instead uses linked lists in the form of arr-indices stored in three long lists. One linked list per remainder/quotient: first[r] and last[r] tell the ends of the linked list for remainder r, and next[i] tells the next element in a list. First/Last/Next are the same but for the quotient q.

    def radix_sorted(arr):
        n = len(arr)
    
        # Sort by x%n
        first = [-1] * n
        last = [-1] * n
        next = [-1] * n
        for i, x in enumerate(arr):
            r = x % n
            if first[r] == -1:
                first[r] = i
            else:
                next[last[r]] = i
            last[r] = i
    
        # Sort by x//n
        First = [-1] * n
        Last = [-1] * n
        Next = [-1] * n
        for r in range(n):
            i = first[r]
            while i != -1:
                x = arr[i]
                q = x // n
                if First[q] == -1:
                    First[q] = i
                else:
                    Next[Last[q]] = i
                Last[q] = i
                i = next[i]
    
        # Output
        out = []
        for q in range(n):
            i = First[q]
            while i != -1:
                out.append(arr[i])
                i = Next[i]
        return out
    
    
    # Demo / testing
    import random
    n = 10**6
    arr = random.sample(range(n**2), n)
    print(sorted(arr) == radix_sorted(arr))
    

    Attempt This Online!

    On the plot of OP: enter image description here