diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 599a4e9cc6426..cf4a77c9f4bed 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -1053,9 +1053,13 @@ cdef class BinaryTree: if leaf_size < 1: raise ValueError("leaf_size must be greater than or equal to 1") + longest_data = max(len(item) for item in data) + if np.any([len(item) < longest_data for item in data]): + raise ValueError("Input points must have " + "the same number of dimensions") + n_samples = data.shape[0] n_features = data.shape[1] - self.data_arr = np.asarray(data, dtype=DTYPE, order='C') self.leaf_size = leaf_size self.dist_metric = DistanceMetric.get_metric(metric, **kwargs) diff --git a/sklearn/neighbors/tests/test_ball_tree.py b/sklearn/neighbors/tests/test_ball_tree.py index 8da703dbe207d..92fe3ae4f4f39 100644 --- a/sklearn/neighbors/tests/test_ball_tree.py +++ b/sklearn/neighbors/tests/test_ball_tree.py @@ -65,3 +65,11 @@ def test_query_haversine(): assert_array_almost_equal(dist1, dist2) assert_array_almost_equal(ind1, ind2) + + +def test_different_dimension_size(): + X = [(1, 2, 3), (2, 5), (5, 5, 1, 2)] + Y = np.array(X) + msg = ("Input points must have the same number of dimensions") + with pytest.raises(ValueError, match=msg): + BallTree(Y)