diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 426c78a95942..487e6957fd74 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -1,7 +1,7 @@ # Construct an internal xgboost Booster and return a handle to it. # internal utility function xgb.Booster.handle <- function(params = list(), cachelist = list(), - modelfile = NULL) { + modelfile = NULL, handle = NULL) { if (typeof(cachelist) != "list" || !all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) { stop("cachelist must be a list of xgb.DMatrix objects") @@ -20,7 +20,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(), return(handle) } else if (typeof(modelfile) == "raw") { ## A memory buffer - bst <- xgb.unserialize(modelfile) + bst <- xgb.unserialize(modelfile, handle) xgb.parameters(bst) <- params return (bst) } else if (inherits(modelfile, "xgb.Booster")) { @@ -129,7 +129,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) { stop("argument type must be xgb.Booster") if (is.null.handle(object$handle)) { - object$handle <- xgb.Booster.handle(modelfile = object$raw) + object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle) } else { if (is.null(object$raw) && saveraw) { object$raw <- xgb.serialize(object$handle) diff --git a/R-package/R/xgb.unserialize.R b/R-package/R/xgb.unserialize.R index 411225f89769..e666eb0550b6 100644 --- a/R-package/R/xgb.unserialize.R +++ b/R-package/R/xgb.unserialize.R @@ -1,11 +1,21 @@ #' Load the instance back from \code{\link{xgb.serialize}} #' #' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}} +#' @param handle An \code{xgb.Booster.handle} object which will be overwritten with +#' the new deserialized object. Must be a null handle (e.g. when loading the model through +#' `readRDS`). If not provided, a new handle will be created. +#' @return An \code{xgb.Booster.handle} object. #' #' @export -xgb.unserialize <- function(buffer) { +xgb.unserialize <- function(buffer, handle = NULL) { cachelist <- list() - handle <- .Call(XGBoosterCreate_R, cachelist) + if (is.null(handle)) { + handle <- .Call(XGBoosterCreate_R, cachelist) + } else { + if (!is.null.handle(handle)) + stop("'handle' is not null/empty. Cannot overwrite existing handle.") + .Call(XGBoosterCreateInEmptyObj_R, cachelist, handle) + } tryCatch( .Call(XGBoosterUnserializeFromBuffer_R, handle, buffer), error = function(e) { diff --git a/R-package/man/xgb.unserialize.Rd b/R-package/man/xgb.unserialize.Rd index 7a11c5c5eca2..d191d77d4ac3 100644 --- a/R-package/man/xgb.unserialize.Rd +++ b/R-package/man/xgb.unserialize.Rd @@ -4,10 +4,17 @@ \alias{xgb.unserialize} \title{Load the instance back from \code{\link{xgb.serialize}}} \usage{ -xgb.unserialize(buffer) +xgb.unserialize(buffer, handle = NULL) } \arguments{ \item{buffer}{the buffer containing booster instance saved by \code{\link{xgb.serialize}}} + +\item{handle}{An \code{xgb.Booster.handle} object which will be overwritten with +the new deserialized object. Must be a null handle (e.g. when loading the model through +`readRDS`). If not provided, a new handle will be created.} +} +\value{ +An \code{xgb.Booster.handle} object. } \description{ Load the instance back from \code{\link{xgb.serialize}} diff --git a/R-package/src/init.c b/R-package/src/init.c index 141e46e8964a..f8f71843bd8e 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -17,6 +17,7 @@ Check these declarations against the C/Fortran source code. /* .Call calls */ extern SEXP XGBoosterBoostOneIter_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterCreate_R(SEXP); +extern SEXP XGBoosterCreateInEmptyObj_R(SEXP, SEXP); extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterGetAttrNames_R(SEXP); @@ -49,6 +50,7 @@ extern SEXP XGBGetGlobalConfig_R(); static const R_CallMethodDef CallEntries[] = { {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4}, {"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1}, + {"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2}, {"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4}, {"XGBoosterEvalOneIter_R", (DL_FUNC) &XGBoosterEvalOneIter_R, 4}, {"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 8bff4212a718..ccb193bc7833 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -272,6 +272,21 @@ SEXP XGBoosterCreate_R(SEXP dmats) { return ret; } +SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) { + R_API_BEGIN(); + int len = length(dmats); + std::vector dvec; + for (int i = 0; i < len; ++i) { + dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i))); + } + BoosterHandle handle; + CHECK_CALL(XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle)); + R_SetExternalPtrAddr(R_handle, handle); + R_RegisterCFinalizerEx(R_handle, _BoosterFinalizer, TRUE); + R_API_END(); + return R_NilValue; +} + SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { R_API_BEGIN(); CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle), diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 4647c7de233a..95a62aaedde4 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -116,6 +116,14 @@ XGB_DLL SEXP XGDMatrixNumCol_R(SEXP handle); */ XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats); + +/*! + * \brief create xgboost learner, saving the pointer into an existing R object + * \param dmats a list of dmatrix handles that will be cached + * \param R_handle a clean R external pointer (not holding any object) + */ +XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle); + /*! * \brief set parameters * \param handle handle diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 5638f70cb554..19709cb38875 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -238,12 +238,13 @@ if (grepl('Windows', Sys.info()[['sysname']]) || test_that("xgb.Booster serializing as R object works", { saveRDS(bst.Tree, 'xgb.model.rds') bst <- readRDS('xgb.model.rds') - if (file.exists('xgb.model.rds')) file.remove('xgb.model.rds') dtrain <- xgb.DMatrix(sparse_matrix, label = label) expect_equal(predict(bst.Tree, dtrain), predict(bst, dtrain), tolerance = float_tolerance) expect_equal(xgb.dump(bst.Tree), xgb.dump(bst)) xgb.save(bst, 'xgb.model') if (file.exists('xgb.model')) file.remove('xgb.model') + bst <- readRDS('xgb.model.rds') + if (file.exists('xgb.model.rds')) file.remove('xgb.model.rds') nil_ptr <- new("externalptr") class(nil_ptr) <- "xgb.Booster.handle" expect_true(identical(bst$handle, nil_ptr)) diff --git a/R-package/tests/testthat/test_model_compatibility.R b/R-package/tests/testthat/test_model_compatibility.R index d94f17f29ce7..0f13bdc73f40 100644 --- a/R-package/tests/testthat/test_model_compatibility.R +++ b/R-package/tests/testthat/test_model_compatibility.R @@ -83,6 +83,7 @@ test_that("Models from previous versions of XGBoost can be loaded", { if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) { booster <- readRDS(model_file) expect_warning(predict(booster, newdata = pred_data)) + booster <- readRDS(model_file) expect_warning(run_booster_check(booster, name)) } else { if (is_rds) {