pythonindexingheapmin-heap

Python: min heap swap count


Although there are lots of questions that have already been asked and answered regarding heap implementation in python, I was unable to find any practical clarifications about indexes. So, allow me to ask one more heap related question.

I'm trying to write a code that transforms a list of values into a min-heap and saves swapped indexes. Here is what I have so far:

def mins(a, i, res):
    n = len(a)-1
    left = 2 * i + 1
    right = 2 * i + 2
    if not (i >= n//2 and i <= n):
        if (a[i] > a[left] or a[i] > a[right]):

            if a[left] < a[right]:
                res.append([i, left])
                a[i], a[left] = a[left], a[i]
                mins(a, left, res)
            
            else:
                res.append([i, right])
                a[i], a[right] = a[right], a[i]
                mins(a, right, res)

def heapify(a, res):
    n = len(a)
    for i in range(n//2, -1, -1):
        mins(a, i, res)
    return res


a = [7, 6, 5, 4, 3, 2]
res = heapify(a, [])

print(a)
print(res)
  

Expected output:

a = [2, 3, 4, 5, 6, 7]
res = [[2, 5], [1, 4], [0, 2], [2, 5]]

What I get:

a = [3, 4, 5, 6, 7, 2]
res = [[1, 4], [0, 1], [1, 3]]

It's clear that there is something wrong with indexation in the above script. Probably something very obvious, but I just don't see it. Help out!


Solution

  • You have some mistakes in your code:

    To avoid code duplication (the two blocks where you swap), introduce a new variable that is set to either left or right depending on which one has the smaller value.

    Here is a correction:

    def mins(a, i, res):
        n = len(a)
        left = 2 * i + 1
        right = 2 * i + 2
        if left >= n:
            return
        child = left
        if right < n and a[right] < a[left]:
            child = right
        if a[child] < a[i]:  # need to swap
            res.append([i, child])
            a[i], a[child] = a[child], a[i]
            mins(a, child, res)
    
    def heapify(a, res):
        n = len(a)
        for i in range((n - 2)//2, -1, -1):
            mins(a, i, res)
        return res