pythonastronomyastropy

Efficient method for counting number of data points inside sphere of fixed radius centered on each data point


I have a database with many data-points each with an x,y,z coordinate. I want to count the number of points that are within a certain distance to neighboring points. Some points will have a pair that is within a radius R, others will not. I simply want to count the number of pairs within some distance. I could easily write an algorithm to do this but it would not be efficient enough (for I would iterate over every single data point).

This seems like something that must already exist in astropy, scipy, etc. but I cannot seem to find what I am looking for. Is there anything out there that accomplishes this?


Solution

  • As mentioned by @Davis Herring in the comments, an efficient option is a k-d tree.

    The k-d tree is an algorithm that avoids the brute-force approach and allows for efficient distance computations* (see bottom of answer for background).

    There are several Python implementations of this, one of which is through SciPy:

    SciPy k-d tree in Cython (faster since it uses C/Cython)

    SciPy k-d tree in pure Python

    You can use this by first constructing a k-d tree for your xyz data:

    import numpy as np  #for later code
    from scipy.spatial import cKDTree
    
    kdtree = cKDTree(xyzData)
    

    Then, you must query the k-d tree with a point point to compute the distance between point and its nearest neighbor. The output of this query is the distance NN_dist between point and its nearest neighbor and the index NN_idx of that neighbor. To compute this for all of your points, we need a for loop, but given the k-d tree algorithm, this is much faster than a brute-force computation:

    NN_dists = np.zeros(numPoints)  #pre-allocate an array to store distances
    for i in range(numPoints):
        point = xyzData[i]
    
        NN_dist, NN_idx = kdtree.query(point,k=[1])
    
        #Note: 'k' specifies the kth neighbor distance to compute, 
        #so set k=2 if you end up finding the point as its own "neighbor":
        if NN_dist == 0:
            NN_dist, NN_idx = targetTree.query(curCoord,k=[2])
        
        NN_dists[i] = NN_dist
    

    (see k-d tree query for more details).

    Then, to find the distances that are below some threshold, you could use the built-in utility of NumPy arrays when using comparison operators (like <):

    distanceThres = 10
    goodIdx = NN_dists < distanceThres
    goodPoints = xyzData[goodIdx]
    

    This will give you the indices goodIdx and points goodPoints that are within your specified distance threshold distanceThres (though you may have to change this code depending on the shape/format of your xyz coordinate data).


    *A light background on k-d trees (glossing over fine details -- see references for more): the k-d tree method partitions a dataset in such a way that avoids computing the distance between every single point (i.e., the brute force method). It does this by dividing the dataset into binary space partitions to construct a k-d tree. These partitions are such that a distance computation (e.g., a nearest-neighbor search) can ignore datapoints that are in distant partitions. Additionally, this same k-d tree is reused for each point.

    There are a lot of resources online about k-d trees in general. I found these references most helpful when I was learning about this algorithm: Stanford k-d trees or Princeton k-d trees.

    Let me know if you have questions -- I had this exact problem myself during an astronomy project, so I may be able to help more.