forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_ball_tree.py
76 lines (59 loc) · 2.43 KB
/
test_ball_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import itertools
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from sklearn.neighbors._ball_tree import BallTree
from sklearn.neighbors import DistanceMetric
from sklearn.utils import check_random_state
rng = np.random.RandomState(10)
V_mahalanobis = rng.rand(3, 3)
V_mahalanobis = np.dot(V_mahalanobis, V_mahalanobis.T)
DIMENSION = 3
METRICS = {'euclidean': {},
'manhattan': {},
'minkowski': dict(p=3),
'chebyshev': {},
'seuclidean': dict(V=rng.random_sample(DIMENSION)),
'wminkowski': dict(p=3, w=rng.random_sample(DIMENSION)),
'mahalanobis': dict(V=V_mahalanobis)}
DISCRETE_METRICS = ['hamming',
'canberra',
'braycurtis']
BOOLEAN_METRICS = ['matching', 'jaccard', 'dice', 'kulsinski',
'rogerstanimoto', 'russellrao', 'sokalmichener',
'sokalsneath']
def brute_force_neighbors(X, Y, k, metric, **kwargs):
D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
ind = np.argsort(D, axis=1)[:, :k]
dist = D[np.arange(Y.shape[0])[:, None], ind]
return dist, ind
@pytest.mark.parametrize('metric',
itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS))
def test_ball_tree_query_metrics(metric):
rng = check_random_state(0)
if metric in BOOLEAN_METRICS:
X = rng.random_sample((40, 10)).round(0)
Y = rng.random_sample((10, 10)).round(0)
elif metric in DISCRETE_METRICS:
X = (4 * rng.random_sample((40, 10))).round(0)
Y = (4 * rng.random_sample((10, 10))).round(0)
k = 5
bt = BallTree(X, leaf_size=1, metric=metric)
dist1, ind1 = bt.query(Y, k)
dist2, ind2 = brute_force_neighbors(X, Y, k, metric)
assert_array_almost_equal(dist1, dist2)
def test_query_haversine():
rng = check_random_state(0)
X = 2 * np.pi * rng.random_sample((40, 2))
bt = BallTree(X, leaf_size=1, metric='haversine')
dist1, ind1 = bt.query(X, k=5)
dist2, ind2 = brute_force_neighbors(X, X, k=5, metric='haversine')
assert_array_almost_equal(dist1, dist2)
assert_array_almost_equal(ind1, ind2)
def test_different_dimension_size():
X = [(1, 2, 3), (2, 5), (5, 5, 1, 2)]
Y = np.array(X)
msg = ("Not all elements had the same number of dimensions"
" - proceeding after extending those with zeros")
with pytest.warns(UserWarning, match=msg):
BallTree(Y)