python-3.xmatplotlibscikit-learnshapelykdtree

What exactly is stored into boundaries of method get_arrays() in scikit KDtree?


So I would like to visualise the boundaries of KD-tree created by scikit https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html
As far as I searched, there is no built-in method to plot this.
What I need is the boundaries of these rectangles that a kd-tree computes.
The library has a method get_arrays() that apparently returns tree data, index, node data and node bounds.
So node data must be the points inside a rectangle (a node) and node bounds I guess is the boundary of it?
I tried to plot the data in node bounds. I basically turned the array I got back into a 2d array of 2d points (two columns one for x's, one for y's).
I did a scatter plot.
I am not sure what I see, propably points in the boundaries? I tried to use shapely polygons , to turn these boundaries into rectangles , and plot them but I did not get what I expected, so I am wondering if I understand well what are these bounds that get_arrays() returns.


Solution

  • You're right that the fourth array is what you're after as node bounds, and the source code provides a little more detail: the shape is (2, n_nodes, n_features), with the first dimension containing min & max for the given feature and node.

    from sklearn.datasets import make_blobs
    from sklearn.neighbors import KDTree
    from matplotlib.patches import Rectangle
    from matplotlib import pyplot as plt
    import numpy as np
    import pandas as pd
    
    X, y = make_blobs(random_state=42)
    kdt = KDTree(X, leaf_size=10)
    
    tree_data, index, node_data, node_bounds = kdt.get_arrays()
    rearranged_bounds = np.transpose(node_bounds, axes=[1, 2, 0])
    df = pd.DataFrame({
        'x_min': rearranged_bounds[:, 0, 0],
        'x_max': rearranged_bounds[:, 0, 1],
        'y_min': rearranged_bounds[:, 1, 0],
        'y_max': rearranged_bounds[:, 1, 1],
    })
    
    fig, ax = plt.subplots()
    plt.scatter(X[:, 0], X[:, 1])
    for _, row in df.iterrows():
        x_min, x_max, y_min, y_max = row
        rect = Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, alpha=0.1)
        ax.add_patch(rect)
    plt.show();
    

    scatter plot plus rectangles representing the kd-tree nodes