Faster Matching
Introduction
Have you ever faced the issue of having to estimate the closest point to a point in a dataset?, or what are the points inside a $r$ distance from a point ?. I’ve faced this problem a few times while I was doing matching (“approximated matching” or “nearest neighborhood matching”) where I basically need to find all the corresponding closest points between two sets of points (treatment, control).
While estimate the distance between all the points in both groups seems like an easy task, implementing it via “brute-force” can be extremely inefficient $~O(n^{2})$.
Fortunately, there are faster ways of doing this. To speed up nearest neighbor search, we can use data structures called KD-Trees and Ball-Trees. These structures allow us to quickly find the closest point(s) to a query point, without having to compute the distance to every point in the dataset.
KD-Tree and Ball-Tree
KD-Tree and Ball-Tree are data structures used for nearest neighbor search. KD-Tree partitions the dataset into smaller regions or hyperplanes, while Ball-Tree partitions the dataset into smaller regions or nested hyperspheres. These structures allow us to quickly find the closest point(s) to a query point, without having to compute the distance to every point in the dataset.
When implementing these structures, we typically found the following methods:
- Finding the k-closest neighbors of a query point
sklearn.neighbors.KDTree.query
- Finding all the neighbors within a radius $R$ of a query point
sklearn.neighbors.KDTree.query_radius
Both sklearn.KDTree
and sklearn.BallTree
provide similar methods, as do scipy
implementations.
Implementation
To test the efficiency of the sklearn.BallTree
and scipy.cKDTree
methods, we will generate a random dataset with two sets of points and find the nearest neighbor from one set to the other.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# this script compares two tree structures on the search of the k-closest neighbors, a very traditional problem on approximated matching
import numpy as np
from scipy.spatial import cKDTree
from sklearn.neighbors import BallTree
import time
# Generate random points for users_1 (7 million points)
users_1 = np.random.uniform(low=-1, high=1, size=(7_000_000, 2))
users_1[:, 1] = np.random.uniform(low=-2, high=2, size=7_000_000)
# Generate random points for users_2 (10,000 points)
users_2 = np.random.uniform(low=-1, high=1, size=(10_000, 2))
users_2[:, 1] = np.random.uniform(low=-2, high=2, size=10_000)
# Measure time taken by BallTree
start_time = time.time()
tree = BallTree(users_1, leaf_size=15, metric='haversine')
indices_balltree = tree.query(users_2, k=1, return_distance=False)
end_time = time.time()
balltree_time = end_time - start_time
print(f"BallTree took {balltree_time} seconds")
# Measure time taken by cKDTree
start_time = time.time()
tree = cKDTree(users_1)
indices_ckdtree = tree.query(users_2, k=1)
end_time = time.time()
ckdtree_time = end_time - start_time
print(f"cKDTree took {ckdtree_time} seconds")
# Compare the results
print(f"Time difference: {abs(balltree_time - ckdtree_time)} seconds")
This outputs (Apple M1 pro)
1
2
3
BallTree took 10.020148992538452 seconds
cKDTree took 3.3549211025238037 seconds
Time difference: 6.665227890014648 seconds
Conclusion
In summary, KD-Trees and Ball-Trees are powerful data structures that can be used to speed up nearest neighbor search. These structures allow us to quickly find the closest point(s) to a query point, without having to compute the distance to every point in the dataset. Efficient algorithms like this are crucial when we do approximate-matching, bias correction, and even when doing synthetic control.