From 02224ea8fe03a8676a47c661a20a8b86f9719708 Mon Sep 17 00:00:00 2001 From: arka204 Date: Sun, 31 May 2020 19:14:19 +0200 Subject: [PATCH 1/2] Fixing error when input has different dimensions. --- sklearn/neighbors/_binary_tree.pxi | 18 +++++++++++++++++- sklearn/neighbors/tests/test_ball_tree.py | 9 +++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 599a4e9cc6426..cf69ced4c05d6 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -1053,9 +1053,25 @@ cdef class BinaryTree: if leaf_size < 1: raise ValueError("leaf_size must be greater than or equal to 1") + longest_data = max(len(item) for item in data) + padded_data = [] + padding = False + + for item in data: + if len(item) < longest_data: + item = np.asarray(item) + padded_item = np.zeros(longest_data) + padded_item[:item.shape[0]] = item + padded_data.append(padded_item) + padding = True + else: + padded_data.append(np.asarray(item)) + if padding: + warnings.warn("Not all elements had the same number of dimensions" + " - proceeding after extending those with zeros") + data = np.asarray(padded_data) n_samples = data.shape[0] n_features = data.shape[1] - self.data_arr = np.asarray(data, dtype=DTYPE, order='C') self.leaf_size = leaf_size self.dist_metric = DistanceMetric.get_metric(metric, **kwargs) diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index 8da703dbe207d..e8d90ca36709e 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -65,3 +65,12 @@ def test_query_haversine(): assert_array_almost_equal(dist1, dist2) assert_array_almost_equal(ind1, ind2) + + +def test_different_dimension_size(): + X = [(1, 2, 3), (2, 5), (5, 5, 1, 2)] + Y = np.array(X) + msg = ("Not all elements had the same number of dimensions" + " - proceeding after extending those with zeros") + with pytest.warns(UserWarning, match=msg): + BallTree(Y) From ef6b1e5aebc8363df9600d5ae1cbd58d7048488a Mon Sep 17 00:00:00 2001 From: arka204 Date: Mon, 8 Jun 2020 19:17:45 +0200 Subject: [PATCH 2/2] Changing solution --- sklearn/neighbors/_binary_tree.pxi | 20 ++++---------------- sklearn/neighbors/tests/test_ball_tree.py | 5 ++--- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index cf69ced4c05d6..cf4a77c9f4bed 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -1054,22 +1054,10 @@ cdef class BinaryTree: raise ValueError("leaf_size must be greater than or equal to 1") longest_data = max(len(item) for item in data) - padded_data = [] - padding = False - - for item in data: - if len(item) < longest_data: - item = np.asarray(item) - padded_item = np.zeros(longest_data) - padded_item[:item.shape[0]] = item - padded_data.append(padded_item) - padding = True - else: - padded_data.append(np.asarray(item)) - if padding: - warnings.warn("Not all elements had the same number of dimensions" - " - proceeding after extending those with zeros") - data = np.asarray(padded_data) + if np.any([len(item) < longest_data for item in data]): + raise ValueError("Input points must have " + "the same number of dimensions") + n_samples = data.shape[0] n_features = data.shape[1] self.data_arr = np.asarray(data, dtype=DTYPE, order='C') diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index e8d90ca36709e..92fe3ae4f4f39 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -70,7 +70,6 @@ def test_query_haversine(): def test_different_dimension_size(): X = [(1, 2, 3), (2, 5), (5, 5, 1, 2)] Y = np.array(X) - msg = ("Not all elements had the same number of dimensions" - " - proceeding after extending those with zeros") - with pytest.warns(UserWarning, match=msg): + msg = ("Input points must have the same number of dimensions") + with pytest.raises(ValueError, match=msg): BallTree(Y)