Skip to content

Commit

Permalink
Release patch release 3.9.1 (#520)
Browse files Browse the repository at this point in the history
* Release patch release 3.9.1

Apply patch: Fix loading from XGBoost 2.0 dev (#509)

* Removed deprecated parameter size_leaf_vector

* Handle iteration_indptr field in JSON

* Additional fix

* More fixes

* [CI] Use XGBoost 2.0

* Formatting check
  • Loading branch information
hcho3 committed Sep 13, 2023
1 parent d63e48c commit 346d925
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 67 deletions.
1 change: 0 additions & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@ sphinx:
python:
install:
- requirements: docs/requirements.txt
system_packages: true
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_policy(SET CMP0091 NEW)
set(CMAKE_FIND_NO_INSTALL_PREFIX TRUE FORCE)
cmake_minimum_required (VERSION 3.16)
project(treelite LANGUAGES CXX C VERSION 3.9.0)
project(treelite LANGUAGES CXX C VERSION 3.9.1)

# check MSVC version
if(MSVC)
Expand Down
2 changes: 1 addition & 1 deletion ops/conda_env/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ dependencies:
- llvm-openmp
- cython
- lightgbm
- xgboost
- cpplint
- pylint
- awscli
- pip
- pip:
- cibuildwheel
- xgboost>=2.0
4 changes: 2 additions & 2 deletions ops/cpp-python-coverage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ echo "##[section]Building Treelite..."
mkdir build/
cd build/
cmake .. -DTEST_COVERAGE=ON -DCMAKE_BUILD_TYPE=Debug -DBUILD_CPP_TEST=ON -GNinja
ninja
ninja install -v
cd ..

echo "##[section]Running Google C++ tests..."
./build/treelite_cpp_test

echo "##[section]Build Cython extension..."
cd tests/cython
python setup.py build_ext --inplace
pip install -vvv .
cd ../..

echo "##[section]Running Python integration tests..."
Expand Down
2 changes: 1 addition & 1 deletion python/treelite/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.9.0
3.9.1
2 changes: 1 addition & 1 deletion runtime/java/treelite4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>ml.dmlc</groupId>
<artifactId>treelite4j</artifactId>
<version>3.9.0</version>
<version>3.9.1</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand Down
2 changes: 1 addition & 1 deletion runtime/python/treelite_runtime/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.9.0
3.9.1
10 changes: 7 additions & 3 deletions src/frontend/xgboost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class XGBTree {
nodes[nodes[nid].cleft() ].set_parent(nid, true);
nodes[nodes[nid].cright()].set_parent(nid, false);
}
inline void Load(PeekableInputStream* fi) {
inline void Load(PeekableInputStream* fi, LearnerModelParam const& mparam) {
TREELITE_CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
<< "Ill-formed XGBoost model file: can't read TreeParam";
TREELITE_CHECK_GT(param.num_nodes, 0)
Expand All @@ -309,13 +309,17 @@ class XGBTree {
TREELITE_CHECK_EQ(fi->Read(stats.data(), sizeof(NodeStat) * stats.size()),
sizeof(NodeStat) * stats.size())
<< "Ill-formed XGBoost model file: cannot read specified number of nodes";
if (param.size_leaf_vector != 0) {
if (param.size_leaf_vector != 0 && mparam.major_version < 2) {
uint64_t len;
TREELITE_CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
<< "Ill-formed XGBoost model file";
if (len > 0) {
CONSUME_BYTES(fi, sizeof(bst_float) * len);
}
} else if (mparam.major_version == 2) {
TREELITE_CHECK_EQ(param.size_leaf_vector, 1)
<< "Multi-target models are not supported with binary serialization. "
<< "Please save the XGBoost model using the JSON format.";
}
TREELITE_CHECK_EQ(param.num_roots, 1)
<< "Invalid XGBoost model file: treelite does not support trees "
Expand Down Expand Up @@ -378,7 +382,7 @@ inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
<< "Invalid XGBoost model file: num_trees must be 0 or greater";
for (int i = 0; i < gbm_param_.num_trees; ++i) {
xgb_trees_.emplace_back();
xgb_trees_.back().Load(fp.get());
xgb_trees_.back().Load(fp.get(), mparam_);
}
if (mparam_.major_version < 1 || (mparam_.major_version == 1 && mparam_.minor_version < 6)) {
// In XGBoost 1.6, num_roots is used as num_parallel_tree, so don't check
Expand Down
12 changes: 8 additions & 4 deletions src/frontend/xgboost_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ bool GBTreeModelHandler::StartArray() {
return (push_key_handler<ArrayHandler<treelite::Tree<float, float>, RegTreeHandler>,
std::vector<treelite::Tree<float, float>>>(
"trees", output.model->trees) ||
push_key_handler<ArrayHandler<int>, std::vector<int>>("tree_info", output.tree_info));
push_key_handler<ArrayHandler<int>, std::vector<int>>("tree_info", output.tree_info) ||
push_key_handler<IgnoreHandler>("iteration_indptr"));
}

bool GBTreeModelHandler::StartObject() {
Expand All @@ -377,7 +378,8 @@ bool GBTreeModelHandler::StartObject() {
}

bool GBTreeModelHandler::is_recognized_key(const std::string& key) {
return (key == "trees" || key == "tree_info" || key == "gbtree_model_param");
return (key == "trees" || key == "tree_info" || key == "gbtree_model_param"
|| key == "iteration_indptr");
}

/******************************************************************************
Expand Down Expand Up @@ -460,7 +462,8 @@ bool ObjectiveHandler::StartObject() {
push_key_handler<IgnoreHandler>("lambda_rank_param") ||
push_key_handler<IgnoreHandler>("aft_loss_param") ||
push_key_handler<IgnoreHandler>("pseduo_huber_param") ||
push_key_handler<IgnoreHandler>("pseudo_huber_param"));
push_key_handler<IgnoreHandler>("pseudo_huber_param") ||
push_key_handler<IgnoreHandler>("lambdarank_param"));
}

bool ObjectiveHandler::String(const char *str, std::size_t length, bool) {
Expand All @@ -474,7 +477,8 @@ bool ObjectiveHandler::is_recognized_key(const std::string& key) {
return (key == "reg_loss_param" || key == "poisson_regression_param"
|| key == "tweedie_regression_param" || key == "softmax_multiclass_param"
|| key == "lambda_rank_param" || key == "aft_loss_param"
|| key == "pseduo_huber_param" || key == "pseudo_huber_param" || key == "name");
|| key == "pseduo_huber_param" || key == "pseudo_huber_param"
|| key == "lambdarank_param" || key == "name");
}

/******************************************************************************
Expand Down
40 changes: 40 additions & 0 deletions tests/cython/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
cmake_minimum_required(VERSION 3.18)

project(treelite_serializer_ext LANGUAGES CXX)

find_package(
Python
COMPONENTS Interpreter Development.Module
REQUIRED)

find_program(CYTHON "cython")

find_package(Treelite REQUIRED)

add_custom_command(
OUTPUT serializer.cpp
DEPENDS serializer.pyx
VERBATIM
COMMAND "${CYTHON}" "${PROJECT_SOURCE_DIR}/serializer.pyx" --output-file
"${PROJECT_BINARY_DIR}/serializer.cpp")

if(DEFINED ENV{CONDA_PREFIX})
set(CMAKE_PREFIX_PATH "$ENV{CONDA_PREFIX};${CMAKE_PREFIX_PATH}")
message(STATUS "Detected Conda environment, CMAKE_PREFIX_PATH set to: ${CMAKE_PREFIX_PATH}")
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
message(STATUS "No CMAKE_INSTALL_PREFIX argument detected, setting to: $ENV{CONDA_PREFIX}")
set(CMAKE_INSTALL_PREFIX $ENV{CONDA_PREFIX})
endif()
else()
message(STATUS "No Conda environment detected")
endif()

python_add_library(serializer MODULE "${PROJECT_BINARY_DIR}/serializer.cpp" WITH_SOABI)
target_link_libraries(serializer PRIVATE treelite::treelite)
set_target_properties(serializer
PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON)

install(TARGETS serializer DESTINATION "${PROJECT_SOURCE_DIR}")
7 changes: 7 additions & 0 deletions tests/cython/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[build-system]
requires = ["scikit-build-core", "cython"]
build-backend = "scikit_build_core.build"

[project]
name = "example"
version = "0.0.1"
27 changes: 0 additions & 27 deletions tests/cython/setup.py

This file was deleted.

26 changes: 15 additions & 11 deletions tests/python/test_gtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pytest
import scipy
from hypothesis import assume, given, settings
from hypothesis import given, settings
from hypothesis.strategies import data as hypothesis_callback
from hypothesis.strategies import floats, integers, just, sampled_from
from sklearn.datasets import load_svmlight_file
Expand Down Expand Up @@ -198,16 +198,15 @@ def test_skl_hist_gradient_boosting_with_categorical():
treelite.sklearn.import_model(clf)


@pytest.mark.parametrize("objective",
[
"reg:linear",
"reg:squarederror",
"reg:squaredlogerror",
"reg:pseudohubererror",
])
@given(
dataset=standard_regression_datasets(),
objective=sampled_from(
[
"reg:linear",
"reg:squarederror",
"reg:squaredlogerror",
"reg:pseudohubererror",
]
),
model_format=sampled_from(["binary", "json"]),
num_boost_round=integers(min_value=5, max_value=50),
num_parallel_tree=integers(min_value=1, max_value=5),
Expand All @@ -218,9 +217,14 @@ def test_xgb_regression(
):
# pylint: disable=too-many-locals
"""Test XGBoost with regression data"""

# See https://github.com/dmlc/xgboost/pull/9574
if objective == "reg:pseudohubererror":
pytest.xfail("XGBoost 2.0 has a bug in the serialization of Pseudo-Huber error")

X, y = dataset
if objective == "reg:squaredlogerror":
assume(np.all(y > -1))
y = np.where(y <= -1, -0.9, y)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
Expand Down Expand Up @@ -330,7 +334,7 @@ def test_xgb_multiclass_classifier(
("count:poisson", 4),
("rank:pairwise", 5),
("rank:ndcg", 5),
("rank:map", 5),
("rank:map", 2),
],
),
model_format=sampled_from(["binary", "json"]),
Expand Down
33 changes: 19 additions & 14 deletions tests/python/test_xgboost_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import pytest
from hypothesis import assume, given, settings
from hypothesis import given, settings
from hypothesis.strategies import integers, lists, sampled_from
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
Expand All @@ -27,16 +27,15 @@
pytest.skip("XGBoost not installed; skipping", allow_module_level=True)


@pytest.mark.parametrize("objective",
[
"reg:linear",
"reg:squarederror",
"reg:squaredlogerror",
"reg:pseudohubererror",
])
@given(
toolchain=sampled_from(os_compatible_toolchains()),
objective=sampled_from(
[
"reg:linear",
"reg:squarederror",
"reg:squaredlogerror",
"reg:pseudohubererror",
]
),
model_format=sampled_from(["binary", "json"]),
num_parallel_tree=integers(min_value=1, max_value=10),
dataset=standard_regression_datasets(),
Expand All @@ -45,9 +44,14 @@
def test_xgb_regression(toolchain, objective, model_format, num_parallel_tree, dataset):
# pylint: disable=too-many-locals
"""Test a random regression dataset"""

# See https://github.com/dmlc/xgboost/pull/9574
if objective == "reg:pseudohubererror":
pytest.xfail("XGBoost 2.0 has a bug in the serialization of Pseudo-Huber error")

X, y = dataset
if objective == "reg:squaredlogerror":
assume(np.all(y > -1))
y = np.where(y <= -1, -0.9, y)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
Expand All @@ -59,6 +63,7 @@ def test_xgb_regression(toolchain, objective, model_format, num_parallel_tree, d
"verbosity": 0,
"objective": objective,
"num_parallel_tree": num_parallel_tree,
"base_score": 0.0
}
num_round = 10
bst = xgb.train(
Expand Down Expand Up @@ -96,7 +101,7 @@ def test_xgb_regression(toolchain, objective, model_format, num_parallel_tree, d
assert predictor.num_feature == dtrain.num_col()
assert predictor.num_class == 1
assert predictor.pred_transform == "identity"
assert predictor.global_bias == 0.5
assert predictor.global_bias == 0.0
assert predictor.sigmoid_alpha == 1.0
dmat = treelite_runtime.DMatrix(X_test, dtype="float32")
out_pred = predictor.predict(dmat)
Expand Down Expand Up @@ -184,7 +189,7 @@ def test_xgb_iris(
("count:poisson", 4, math.log(0.5)),
("rank:pairwise", 5, 0.5),
("rank:ndcg", 5, 0.5),
("rank:map", 5, 0.5),
("rank:map", 2, 0.5),
],
ids=[
"binary:logistic",
Expand Down Expand Up @@ -276,7 +281,8 @@ def test_xgb_deserializers(toolchain, dataset):
)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
param = {"max_depth": 8, "eta": 1, "silent": 1, "objective": "reg:linear"}
param = {"max_depth": 8, "eta": 1, "silent": 1, "objective": "reg:linear",
"base_score": 0.5}
num_round = 10
bst = xgb.train(
param,
Expand Down Expand Up @@ -417,7 +423,6 @@ def test_xgb_dart(tmpdir, toolchain, model_format):
assert predictor.num_feature == dtrain.num_col()
assert predictor.num_class == 1
assert predictor.pred_transform == "sigmoid"
np.testing.assert_almost_equal(predictor.global_bias, 0, decimal=5)
assert predictor.sigmoid_alpha == 1.0
dmat = treelite_runtime.DMatrix(X, dtype="float32")
out_pred = predictor.predict(dmat)
Expand Down

0 comments on commit 346d925

Please sign in to comment.