Trying to implement Introsort using Python.
The pseudo-code given is:
1 n ←|A|
2 if n ≤ 1
3 return
4 elseif d = 0
5 Heap-Sort(A)
6 else
7 p ← Partition(A) // Partitions A and returns pivot position
8 Intro-Sort(A[0:p],d−1)
9 Intro-Sort(A[p+1:n],d−1)
My source code is:
import math
def introSort(a,d):
n = len(a)
if n <= 1:
return
elif d == 0:
heapSort(a)
else:
p = partition(a)
a1 = a[0:p]
a2 = a[p+1:n]
introSort(a1, d-1)
introSort(a2, d-1)
a = a1 + [a[p]] + a2
def heapSort (a):
END = len(a)
for k in range (math.floor(END/2) - 1, -1, -1):
heapify(a, END, k)
for k in range(END, 1, -1):
swap(a, 0, k-1)
heapify(a, k-1, 0)
def partition(a):
x = a[len(a) - 1]
i = -1
for j in range(0, len(a) - 2):
if a[j] <= x:
i = i + 1
swap(a, i, j)
swap(a, i + 1, len(a) - 1)
return i + 1
def swap(a, i, j):
tmp = a[i]
a[i] = a[j]
a[j] = tmp
def heapify(a,iEnd,iRoot):
iL = 2*iRoot + 1
iR = 2*iRoot + 2
if iR < iEnd:
if (a[iRoot] >= a[iL] and a[iRoot] >= a[iR]):
return
else:
if(a[iL] > a[iR]):
j = iL
else:
j = iR
swap(a, iRoot, j)
heapify(a, iEnd, j)
elif iL < iEnd:
if (a[iRoot] >= a[iL]):
return
else:
swap(a, iRoot, iL)
heapify(a,iEnd,iL)
else:
return
a = [3,5,6,1,23,521,6243,632,123,53,62,421,15,672,7,435,21]
introSort(a,2)
print(a)
The result given was wrong:
>python introsort.py
[3, 5, 6, 1, 15, 7, 21, 632, 123, 53, 62, 421, 23, 672, 521, 435, 6243]
It seems that it stopped straight after partition and the sorting on sublists was not working. It was clear that 21 was the pivot and the partition worked perfectly.
Can anybody point out my mistake? Thank you very much!
I actually fixed this by using a helper function that calls heapsort on a piece of the list. Instead of doing it by modify the heapsort function, I did this in a separate copy and update the sorted elements back in the original list. The code shown as:
import math
def introSort(a, d, start, end):
n = end - start
if n <= 1:
return
elif d == 0:
introHS(a, start, end)
else:
p = partition(a, start, end)
introSort(a, d-1, start, p)
introSort(a, d-1, p+1, end)
def introHS (a, start, end):
b = a[start:end]
heapSort(b)
for i in range(0,len(b)):
a[start+i] = b[i]
def heapSort (a):
END = len(a)
for k in range (math.floor(END/2) - 1, -1, -1):
heapify(a, END, k)
for k in range(END, 1, -1):
swap(a, 0, k-1)
heapify(a, k-1, 0)
def partition(a, start, end):
x = a[end-1]
i = start-1
for j in range(start, end-1):
if a[j] <= x:
i=i+1
swap(a, i, j)
swap(a, i+1, end-1)
return i+1
def swap(a, i, j):
tmp = a[i]
a[i] = a[j]
a[j] = tmp
def heapify(a,iEnd,iRoot):
iL = 2*iRoot + 1
iR = 2*iRoot + 2
if iR < iEnd:
if (a[iRoot] >= a[iL] and a[iRoot] >= a[iR]):
return
else:
if(a[iL] > a[iR]):
j = iL
else:
j = iR
swap(a, iRoot, j)
heapify(a, iEnd, j)
elif iL < iEnd:
if (a[iRoot] >= a[iL]):
return
else:
swap(a, iRoot, iL)
heapify(a,iEnd,iL)
else:
return
It worked fine by testing with:
a = [3,5,6,1,23,521,6243,632,123,53,62,421,15,672,7,435,21,123,41,52,6234,11,55,6345,324,58,46,2,123,152,6156,46,34,3426,5341,16,3314,34,73416,345]
print("Original:")
print(a)
for i in range(0,15):
introSort(a,i,0,len(a))
print("d=" + str(i))
print(a)
Giving:
>Original:
>[3, 5, 6, 1, 23, 521, 6243, 632, 123, 53, 62, 421, 15, 672, 7, 435, 21, 123, 41, 52, 6234, 11, 55, 6345, 324, 58, 46, 2, 123, 152, 6156, 46, 34, 3426, 5341, 16, 3314, 34, 73416, 345]
>d=0
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=1
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=2
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=3
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=4
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=5
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=6
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=7
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=8
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=9
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=10
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=11
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=12
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=13
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]
>d=14
>[1, 2, 3, 5, 6, 7, 11, 15, 16, 21, 23, 34, 34, 41, 46, 46, 52, 53, 55, 58, 62, 123, 123, 123, 152, 324, 345, 421, 435, 521, 632, 672, 3314, 3426, 5341, 6156, 6234, 6243, 6345, 73416]