Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from dmlc/master
update from dmlc/xgboost
- Loading branch information
Showing
40 changed files
with
1,065 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# install development version of caret library that contains xgboost models | ||
devtools::install_github("topepo/caret/pkg/caret") | ||
require(caret) | ||
require(xgboost) | ||
require(data.table) | ||
require(vcd) | ||
require(e1071) | ||
|
||
# Load Arthritis dataset in memory. | ||
data(Arthritis) | ||
# Create a copy of the dataset with data.table package (data.table is 100% compliant with R dataframe but its syntax is a lot more consistent and its performance are really good). | ||
df <- data.table(Arthritis, keep.rownames = F) | ||
|
||
# Let's add some new categorical features to see if it helps. Of course these feature are highly correlated to the Age feature. Usually it's not a good thing in ML, but Tree algorithms (including boosted trees) are able to select the best features, even in case of highly correlated features. | ||
# For the first feature we create groups of age by rounding the real age. Note that we transform it to factor (categorical data) so the algorithm treat them as independant values. | ||
df[,AgeDiscret:= as.factor(round(Age/10,0))] | ||
|
||
# Here is an even stronger simplification of the real age with an arbitrary split at 30 years old. I choose this value based on nothing. We will see later if simplifying the information based on arbitrary values is a good strategy (I am sure you already have an idea of how well it will work!). | ||
df[,AgeCat:= as.factor(ifelse(Age > 30, "Old", "Young"))] | ||
|
||
# We remove ID as there is nothing to learn from this feature (it will just add some noise as the dataset is small). | ||
df[,ID:=NULL] | ||
|
||
#-------------Basic Training using XGBoost in caret Library----------------- | ||
# Set up control parameters for caret::train | ||
# Here we use 10-fold cross-validation, repeating twice, and using random search for tuning hyper-parameters. | ||
fitControl <- trainControl(method = "cv", number = 10, repeats = 2, search = "random") | ||
# train a xgbTree model using caret::train | ||
model <- train(factor(Improved)~., data = df, method = "xgbTree", trControl = fitControl) | ||
|
||
# Instead of tree for our boosters, you can also fit a linear regression or logistic regression model using xgbLinear | ||
# model <- train(factor(Improved)~., data = df, method = "xgbLinear", trControl = fitControl) | ||
|
||
# See model results | ||
print(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
library(testthat) | ||
library(xgboost) | ||
|
||
test_check("xgboost") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
require(xgboost) | ||
|
||
context("basic functions") | ||
|
||
data(agaricus.train, package='xgboost') | ||
data(agaricus.test, package='xgboost') | ||
train = agaricus.train | ||
test = agaricus.test | ||
|
||
test_that("train and predict", { | ||
bst = xgboost(data = train$data, label = train$label, max.depth = 2, | ||
eta = 1, nthread = 2, nround = 2, objective = "binary:logistic") | ||
pred = predict(bst, test$data) | ||
}) | ||
|
||
|
||
test_that("early stopping", { | ||
res = xgb.cv(data = train$data, label = train$label, max.depth = 2, nfold = 5, | ||
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic", | ||
early.stop.round = 3, maximize = FALSE) | ||
expect_true(nrow(res)<20) | ||
bst = xgboost(data = train$data, label = train$label, max.depth = 2, | ||
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic", | ||
early.stop.round = 3, maximize = FALSE) | ||
pred = predict(bst, test$data) | ||
}) | ||
|
||
test_that("save_period", { | ||
bst = xgboost(data = train$data, label = train$label, max.depth = 2, | ||
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic", | ||
save_period = 10, save_name = "xgb.model") | ||
pred = predict(bst, test$data) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
context('Test models with custom objective') | ||
|
||
require(xgboost) | ||
|
||
test_that("custom objective works", { | ||
data(agaricus.train, package='xgboost') | ||
data(agaricus.test, package='xgboost') | ||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) | ||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) | ||
|
||
watchlist <- list(eval = dtest, train = dtrain) | ||
num_round <- 2 | ||
|
||
logregobj <- function(preds, dtrain) { | ||
labels <- getinfo(dtrain, "label") | ||
preds <- 1/(1 + exp(-preds)) | ||
grad <- preds - labels | ||
hess <- preds * (1 - preds) | ||
return(list(grad = grad, hess = hess)) | ||
} | ||
evalerror <- function(preds, dtrain) { | ||
labels <- getinfo(dtrain, "label") | ||
err <- as.numeric(sum(labels != (preds > 0)))/length(labels) | ||
return(list(metric = "error", value = err)) | ||
} | ||
|
||
param <- list(max.depth=2, eta=1, nthread = 2, silent=1, | ||
objective=logregobj, eval_metric=evalerror) | ||
|
||
bst <- xgb.train(param, dtrain, num_round, watchlist) | ||
expect_equal(class(bst), "xgb.Booster") | ||
expect_equal(length(bst$raw), 1064) | ||
attr(dtrain, 'label') <- getinfo(dtrain, 'label') | ||
|
||
logregobjattr <- function(preds, dtrain) { | ||
labels <- attr(dtrain, 'label') | ||
preds <- 1/(1 + exp(-preds)) | ||
grad <- preds - labels | ||
hess <- preds * (1 - preds) | ||
return(list(grad = grad, hess = hess)) | ||
} | ||
param <- list(max.depth=2, eta=1, nthread = 2, silent=1, | ||
objective=logregobjattr, eval_metric=evalerror) | ||
bst <- xgb.train(param, dtrain, num_round, watchlist) | ||
expect_equal(class(bst), "xgb.Booster") | ||
expect_equal(length(bst$raw), 1064) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
context('Test generalized linear models') | ||
|
||
require(xgboost) | ||
|
||
test_that("glm works", { | ||
data(agaricus.train, package='xgboost') | ||
data(agaricus.test, package='xgboost') | ||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) | ||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) | ||
expect_equal(class(dtrain), "xgb.DMatrix") | ||
expect_equal(class(dtest), "xgb.DMatrix") | ||
param <- list(objective = "binary:logistic", booster = "gblinear", | ||
nthread = 2, alpha = 0.0001, lambda = 1) | ||
watchlist <- list(eval = dtest, train = dtrain) | ||
num_round <- 2 | ||
bst <- xgb.train(param, dtrain, num_round, watchlist) | ||
ypred <- predict(bst, dtest) | ||
expect_equal(length(getinfo(dtest, 'label')), 1611) | ||
}) |
Oops, something went wrong.