Skip to content

Commit

Permalink
Merge pull request #309 from mlr-org/base_margin
Browse files Browse the repository at this point in the history
support base_margin for xgboost
  • Loading branch information
sebffischer authored Sep 6, 2024
2 parents afb3ac7 + bb1f03b commit b58ba3b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 14 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mlr3learners (development version)

* feat: use `base_margin` in xgboost learners (#205)
* bugfix: validation for learner `lrn("regr.xgboost")` now works properly. Previously the training data was used.
* feat: add weights for logistic regression again, which were incorrectlu removed
* feat: add weights for logistic regression again, which were incorrectly removed
in a previous release (#265)
* BREAKING_CHANGE: When using internal tuning for xgboost learners, the `eval_metric` must now be set.
This achieves that one needs to make the conscious decision which performance metric to use for
Expand Down
28 changes: 22 additions & 6 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
alpha = p_dbl(0, default = 0, tags = "train"),
approxcontrib = p_lgl(default = FALSE, tags = "predict"),
base_score = p_dbl(default = 0.5, tags = "train"),
base_margin = p_uty(default = NULL, tags = "train", custom_check = crate({function(x) check_character(x, len = 1, null.ok = TRUE, min.chars = 1)})),
booster = p_fct(c("gbtree", "gblinear", "dart"), default = "gbtree", tags = c("train", "control")),
callbacks = p_uty(default = list(), tags = "train"),
colsample_bylevel = p_dbl(0, 1, default = 1, tags = "train"),
Expand Down Expand Up @@ -244,13 +245,24 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
)

data = task$data(cols = task$feature_names)
# recode to 0:1 to that for the binary case the positive class translates to 1 (#32)
# recode to 0:1 so that for the binary case the positive class translates to 1 (#32)
# note that task$truth() is guaranteed to have the factor levels in the right order
label = nlvls - as.integer(task$truth())
data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = label)
xgb_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = label)

if ("weights" %in% task$properties) {
xgboost::setinfo(data, "weight", task$weights$weight)
xgboost::setinfo(xgb_data, "weight", task$weights$weight)
}

base_margin = pv$base_margin
pv$base_margin = NULL # silence xgb.train message
if (!is.null(base_margin)) {
# base_margin must be a task feature and works only with
# binary classification objectives
assert(check_true(base_margin %in% task$feature_names),
check_true(startsWith(pv$objective, "binary")),
combine = "and")
xgboost::setinfo(xgb_data, "base_margin", data[[base_margin]])
}

# the last element in the watchlist is used as the early stopping set
Expand All @@ -262,8 +274,12 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
if (!is.null(internal_valid_task)) {
test_data = internal_valid_task$data(cols = internal_valid_task$feature_names)
test_label = nlvls - as.integer(internal_valid_task$truth())
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = test_label)
pv$watchlist = c(pv$watchlist, list(test = test_data))
xgb_test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = test_label)
if (!is.null(base_margin)) {
xgboost::setinfo(xgb_test_data, "base_margin", test_data[[base_margin]])
}

pv$watchlist = c(pv$watchlist, list(test = xgb_test_data))
}

# set internal validation measure
Expand Down Expand Up @@ -293,7 +309,7 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
pv$maximize = !measure$minimize
}

invoke(xgboost::xgb.train, data = data, .args = pv)
invoke(xgboost::xgb.train, data = xgb_data, .args = pv)
},

.predict = function(task) {
Expand Down
23 changes: 18 additions & 5 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
alpha = p_dbl(0, default = 0, tags = "train"),
approxcontrib = p_lgl(default = FALSE, tags = "predict"),
base_score = p_dbl(default = 0.5, tags = "train"),
base_margin = p_uty(default = NULL, tags = "train", custom_check = crate({function(x) check_character(x, len = 1, null.ok = TRUE, min.chars = 1)})),
booster = p_fct(c("gbtree", "gblinear", "dart"), default = "gbtree", tags = "train"),
callbacks = p_uty(default = list(), tags = "train"),
colsample_bylevel = p_dbl(0, 1, default = 1, tags = "train"),
Expand Down Expand Up @@ -200,10 +201,18 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",

data = task$data(cols = task$feature_names)
target = task$data(cols = task$target_names)
data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target))
xgb_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target))

if ("weights" %in% task$properties) {
xgboost::setinfo(data, "weight", task$weights$weight)
xgboost::setinfo(xgb_data, "weight", task$weights$weight)
}

base_margin = pv$base_margin
pv$base_margin = NULL # silence xgb.train message
if (!is.null(base_margin)) {
# base_margin must be a task feature
assert_true(base_margin %in% task$feature_names)
xgboost::setinfo(xgb_data, "base_margin", data[[base_margin]])
}

# the last element in the watchlist is used as the early stopping set
Expand All @@ -214,8 +223,12 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
if (!is.null(internal_valid_task)) {
test_data = internal_valid_task$data(cols = task$feature_names)
test_target = internal_valid_task$data(cols = task$target_names)
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target))
pv$watchlist = c(pv$watchlist, list(test = test_data))
xgb_test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target))
if (!is.null(base_margin)) {
xgboost::setinfo(xgb_test_data, "base_margin", test_data[[base_margin]])
}

pv$watchlist = c(pv$watchlist, list(test = xgb_test_data))
}

# set internal validation measure
Expand All @@ -235,7 +248,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
pv$maximize = !measure$minimize
}

invoke(xgboost::xgb.train, data = data, .args = pv)
invoke(xgboost::xgb.train, data = xgb_data, .args = pv)
},
#' Returns the `$best_iteration` when early stopping is activated.
.predict = function(task) {
Expand Down
2 changes: 1 addition & 1 deletion inst/paramtest/test_paramtest_classif.xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add_params_xgboost = x %>%
# values which do not match regex
append(values = c("interaction_constraints", "monotone_constraints", "base_score")) %>%
# only defined in help page but not in signature or website
append(values = "lambda_bias")
append(values = c("lambda_bias", "base_margin"))

test_that("classif.xgboost", {
learner = lrn("classif.xgboost", nrounds = 1L)
Expand Down
2 changes: 1 addition & 1 deletion inst/paramtest/test_paramtest_regr.xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add_params_xgboost = x %>%
# values which do not match regex
append(values = c("interaction_constraints", "monotone_constraints", "base_score")) %>%
# only defined in help page but not in signature or website
append(values = "lambda_bias")
append(values = c("lambda_bias", "base_margin"))

test_that("regr.xgboost", {
learner = lrn("regr.xgboost", nrounds = 1L)
Expand Down
25 changes: 25 additions & 0 deletions tests/testthat/test_classif_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,28 @@ test_that("mlr3measures are equal to internal measures", {
expect_equal(log_mlr3$test_classif.ce, log_internal$test_error)

})


test_that("base_margin", {
# input checks
expect_error(lrn("classif.xgboost", base_margin = 1), "Must be of type")
expect_error(lrn("classif.xgboost", base_margin = ""), "have at least 1 characters")
expect_error(lrn("classif.xgboost", base_margin = c("a", "b")), "have length 1")

# base_margin not a feature
task = tsk("iris")
learner = lrn("classif.xgboost", base_margin = "not_a_feature")
expect_error(learner$train(task), "base_margin %in%")

# base_margin is a feature but objective is multiclass
learner = lrn("classif.xgboost", base_margin = "Petal.Length")
expect_error(learner$train(task), "startsWith")

# predictions change
task = tsk("sonar") # binary classification task
l1 = lrn("classif.xgboost", nrounds = 5, predict_type = "prob")
l2 = lrn("classif.xgboost", nrounds = 5, base_margin = "V9", predict_type = "prob")
p1 = l1$train(task)$predict(task)
p2 = l2$train(task)$predict(task)
expect_false(all(p1$prob[, 1L] == p2$prob[, 1L]))
})
19 changes: 19 additions & 0 deletions tests/testthat/test_regr_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,22 @@ test_that("mlr3measures are equal to internal measures", {

expect_equal(log_mlr3, log_internal)
})

test_that("base_margin", {
# input checks
expect_error(lrn("regr.xgboost", base_margin = 1), "Must be of type")
expect_error(lrn("regr.xgboost", base_margin = ""), "have at least 1 characters")
expect_error(lrn("regr.xgboost", base_margin = c("a", "b")), "have length 1")

# base_margin not a feature
task = tsk("mtcars")
learner = lrn("regr.xgboost", base_margin = "not_a_feature")
expect_error(learner$train(task), "base_margin %in%")

# predictions change
l1 = lrn("regr.xgboost", nrounds = 5)
l2 = lrn("regr.xgboost", nrounds = 5, base_margin = "qsec")
p1 = l1$train(task)$predict(task)
p2 = l2$train(task)$predict(task)
expect_false(all(p1$response == p2$response))
})

0 comments on commit b58ba3b

Please sign in to comment.