From 16797f928162176c63936af37579a392cc4756a9 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Feb 2022 05:40:37 +0800 Subject: [PATCH] Ensure models with categorical splits don't use old binary format. --- src/tree/tree_model.cc | 1 + tests/python-gpu/test_gpu_updaters.py | 2 +- tests/python/test_basic_models.py | 23 +++++++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 07e0f1439574..f42624bbeb5a 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -868,6 +868,7 @@ void RegTree::Save(dmlc::Stream* fo) const { CHECK_EQ(param.num_nodes, static_cast(stats_.size())); CHECK_EQ(param.deprecated_num_roots, 1); CHECK_NE(param.num_nodes, 0); + CHECK(!HasCategoricalSplit()) << "Please JSON/UBJSON for saving models with categorical splits."; if (DMLC_IO_NO_ENDIAN_SWAP) { fo->Write(¶m, sizeof(TreeParam)); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 86200f335a47..5c5d19644fdc 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -70,7 +70,7 @@ def test_categorical_32_cat(self): self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") @pytest.mark.skipif(**tm.no_cupy()) - def test_invalid_categorical(self): + def test_invalid_category(self): self.cputest.run_invalid_category("gpu_hist") @pytest.mark.skipif(**tm.no_cupy()) diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index cbb7b1fd9268..9597632021f8 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -381,6 +381,29 @@ def validate_model(parameters): 'objective': 'multi:softmax'} validate_model(parameters) + def test_categorical_model_io(self): + X, y = tm.make_categorical(256, 16, 71, False) + Xy = xgb.DMatrix(X, y, enable_categorical=True) + booster = xgb.train({"tree_method": "approx"}, Xy, num_boost_round=16) + predt_0 = booster.predict(Xy) + + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "model.binary") + with pytest.raises(ValueError, match=r".*JSON/UBJSON.*"): + booster.save_model(path) + + path = os.path.join(tempdir, "model.json") + booster.save_model(path) + booster = xgb.Booster(model_file=path) + predt_1 = booster.predict(Xy) + np.testing.assert_allclose(predt_0, predt_1) + + path = os.path.join(tempdir, "model.ubj") + booster.save_model(path) + booster = xgb.Booster(model_file=path) + predt_1 = booster.predict(Xy) + np.testing.assert_allclose(predt_0, predt_1) + @pytest.mark.skipif(**tm.no_sklearn()) def test_attributes(self): from sklearn.datasets import load_iris