pythonpython-3.xkdtree

Construction of KD-Tree


I'm trying to construct a KD-Tree, but I'm getting an error where the root node doesn't have any children which it should have.

I think it's a problem with the recursion, but I can't figure out why.

A minimal reproducible sample.

import numpy as np

class Node():
    def __init__(self, point, min_ax, max_ax, axis, left=None, right=None):
        self.point = point
        self.min_ax = min_ax
        self.max_ax = max_ax
        self.axis = axis

        self.left = None
        self.right = None

def construct_kd_tree(points, axis=0):
    if len(points) == 0:
        return None

    # sort triangles with the first axis of the center
    vals = list(sorted(points, key=lambda x: x[axis]))

    median = len(points) // 2

    left = construct_kd_tree(vals[:median])
    right = construct_kd_tree(vals[median+1:])

    # print(left,right)

    return Node(vals[median],
                vals[0][axis],
                vals[-1][axis],
                axis,
                left=left,
                right=right)


points = np.random.rand(10,3)
node = construct_kd_tree(points)

print(node.left)  # None
print(node.right) # None

Solution

  • See if this solves your problem:

    class Node():
        def __init__(self, point, min_ax, max_ax, axis, left, right):
            self.point = point
            self.min_ax = min_ax
            self.max_ax = max_ax
            self.axis = axis
            self.left = left
            self.right = right
            
    def construct_kd_tree(points, axis=0):
        if len(points) == 0:
            return None
    
        vals = sorted(points, key=lambda x: x[axis])
        median = len(points) // 2
    
        return Node(vals[median], vals[0][axis], vals[-1][axis], axis,
                    left=construct_kd_tree(vals[:median], (axis + 1) % len(points[0])),
                    right=construct_kd_tree(vals[median+1:], (axis + 1) % len(points[0])))
    
    # Example usage
    points = [(1, 2), (5, 3), (8, 1), (3, 6)]
    kd_tree = construct_kd_tree(points)