I have a database with many data-points each with an x,y,z coordinate. I want to count the number of points that are within a certain distance to neighboring points. Some points will have a pair that is within a radius R, others will not. I simply want to count the number of pairs within some distance. I could easily write an algorithm to do this but it would not be efficient enough (for I would iterate over every single data point).
This seems like something that must already exist in astropy, scipy, etc. but I cannot seem to find what I am looking for. Is there anything out there that accomplishes this?
As mentioned by @Davis Herring in the comments, an efficient option is a k-d tree.
The k-d tree is an algorithm that avoids the brute-force approach and allows for efficient distance computations* (see bottom of answer for background).
There are several Python implementations of this, one of which is through SciPy
:
SciPy k-d tree in Cython (faster since it uses C/Cython)
You can use this by first constructing a k-d tree for your xyz data:
import numpy as np #for later code
from scipy.spatial import cKDTree
kdtree = cKDTree(xyzData)
Then, you must query the k-d tree with a point point
to compute the distance between point
and its nearest neighbor. The output of this query is the distance NN_dist
between point
and its nearest neighbor and the index NN_idx
of that neighbor. To compute this for all of your points, we need a for loop, but given the k-d tree algorithm, this is much faster than a brute-force computation:
NN_dists = np.zeros(numPoints) #pre-allocate an array to store distances
for i in range(numPoints):
point = xyzData[i]
NN_dist, NN_idx = kdtree.query(point,k=[1])
#Note: 'k' specifies the kth neighbor distance to compute,
#so set k=2 if you end up finding the point as its own "neighbor":
if NN_dist == 0:
NN_dist, NN_idx = targetTree.query(curCoord,k=[2])
NN_dists[i] = NN_dist
(see k-d tree query for more details).
Then, to find the distances that are below some threshold, you could use the built-in utility of NumPy arrays when using comparison operators (like <
):
distanceThres = 10
goodIdx = NN_dists < distanceThres
goodPoints = xyzData[goodIdx]
This will give you the indices goodIdx
and points goodPoints
that are within your specified distance threshold distanceThres
(though you may have to change this code depending on the shape/format of your xyz coordinate data).
*A light background on k-d trees (glossing over fine details -- see references for more): the k-d tree method partitions a dataset in such a way that avoids computing the distance between every single point (i.e., the brute force method). It does this by dividing the dataset into binary space partitions to construct a k-d tree. These partitions are such that a distance computation (e.g., a nearest-neighbor search) can ignore datapoints that are in distant partitions. Additionally, this same k-d tree is reused for each point.
There are a lot of resources online about k-d trees in general. I found these references most helpful when I was learning about this algorithm: Stanford k-d trees or Princeton k-d trees.
Let me know if you have questions -- I had this exact problem myself during an astronomy project, so I may be able to help more.