Skip to content

Commit

Permalink
[R] Fix single sample prediction. (#7524)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 13, 2022
1 parent 3e2d751 commit 88d54c6
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
6 changes: 4 additions & 2 deletions R-package/R/xgb.Booster.R
Expand Up @@ -435,7 +435,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
lapply(seq_len(n_groups), function(g) arr[g, , ])
} else {
## remove the first axis (group)
as.matrix(arr[1, , ])
dn <- dimnames(arr)
matrix(arr[1, , ], nrow = dim(arr)[2], ncol = dim(arr)[3], dimnames = c(dn[2], dn[3]))
}
} else if (predinteraction) {
## Predict interaction
Expand All @@ -447,7 +448,8 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
lapply(seq_len(n_groups), function(g) arr[g, , , ])
} else {
## remove the first axis (group)
arr[1, , , ]
arr <- arr[1, , , , drop = FALSE]
array(arr, dim = dim(arr)[2:4], dimnames(arr)[2:4])
}
} else {
## Normal prediction
Expand Down
25 changes: 25 additions & 0 deletions R-package/tests/testthat/test_interactions.R
Expand Up @@ -157,3 +157,28 @@ test_that("multiclass feature interactions work", {
# sums WRT columns must be close to feature contributions
expect_lt(max(abs(apply(intr, c(1, 2, 3), sum) - aperm(cont, c(3, 1, 2)))), 0.00001)
})


test_that("SHAP single sample works", {
train <- agaricus.train
test <- agaricus.test
booster <- xgboost(
data = train$data,
label = train$label,
max_depth = 2,
nrounds = 4,
objective = "binary:logistic",
)

predt <- predict(
booster,
newdata = train$data[1, , drop = FALSE], predcontrib = TRUE
)
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1))

predt <- predict(
booster,
newdata = train$data[1, , drop = FALSE], predinteraction = TRUE
)
expect_equal(dim(predt), c(1, dim(train$data)[2] + 1, dim(train$data)[2] + 1))
})

0 comments on commit 88d54c6

Please sign in to comment.