Skip to content

Commit

Permalink
MNT/TST Replace boston by synthetic dataset in ensemble test_forest (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
lucyleeow authored and jayzed82 committed Oct 22, 2020
1 parent 42e5713 commit 9399227
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions sklearn/ensemble/tests/test_forest.py
Expand Up @@ -72,12 +72,9 @@
iris.data = iris.data[perm]
iris.target = iris.target[perm]

# also load the boston dataset
# and randomly permute it
boston = datasets.load_boston()
perm = rng.permutation(boston.target.size)
boston.data = boston.data[perm]
boston.target = boston.target[perm]
# Make regression dataset
X_reg, y_reg = datasets.make_regression(n_samples=500, n_features=10,
random_state=1)

# also make a hastie_10_2 dataset
hastie_X, hastie_y = datasets.make_hastie_10_2(n_samples=20, random_state=1)
Expand Down Expand Up @@ -159,29 +156,29 @@ def test_iris(name, criterion):
check_iris_criterion(name, criterion)


def check_boston_criterion(name, criterion):
# Check consistency on dataset boston house prices.
def check_regression_criterion(name, criterion):
# Check consistency on regression dataset.
ForestRegressor = FOREST_REGRESSORS[name]

reg = ForestRegressor(n_estimators=5, criterion=criterion,
random_state=1)
reg.fit(boston.data, boston.target)
score = reg.score(boston.data, boston.target)
assert score > 0.94, ("Failed with max_features=None, criterion %s "
reg.fit(X_reg, y_reg)
score = reg.score(X_reg, y_reg)
assert score > 0.93, ("Failed with max_features=None, criterion %s "
"and score = %f" % (criterion, score))

reg = ForestRegressor(n_estimators=5, criterion=criterion,
max_features=6, random_state=1)
reg.fit(boston.data, boston.target)
score = reg.score(boston.data, boston.target)
assert score > 0.95, ("Failed with max_features=6, criterion %s "
reg.fit(X_reg, y_reg)
score = reg.score(X_reg, y_reg)
assert score > 0.92, ("Failed with max_features=6, criterion %s "
"and score = %f" % (criterion, score))


@pytest.mark.parametrize('name', FOREST_REGRESSORS)
@pytest.mark.parametrize('criterion', ("mse", "mae", "friedman_mse"))
def test_boston(name, criterion):
check_boston_criterion(name, criterion)
def test_regression(name, criterion):
check_regression_criterion(name, criterion)


def check_regressor_attributes(name):
Expand Down Expand Up @@ -384,12 +381,10 @@ def check_oob_score(name, X, y, n_estimators=20):
n_samples = X.shape[0]
est.fit(X[:n_samples // 2, :], y[:n_samples // 2])
test_score = est.score(X[n_samples // 2:, :], y[n_samples // 2:])

if name in FOREST_CLASSIFIERS:
assert abs(test_score - est.oob_score_) < 0.1
else:
assert test_score > est.oob_score_
assert est.oob_score_ > .8
assert abs(test_score - est.oob_score_) < 0.1

# Check warning if not enough estimators
with np.errstate(divide="ignore", invalid="ignore"):
Expand All @@ -411,10 +406,10 @@ def test_oob_score_classifiers(name):

@pytest.mark.parametrize('name', FOREST_REGRESSORS)
def test_oob_score_regressors(name):
check_oob_score(name, boston.data, boston.target, 50)
check_oob_score(name, X_reg, y_reg, 50)

# csc matrix
check_oob_score(name, csc_matrix(boston.data), boston.target, 50)
check_oob_score(name, csc_matrix(X_reg), y_reg, 50)


def check_oob_score_raise_error(name):
Expand Down Expand Up @@ -475,11 +470,13 @@ def check_parallel(name, X, y):
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS)
def test_parallel(name):
if name in FOREST_CLASSIFIERS:
ds = iris
X = iris.data
y = iris.target
elif name in FOREST_REGRESSORS:
ds = boston
X = X_reg
y = y_reg

check_parallel(name, ds.data, ds.target)
check_parallel(name, X, y)


def check_pickle(name, X, y):
Expand All @@ -500,11 +497,13 @@ def check_pickle(name, X, y):
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS)
def test_pickle(name):
if name in FOREST_CLASSIFIERS:
ds = iris
X = iris.data
y = iris.target
elif name in FOREST_REGRESSORS:
ds = boston
X = X_reg
y = y_reg

check_pickle(name, ds.data[::2], ds.target[::2])
check_pickle(name, X[::2], y[::2])


def check_multioutput(name):
Expand Down

0 comments on commit 9399227

Please sign in to comment.