Post

Faster Matching

Introduction

Have you ever faced the issue of having to estimate the closest point to a point in a dataset?, or what are the points inside a $r$ distance from a point ?. I’ve faced this problem a few times while I was doing matching (“approximated matching” or “nearest neighborhood matching”) where I basically need to find all the corresponding closest points between two sets of points (treatment, control).

While estimate the distance between all the points in both groups seems like an easy task, implementing it via “brute-force” can be extremely inefficient $~O(n^{2})$.

Fortunately, there are faster ways of doing this. To speed up nearest neighbor search, we can use data structures called KD-Trees and Ball-Trees. These structures allow us to quickly find the closest point(s) to a query point, without having to compute the distance to every point in the dataset.

KD-Tree and Ball-Tree

KD-Tree and Ball-Tree are data structures used for nearest neighbor search. KD-Tree partitions the dataset into smaller regions or hyperplanes, while Ball-Tree partitions the dataset into smaller regions or nested hyperspheres. These structures allow us to quickly find the closest point(s) to a query point, without having to compute the distance to every point in the dataset.

When implementing these structures, we typically found the following methods:

  1. Finding the k-closest neighbors of a query point sklearn.neighbors.KDTree.query
  2. Finding all the neighbors within a radius $R$ of a query point sklearn.neighbors.KDTree.query_radius

Both sklearn.KDTree and sklearn.BallTree provide similar methods, as do scipy implementations.

Implementation

To test the efficiency of the sklearn.BallTree and scipy.cKDTree methods, we will generate a random dataset with two sets of points and find the nearest neighbor from one set to the other.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# this script compares two tree structures on the search of the k-closest neighbors, a very traditional problem on approximated matching 

import numpy as np
from scipy.spatial import cKDTree
from sklearn.neighbors import BallTree
import time

# Generate random points for users_1 (7 million points)
users_1 = np.random.uniform(low=-1, high=1, size=(7_000_000, 2))  
users_1[:, 1] = np.random.uniform(low=-2, high=2, size=7_000_000)  

# Generate random points for users_2 (10,000 points)
users_2 = np.random.uniform(low=-1, high=1, size=(10_000, 2))  
users_2[:, 1] = np.random.uniform(low=-2, high=2, size=10_000)  


# Measure time taken by BallTree
start_time = time.time()
tree = BallTree(users_1, leaf_size=15, metric='haversine')
indices_balltree = tree.query(users_2, k=1, return_distance=False)
end_time = time.time()
balltree_time = end_time - start_time
print(f"BallTree took {balltree_time} seconds")

# Measure time taken by cKDTree
start_time = time.time()
tree = cKDTree(users_1)
indices_ckdtree = tree.query(users_2, k=1)
end_time = time.time()
ckdtree_time = end_time - start_time
print(f"cKDTree took {ckdtree_time} seconds")

# Compare the results
print(f"Time difference: {abs(balltree_time - ckdtree_time)} seconds")

This outputs (Apple M1 pro)

1
2
3
BallTree took 10.020148992538452 seconds
cKDTree took 3.3549211025238037 seconds
Time difference: 6.665227890014648 seconds

Conclusion

In summary, KD-Trees and Ball-Trees are powerful data structures that can be used to speed up nearest neighbor search. These structures allow us to quickly find the closest point(s) to a query point, without having to compute the distance to every point in the dataset. Efficient algorithms like this are crucial when we do approximate-matching, bias correction, and even when doing synthetic control.

This post is licensed under CC BY 4.0 by the author.