So I would like to visualise the boundaries of KD-tree created by scikit https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html
As far as I searched, there is no built-in method to plot this.
What I need is the boundaries of these rectangles that a kd-tree computes.
The library has a method get_arrays()
that apparently returns tree data, index, node data and node bounds.
So node data
must be the points inside a rectangle (a node) and node bounds I guess is the boundary of it?
I tried to plot the data in node bounds. I basically turned the array I got back into a 2d array of 2d points (two columns one for x's, one for y's).
I did a scatter plot.
I am not sure what I see, propably points in the boundaries?
I tried to use shapely polygons , to turn these boundaries into rectangles , and plot them but I did not get what I expected, so I am wondering if I understand well what are these bounds that get_arrays()
returns.
You're right that the fourth array is what you're after as node bounds, and the source code provides a little more detail: the shape is (2, n_nodes, n_features)
, with the first dimension containing min & max for the given feature and node.
from sklearn.datasets import make_blobs
from sklearn.neighbors import KDTree
from matplotlib.patches import Rectangle
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
X, y = make_blobs(random_state=42)
kdt = KDTree(X, leaf_size=10)
tree_data, index, node_data, node_bounds = kdt.get_arrays()
rearranged_bounds = np.transpose(node_bounds, axes=[1, 2, 0])
df = pd.DataFrame({
'x_min': rearranged_bounds[:, 0, 0],
'x_max': rearranged_bounds[:, 0, 1],
'y_min': rearranged_bounds[:, 1, 0],
'y_max': rearranged_bounds[:, 1, 1],
})
fig, ax = plt.subplots()
plt.scatter(X[:, 0], X[:, 1])
for _, row in df.iterrows():
x_min, x_max, y_min, y_max = row
rect = Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, alpha=0.1)
ax.add_patch(rect)
plt.show();