Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] User-friendly redesign for lightgbm() #4968

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ S3method("dimnames<-",lgb.Dataset)
S3method(dim,lgb.Dataset)
S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(lightgbm,data.frame)
S3method(lightgbm,dgCMatrix)
S3method(lightgbm,formula)
S3method(lightgbm,matrix)
S3method(predict,lgb.Booster)
S3method(print,lgb.Booster)
S3method(set_field,lgb.Dataset)
Expand Down Expand Up @@ -38,11 +42,16 @@ export(saveRDS.lgb.Booster)
export(set_field)
export(slice)
import(methods)
importClassesFrom(Matrix,CsparseMatrix)
importClassesFrom(Matrix,dgCMatrix)
importClassesFrom(Matrix,sparseMatrix)
importClassesFrom(Matrix,sparseVector)
importFrom(Matrix,Matrix)
importFrom(R6,R6Class)
importFrom(data.table,":=")
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,is.data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set)
importFrom(data.table,setnames)
Expand All @@ -51,8 +60,11 @@ importFrom(data.table,setorderv)
importFrom(graphics,barplot)
importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(methods,as)
importFrom(methods,is)
importFrom(parallel,detectCores)
importFrom(stats,quantile)
importFrom(utils,head)
importFrom(utils,modifyList)
importFrom(utils,read.delim)
useDynLib(lib_lightgbm , .registration = TRUE)
204 changes: 157 additions & 47 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Booster <- R6::R6Class(
best_score = NA_real_,
params = list(),
record_evals = list(),
data_processor = NULL,

# Finalize will free up the handles
finalize = function() {
Expand Down Expand Up @@ -497,6 +498,10 @@ Booster <- R6::R6Class(

self$restore_handle()

if (!is.null(self$data_processor)) {
data <- self$data_processor$process_new_data(data)
}

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
Expand All @@ -510,19 +515,20 @@ Booster <- R6::R6Class(
modelfile = private$handle
, params = params
)
return(
predictor$predict(
data = data
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = rawscore
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
)
pred <- predictor$predict(
data = data
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = rawscore
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
)

if (!predleaf && !is.null(self$data_processor)) {
pred <- self$data_processor$process_predictions(pred, predcontrib)
}
return(pred)
},

# Transform into predictor
Expand Down Expand Up @@ -729,10 +735,60 @@ Booster <- R6::R6Class(

#' @name predict.lgb.Booster
#' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' @description Predict values on new data based on a boosting model (class \code{lgb.Booster}).
#' @param object Object of class \code{lgb.Booster} from which to make predictions.
#' @param newdata New data on which to make predictions. Allowed types are:\itemize{
#' \item `data.frame`, \bold{only if} the model object was produced through the \link{lightgbm}
#' interface. If the input to \link{lightgbm} was a `formula` or a `data.frame` with
#' categorical columns (`factor` or `character`), then \bold{only} `data.frame` inputs will
#' be accepted here. Columns will be taken according to the names that they had in the data
#' that they were passed to the model (i.e. the input here will be reordered if the order
#' does not match, and will be subsetted if it has additional columns).
#' \item `matrix` from base R. Will be converted to numeric if it isn't already.
#' \item `dgCMatrix` from package `Matrix`.
#' \item `character` with a single entry representing a path to a text file in CSV, TSV,
#' or SVMLight / LibSVM formats.
#' }
#' Other input types are not allowed.
#'
#' Note that, if using the `formula` interface, the user is responsible for making
#' factor variables' levels match to those that were passed in the data to which the model
#' was fitted, and if the model was not produced through the \link{lightgbm} interface
#' (e.g. through \link{lgb.train} or \link{lgb.cv}), then the user is responsible for
#' handling the encoding of categorical variables.
#' @param type Type of prediction to output. Allowed types are:\itemize{
#' \item `"score"`, which will output the predicted score according to the function
#' objective function being optimized (equivalent to `"link"` in base R's `glm`) - for
#' example, for `objective="binary"`, it will output probabilities, while for
#' `objective="regression"`, it will output predicted values. For objective functions other
#' than multi-class classification, the result will be a numeric vector with number of rows
#' matching to `nrow(newdata)`. For multi-class classification, if passing `reshape=TRUE`,
#' it will output a matrix with columns matching to the number of classes (and if the model
#' object was produced through the \link{lightgbm} interface instead of through
#' \link{lgb.train} or \link{lgb.cv}, it will have class names as column names if available),
#' and if passing `reshape=FALSE`, will output a numeric vector with these same results in
#' row-major order.
#' \item `"class"` (only for binary and multi-class classification objectives), which will
#' output the class with the highest predicted score. If the model object was produced through
#' the \link{lightgbm} interface and the label was a factor variable, the result will be a
#' factor variable with levels matching to classes, otherwise it will be an integer vector
#' with indicating the class number.
#' \item `"raw"`, which will output the non-transformed numbers (sum of predictions from
#' boosting iterations' results) from which the score is produced for a given objective
#' function - for example, for `objective="binary"`, this corresponds to log-odds. The
#' output type is the same as for `type="score"`.
#' \item `"leaf"`, which will output the index of the terminal node / leaf at which
#' each observations falls in each tree in the model, outputted as as integers. If passing
#' `reshape=TRUE`, the result will be a matrix with number of columns matching to number of
#' trees, otherwise it will be a vector with this same matrix in row-major order.
#' \item `"contrib"`, which will return the per-feature contributions for each prediction.
#' If passing `reshape=TRUE`, the result will be a matrix with number of columns matching
#' to number of features that the model saw while fitting, otherwise will be a vector with
#' this same matrix outputted in row-major order. If the model object was produced through
#' the \link{lightgbm} interface, `reshape=TRUE` is passed, and the data to which the model
#' was fit had column names, then the output matrix will have column names corresponding to
#' the feature names.
#' }
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
Expand All @@ -741,26 +797,25 @@ Booster <- R6::R6Class(
#' If None, if the best iteration exists and start_iteration is None or <= 0, the
#' best iteration is used; otherwise, all iterations from start_iteration are used.
#' If <= 0, all iterations from start_iteration are used (no limits).
#' @param rawscore whether the prediction should be returned in the for of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#' for logistic regression would result in predictions for log-odds instead of probabilities.
#' @param predleaf whether predict leaf index instead.
#' @param predcontrib return per-feature contributions for each record.
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
#' prediction outputs per case. When using `reshape=FALSE`, the output will
#' be in row-major order (contrary to R matrices which assume column-major order).
#' If passing `reshape=TRUE` and `newdata` has row names, the output will also have those
#' row names.
#' @param index1 When producing outputs that correspond to some numeration (such as
#' `type="class"` or `type="leaf"`), whether to make these outputs have a numeration
#' starting at 1 or at zero. Note that the underlying lightgbm core library uses zero-based
#' numeration, thus `index1=FALSE` will be slightly faster.
#' @param params a list of additional named parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#' the "Predict Parameters" section of the documentation} for a list of parameters and
#' valid values.
#' @param ... ignored
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#' the \code{reshape} value.
#'
#' When \code{predleaf = TRUE}, the output is a matrix object with the
#' number of columns corresponding to the number of trees.
#' @param ... Ignored.
#' @return Either a matrix with number of rows matching to the number of rows in `newdata`, or
#' a vector with number of entries matching to rows in `newdata`, or a vector representing a
#' matrix in row-major order with number of entries matching to `nrow(newdata)*n_outputs`;
#' depending on the requested `type` and `reshape` parameter.
#'
#' @examples
#' \donttest{
Expand Down Expand Up @@ -797,21 +852,44 @@ Booster <- R6::R6Class(
#' @importFrom utils modifyList
#' @export
predict.lgb.Booster <- function(object,
data,
newdata,
type = c("score", "class", "raw", "leaf", "contrib"),
start_iteration = NULL,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE,
reshape = TRUE,
index1 = TRUE,
params = list(),
...) {

if (!lgb.is.Booster(x = object)) {
stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
}

if (!is.character(type)) {
stop("'type' must be a character variable.")
}
type <- type[1L]
allowed_type <- c("score", "class", "raw", "leaf", "contrib")
if (!(type %in% allowed_type)) {
stop(sprintf("'type' must be one of the following: %s"
, paste(allowed_type, collapse = ", ")))
}
if (type == "class") {
reshape <- TRUE
}
rawscore <- type == "raw"
predleaf <- type == "leaf"
predcontrib <- type == "contrib"
if (type == "class") {
classif_objectives <- c("binary", "multiclass", "multiclassova")
if (!(object$params$objective %in% classif_objectives)) {
stop(sprintf(paste0("Passed prediction 'type=class', but model is not a classifier"
, "(objective: %s).")
, object$params$objective))
}
}

additional_params <- list(...)
if (length(additional_params) > 0L) {
warning(paste0(
Expand All @@ -821,19 +899,51 @@ predict.lgb.Booster <- function(object,
))
}

return(
object$predict(
data = data
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = rawscore
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
, params = params
)
pred <- object$predict(
data = newdata
, start_iteration = start_iteration
, num_iteration = num_iteration
, rawscore = rawscore
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
, params = params
)
if (type == "class") {
if (object$params$objective == "binary") {
pred <- as.integer(pred >= 0.5)
if (NROW(object$data_processor$label_levels)) {
pred <- pred + 1L
attributes(pred)$levels <- object$data_processor$label_levels
attributes(pred)$class <- "factor"
} else if (index1) {
pred <- pred + 1L
}
} else {
cnames <- colnames(pred)
pred <- max.col(pred)
if (NROW(cnames)) {
if (!is.integer(pred)) {
pred <- as.integer(pred)
}
attributes(pred)$levels <- cnames
attributes(pred)$class <- "factor"
} else if (!index1) {
pred <- pred - 1L
}
}
} else if (type == "leaf" && index1) {
pred <- pred + 1L
}
if (reshape && NROW(row.names(newdata))) {
if (is.null(dim(pred))) {
names(pred) <- row.names(newdata)
} else {
row.names(pred) <- row.names(newdata)
}
}
return(pred)
}

#' @name print.lgb.Booster
Expand Down