From 4c04f4023d23e238324ec4c844d90dd6a45420aa Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 18:46:12 +0800 Subject: [PATCH 01/10] Error on get_split_value_histogram when feature is categorical. --- python-package/xgboost/core.py | 26 ++++++++++++++++++++++++-- tests/python/test_with_sklearn.py | 7 +++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f4fe1b3967fe..d521293e3836 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2381,9 +2381,31 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, nph = np.column_stack((nph[1][1:], nph[0])) nph = nph[nph[:, 1] > 0] + if nph.size == 0: + ft = self.feature_types + fn = self.feature_names + if fn is None: + # Let xgboost generate the feature names. + fn = self.get_fscore().keys() + index = -1 + try: + index = fn.index(feature) + feature_t = ft[index] + except (ValueError, AttributeError, TypeError): + # None.index: attr err, None[0]: type err, fn.index(-1): value err + feature_t = None + pass + if feature_t == "categorical": + raise ValueError( + "Split value historgam doesn't support categorical split." + ) + if as_pandas and PANDAS_INSTALLED: return DataFrame(nph, columns=['SplitValue', 'Count']) if as_pandas and not PANDAS_INSTALLED: - sys.stderr.write( - "Returning histogram as ndarray (as_pandas == True, but pandas is not installed).") + warnings.warn( + "Returning histogram as ndarray" + " (as_pandas == True, but pandas is not installed).", + UserWarning + ) return nph diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index d44d0e3af571..78bca9d96ed9 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -628,6 +628,13 @@ def test_split_value_histograms(): assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2 + X, y = tm.make_categorical(1000, 10, 13, False) + reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) + reg.fit(X, y) + + with pytest.raises(ValueError, match="doesn't"): + reg.get_booster().get_split_value_histogram("3", bins=5) + def test_sklearn_random_state(): clf = xgb.XGBClassifier(random_state=402) From 346d940c727c7e19ccc4847bcd8d184722af23ac Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 19:46:09 +0800 Subject: [PATCH 02/10] Support df. --- python-package/xgboost/core.py | 20 +++++++++++++++++--- tests/python-gpu/test_gpu_parse_tree.py | 15 +++++++++++++++ tests/python/test_parse_tree.py | 1 + 3 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 tests/python-gpu/test_gpu_parse_tree.py diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d521293e3836..35473122fefa 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2251,6 +2251,7 @@ def trees_to_dataframe(self, fmap=''): node_ids = [] fids = [] splits = [] + categories = [] y_directs = [] n_directs = [] missings = [] @@ -2275,6 +2276,7 @@ def trees_to_dataframe(self, fmap=''): node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0])) fids.append('Leaf') splits.append(float('NAN')) + categories.append(float('NAN')) y_directs.append(float('NAN')) n_directs.append(float('NAN')) missings.append(float('NAN')) @@ -2284,14 +2286,26 @@ def trees_to_dataframe(self, fmap=''): else: # parse string fid = arr[1].split(']') - parse = fid[0].split('<') + if fid[0].find("<") != -1: + # numerical + parse = fid[0].split('<') + splits.append(float(parse[1])) + categories.append(None) + elif fid[0].find(":{") != -1: + # categorical + parse = fid[0].split(":") + cats = parse[1][1:-1] # strip the {} + cats = cats.split(",") + splits.append(float("NAN")) + categories.append(cats if cats else None) + else: + raise ValueError("Failed to parse model text dump.") stats = re.split('=|,', fid[1]) # append to lists tree_ids.append(i) node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0])) fids.append(parse[0]) - splits.append(float(parse[1])) str_i = str(i) y_directs.append(str_i + '-' + stats[1]) n_directs.append(str_i + '-' + stats[3]) @@ -2303,7 +2317,7 @@ def trees_to_dataframe(self, fmap=''): df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids, 'Feature': fids, 'Split': splits, 'Yes': y_directs, 'No': n_directs, 'Missing': missings, 'Gain': gains, - 'Cover': covers}) + 'Cover': covers, "Categories": categories}) if callable(getattr(df, 'sort_values', None)): # pylint: disable=no-member diff --git a/tests/python-gpu/test_gpu_parse_tree.py b/tests/python-gpu/test_gpu_parse_tree.py new file mode 100644 index 000000000000..47396abd2dfe --- /dev/null +++ b/tests/python-gpu/test_gpu_parse_tree.py @@ -0,0 +1,15 @@ +import sys +import xgboost as xgb + +sys.path.append("tests/python") +import testing as tm + + +def test_tree_to_df_categorical(): + X, y = tm.make_categorical(100, 10, 31, False) + Xy = xgb.DMatrix(X, y, enable_categorical=True) + booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10) + df = booster.trees_to_dataframe() + for _, x in df.iterrows(): + if x["Feature"] != "Leaf": + assert len(x["Categories"]) == 1 diff --git a/tests/python/test_parse_tree.py b/tests/python/test_parse_tree.py index 90180cf6a094..2fc8cb314827 100644 --- a/tests/python/test_parse_tree.py +++ b/tests/python/test_parse_tree.py @@ -41,6 +41,7 @@ def test_trees_to_dataframe(self): # method being tested df = bst.trees_to_dataframe() + print(df) # test for equality of gains gain_from_df = df[df.Feature != 'Leaf'][['Gain']].sum() assert np.allclose(gain_from_dump, gain_from_df) From 672cd2beeb314bb7dafed976df561d43c922de59 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 20:06:41 +0800 Subject: [PATCH 03/10] Singular. --- python-package/xgboost/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 35473122fefa..1274a334e26f 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2317,7 +2317,7 @@ def trees_to_dataframe(self, fmap=''): df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids, 'Feature': fids, 'Split': splits, 'Yes': y_directs, 'No': n_directs, 'Missing': missings, 'Gain': gains, - 'Cover': covers, "Categories": categories}) + 'Cover': covers, "Category": categories}) if callable(getattr(df, 'sort_values', None)): # pylint: disable=no-member From 70eb07732d2fcc9fa1f6127304211dd1e4a85659 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 20:08:28 +0800 Subject: [PATCH 04/10] cleanup. --- python-package/xgboost/core.py | 3 +-- tests/python/test_parse_tree.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 1274a334e26f..514810a2eb76 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2225,7 +2225,7 @@ def get_score( results[feat] = float(score) return results - def trees_to_dataframe(self, fmap=''): + def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements """Parse a boosted tree model text dump into a pandas DataFrame structure. This feature is only defined when the decision tree model is chosen as base @@ -2408,7 +2408,6 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, except (ValueError, AttributeError, TypeError): # None.index: attr err, None[0]: type err, fn.index(-1): value err feature_t = None - pass if feature_t == "categorical": raise ValueError( "Split value historgam doesn't support categorical split." diff --git a/tests/python/test_parse_tree.py b/tests/python/test_parse_tree.py index 2fc8cb314827..90180cf6a094 100644 --- a/tests/python/test_parse_tree.py +++ b/tests/python/test_parse_tree.py @@ -41,7 +41,6 @@ def test_trees_to_dataframe(self): # method being tested df = bst.trees_to_dataframe() - print(df) # test for equality of gains gain_from_df = df[df.Feature != 'Leaf'][['Gain']].sum() assert np.allclose(gain_from_dump, gain_from_df) From 862909ca5a25e8fe8c9bfa5ff81011d8146888f6 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 20:16:46 +0800 Subject: [PATCH 05/10] fix test. --- tests/python-gpu/test_gpu_parse_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-gpu/test_gpu_parse_tree.py b/tests/python-gpu/test_gpu_parse_tree.py index 47396abd2dfe..dac4d1dce2ac 100644 --- a/tests/python-gpu/test_gpu_parse_tree.py +++ b/tests/python-gpu/test_gpu_parse_tree.py @@ -12,4 +12,4 @@ def test_tree_to_df_categorical(): df = booster.trees_to_dataframe() for _, x in df.iterrows(): if x["Feature"] != "Leaf": - assert len(x["Categories"]) == 1 + assert len(x["Category"]) == 1 From 1bf07317c403a820d5af0ea95310af902d2b5b98 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 20:36:07 +0800 Subject: [PATCH 06/10] Unused code. --- python-package/xgboost/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 514810a2eb76..4e6e5407a8f9 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2401,7 +2401,6 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, if fn is None: # Let xgboost generate the feature names. fn = self.get_fscore().keys() - index = -1 try: index = fn.index(feature) feature_t = ft[index] From 7c76d74048c7385b7b06ede3e29bb04c30975b6c Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 21:04:08 +0800 Subject: [PATCH 07/10] Fix test. --- tests/python-gpu/test_gpu_parse_tree.py | 10 ++++++++++ tests/python/test_with_sklearn.py | 7 ------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/python-gpu/test_gpu_parse_tree.py b/tests/python-gpu/test_gpu_parse_tree.py index dac4d1dce2ac..8033fb9852d1 100644 --- a/tests/python-gpu/test_gpu_parse_tree.py +++ b/tests/python-gpu/test_gpu_parse_tree.py @@ -1,4 +1,5 @@ import sys +import pytest import xgboost as xgb sys.path.append("tests/python") @@ -13,3 +14,12 @@ def test_tree_to_df_categorical(): for _, x in df.iterrows(): if x["Feature"] != "Leaf": assert len(x["Category"]) == 1 + + +def test_split_value_histograms(): + X, y = tm.make_categorical(1000, 10, 13, False) + reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) + reg.fit(X, y) + + with pytest.raises(ValueError, match="doesn't"): + reg.get_booster().get_split_value_histogram("3", bins=5) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 78bca9d96ed9..d44d0e3af571 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -628,13 +628,6 @@ def test_split_value_histograms(): assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2 - X, y = tm.make_categorical(1000, 10, 13, False) - reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) - reg.fit(X, y) - - with pytest.raises(ValueError, match="doesn't"): - reg.get_booster().get_split_value_histogram("3", bins=5) - def test_sklearn_random_state(): clf = xgb.XGBClassifier(random_state=402) From d13d334809c67c8bc961cde508175d804bf37157 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 26 Jun 2021 02:49:30 +0800 Subject: [PATCH 08/10] Note about the flaky test. --- tests/python-gpu/test_gpu_updaters.py | 46 +++++++++++++++++---------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 3c3a7e045058..8eb492895405 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -32,15 +32,14 @@ def train_result(param, dmat, num_rounds): class TestGPUUpdaters: - @given(parameter_strategy, strategies.integers(1, 20), - tm.dataset_strategy) + @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) @settings(deadline=None) def test_gpu_hist(self, param, num_rounds, dataset): - param['tree_method'] = 'gpu_hist' + param["tree_method"] = "gpu_hist" param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) note(result) - assert tm.non_increasing(result['train'][dataset.metric]) + assert tm.non_increasing(result["train"][dataset.metric]) def run_categorical_basic(self, rows, cols, rounds, cats): onehot, label = tm.make_categorical(rows, cols, cats, True) @@ -49,25 +48,40 @@ def run_categorical_basic(self, rows, cols, rounds, cats): by_etl_results = {} by_builtin_results = {} - parameters = {'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor'} + parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"} m = xgb.DMatrix(onehot, label, enable_categorical=True) - xgb.train(parameters, m, - num_boost_round=rounds, - evals=[(m, 'Train')], evals_result=by_etl_results) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_etl_results, + ) m = xgb.DMatrix(cat, label, enable_categorical=True) - xgb.train(parameters, m, - num_boost_round=rounds, - evals=[(m, 'Train')], evals_result=by_builtin_results) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_builtin_results, + ) + + # There are guidelines on how to specify tolerance based on considering output as + # random variables. But in here the tree construction is extremely sensitive to + # floating point errors. An 1e-5 error in a histogram bin can lead to an entirely + # different tree. So even though the test is quite lenient, hypothesis can still + # pick up falsifying examples from time to time. np.testing.assert_allclose( - np.array(by_etl_results['Train']['rmse']), - np.array(by_builtin_results['Train']['rmse']), - rtol=1e-3) - assert tm.non_increasing(by_builtin_results['Train']['rmse']) + np.array(by_etl_results["Train"]["rmse"]), + np.array(by_builtin_results["Train"]["rmse"]), + rtol=1e-3, + ) + assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) @given(strategies.integers(10, 400), strategies.integers(3, 8), - strategies.integers(1, 5), strategies.integers(4, 7)) + strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical(self, rows, cols, rounds, cats): From 49ae2c4b8b60d6e1f5c2e585be2b18ad45502876 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 26 Jun 2021 16:03:51 +0800 Subject: [PATCH 09/10] full range. --- python-package/xgboost/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 4e6e5407a8f9..61356fc52e24 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2400,7 +2400,7 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, fn = self.feature_names if fn is None: # Let xgboost generate the feature names. - fn = self.get_fscore().keys() + fn = ["f{0}".format(i) for i in range(self.num_features())] try: index = fn.index(feature) feature_t = ft[index] From 470380652b479f458766408881abc42acab1036e Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 26 Jun 2021 16:12:34 +0800 Subject: [PATCH 10/10] Small fix to test. --- tests/python-gpu/test_gpu_updaters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 8eb492895405..11140a7083dc 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -50,7 +50,7 @@ def run_categorical_basic(self, rows, cols, rounds, cats): parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"} - m = xgb.DMatrix(onehot, label, enable_categorical=True) + m = xgb.DMatrix(onehot, label, enable_categorical=False) xgb.train( parameters, m,