I have 2 numpy arrays: a (smaller) array consisting of int values, b (larger) array consisting of float values. The idea is that b contains float values which are close to some int values in a. For example, as a toy example, I have the code below. The arrays aren't sorted like this and I use np.sort() on both a and b to get:
a = np.array([35, 11, 48, 20, 13, 31, 49])
b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])
For each element in a, there are multiple float values in b and the goal is to get the closest value in b for each element in a.
To naively achieve this, I use a for loop:
for e in a:
idx = np.abs(e - b).argsort()
print(f"{e} has nearest match = {b[idx[0]]:.4f}")
'''
11 has nearest match = 11.2890
13 has nearest match = 12.8700
20 has nearest match = 20.0500
31 has nearest match = 31.0300
35 has nearest match = 34.9900
48 has nearest match = 48.1000
49 has nearest match = 49.2000
'''
There can be values in a not existing in b and vice-versa.
a.size = 2040 and b.size = 1041901
To construct a KD-Tree:
# Construct KD-Tree using and query nearest neighnor-
kd_tree = KDTree(data = np.expand_dims(a, 1))
dist_nn, idx_nn = kd_tree.query(x = np.expand_dims(b, 1), k = [1])
dist.shape, idx.shape
# ((19, 1), (19, 1))
To get nearest neighbor in 'b' with respect to 'a', I do:
b[idx]
'''
array([[10.7 ],
[10.7 ],
[10.7 ],
[11.289],
[11.289],
[11.289],
[11.3 ],
[11.3 ],
[11.3 ],
[12.32 ],
[12.32 ],
[12.32 ],
[12.87 ],
[12.87 ],
[12.87 ],
[12.87 ],
[13.5 ],
[13.5 ],
[18.78 ]])
'''
Problems:
What's going wrong?
If you want to get the closest element for each entry in a
, you build your KD-Tree for b
and then query a
.
from scipy import spatial
kd = spatial.KDTree(b[:,np.newaxis])
distances, indices = kd.query(a[:, np.newaxis])
values = b[indices]
for ai, bi in zip(a, values):
print(f"{ai} has nearest match = {bi:.4f}")
35 has nearest match = 34.9900
11 has nearest match = 11.2890
48 has nearest match = 48.1000
20 has nearest match = 20.0500
13 has nearest match = 12.8700
31 has nearest match = 31.0300
49 has nearest match = 49.2000