diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index c19452925de3..46eba2633659 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -435,7 +435,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA lapply(seq_len(n_groups), function(g) arr[g, , ]) } else { ## remove the first axis (group) - as.matrix(arr[1, , ]) + dn <- dimnames(arr) + matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3])) } } else if (predinteraction) { ## Predict interaction @@ -447,7 +448,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA lapply(seq_len(n_groups), function(g) arr[g, , , ]) } else { ## remove the first axis (group) - arr[1, , , ] + arr <- arr[1, , , , drop = FALSE] + array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4]) } } else { ## Normal prediction diff --git a/R-package/tests/testthat/test_interactions.R b/R-package/tests/testthat/test_interactions.R index 7b86537c0e06..e90467cdcf62 100644 --- a/R-package/tests/testthat/test_interactions.R +++ b/R-package/tests/testthat/test_interactions.R @@ -157,3 +157,28 @@ test_that("multiclass feature interactions work", { # sums WRT columns must be close to feature contributions expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001) }) + + +test_that("SHAP single sample works", { + train <- agaricus.train + test <- agaricus.test + booster <- xgboost( + data = train$data, + label = train$label, + max_depth = 2, + nrounds = 4, + objective = "binary:logistic", + ) + + predt <- predict( + booster, + newdata = train$data[1, , drop = FALSE], predcontrib = TRUE + ) + expect_equal(dim(predt), c(1, dim(train$data)[2] + 1)) + + predt <- predict( + booster, + newdata = train$data[1, , drop = FALSE], predinteraction = TRUE + ) + expect_equal(dim(predt), c(1, dim(train$data)[2] + 1, dim(train$data)[2] + 1)) +})