Skip to content

Commit

Permalink
feat: use mlr3 measures as internal validation measures (#300)
Browse files Browse the repository at this point in the history
* feat: use mlr3 measure as internal validation measure

* chore: style

* chore: add comments

* tests: xgboost

* refactor: remove field

* docs: update

* fix: measure class

* feat: add regr

* chore: remove browser

* tests: parameter

* Update R/LearnerClassifXgboost.R

Co-authored-by: Sebastian Fischer <[email protected]>

* Update R/LearnerRegrXgboost.R

Co-authored-by: Sebastian Fischer <[email protected]>

* fix: crate

* fix: test

---------

Co-authored-by: Sebastian Fischer <[email protected]>
  • Loading branch information
be-marc and sebffischer authored Aug 17, 2024
1 parent 254488b commit da24fe1
Show file tree
Hide file tree
Showing 8 changed files with 449 additions and 16 deletions.
98 changes: 92 additions & 6 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
#' In order to monitor the validation performance during the training, you can set the `$validate` field of the Learner.
#' For information on how to configure the valdiation set, see the *Validation* section of [`mlr3::Learner`].
#' This validation data can also be used for early stopping, which can be enabled by setting the `early_stopping_rounds` parameter.
#' The final (or in the case of early stopping best) validation scores can be accessed via `$internal_valid_scores`, and the
#' optimal `nrounds` via `$internal_tuned_values`.
#' The final (or in the case of early stopping best) validation scores can be accessed via `$internal_valid_scores`, and the optimal `nrounds` via `$internal_tuned_values`.
#' The internal validation measure can be set via the `eval_metric` parameter that can be a [mlr3::Measure], a function, or a character string for the internal xgboost measures.
#'
#' @templateVar id classif.xgboost
#' @template learner
#'
Expand Down Expand Up @@ -103,9 +104,8 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
disable_default_eval_metric = p_lgl(default = FALSE, tags = "train"),
early_stopping_rounds = p_int(1L, default = NULL, special_vals = list(NULL), tags = "train"),
eta = p_dbl(0, 1, default = 0.3, tags = c("train", "control")),
eval_metric = p_uty(tags = "train"),
eval_metric = p_uty(tags = "train", custom_check = crate({function(x) check_true(any(is.character(x), is.function(x), inherits(x, "Measure")))})),
feature_selector = p_fct(c("cyclic", "shuffle", "random", "greedy", "thrifty"), default = "cyclic", tags = "train", depends = quote(booster == "gblinear")),
feval = p_uty(default = NULL, tags = "train"),
gamma = p_dbl(0, default = 0, tags = c("train", "control")),
grow_policy = p_fct(c("depthwise", "lossguide"), default = "depthwise", tags = "train", depends = quote(tree_method == "hist")),
interaction_constraints = p_uty(tags = "train"),
Expand Down Expand Up @@ -193,12 +193,14 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
internal_valid_scores = function() {
self$state$internal_valid_scores
},

#' @field internal_tuned_values (named `list()` or `NULL`)
#' If early stopping is activated, this returns a list with `nrounds`,
#' which is extracted from `$best_iteration` of the model and otherwise `NULL`.
internal_tuned_values = function() {
self$state$internal_tuned_values
},

#' @field validate (`numeric(1)` or `character(1)` or `NULL`)
#' How to construct the internal validation data. This parameter can be either `NULL`,
#' a ratio, `"test"`, or `"predefined"`.
Expand Down Expand Up @@ -241,7 +243,6 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
}
)


data = task$data(cols = task$feature_names)
# recode to 0:1 to that for the binary case the positive class translates to 1 (#32)
# note that task$truth() is guaranteed to have the factor levels in the right order
Expand All @@ -253,18 +254,45 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
}

# the last element in the watchlist is used as the early stopping set

internal_valid_task = task$internal_valid_task
if (!is.null(pv$early_stopping_rounds) && is.null(internal_valid_task)) {
stopf("Learner (%s): Configure field 'validate' to enable early stopping.", self$id)
}

if (!is.null(internal_valid_task)) {
test_data = internal_valid_task$data(cols = internal_valid_task$feature_names)
test_label = nlvls - as.integer(internal_valid_task$truth())
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = test_label)
pv$watchlist = c(pv$watchlist, list(test = test_data))
}

# set internal validation measure
if (inherits(pv$eval_metric, "Measure")) {
n_classes = length(task$class_names)
measure = pv$eval_metric

fun = if (pv$objective == "binary:logistic" && measure$predict_type == "prob" && inherits(measure, "MeasureBinarySimple")) {
xgboost_binary_binary_prob
} else if (pv$objective == "binary:logistic" && measure$predict_type == "prob" && inherits(measure, "MeasureClassifSimple")) {
xgboost_binary_classif_prob
} else if (pv$objective == "binary:logistic" && measure$predict_type == "response") {
xgboost_binary_response
} else if (pv$objective == "multi:softprob" && measure$predict_type == "prob") {
xgboost_multiclass_prob
} else if (pv$objective %in% c("multi:softmax", "multi:softprob") && measure$predict_type == "response") {
xgboost_multiclass_response
} else {
stop("Only 'binary:logistic', 'multi:softprob' and 'multi:softmax' objectives are supported.")
}

pv$eval_metric = mlr3misc::crate({function(pred, dtrain) {
scores = fun(pred, dtrain, measure, n_classes)
list(metric = measure$id, value = scores)
}}, n_classes, measure, fun)

pv$maximize = !measure$minimize
}

invoke(xgboost::xgb.train, data = data, .args = pv)
},

Expand Down Expand Up @@ -363,3 +391,61 @@ default_values.LearnerClassifXgboost = function(x, search_space, task, ...) { #

#' @include aaa.R
learners[["classif.xgboost"]] = LearnerClassifXgboost

# mlr3 measure to custom inner measure functions
xgboost_binary_binary_prob = function(pred, dtrain, measure, ...) {
# label is a vector of labels (0, 1)
truth = factor(xgboost::getinfo(dtrain, "label"), levels = c(0, 1))
# pred is a vector of log odds
# transform log odds to probabilities
pred = 1 / (1 + exp(-pred))
measure$fun(truth, pred, positive = "1")
}

xgboost_binary_classif_prob = function(pred, dtrain, measure, ...) {
# label is a vector of labels (0, 1)
truth = factor(xgboost::getinfo(dtrain, "label"), levels = c(0, 1))
# pred is a vector of log odds
# transform log odds to probabilities
pred = 1 / (1 + exp(-pred))
# multiclass measure needs a matrix of probabilities
pred_mat = matrix(c(pred, 1 - pred), ncol = 2)
colnames(pred_mat) = c("1", "0")
measure$fun(truth, pred_mat, positive = "1")
}

xgboost_binary_response = function(pred, dtrain, measure, ...) {
# label is a vector of labels (0, 1)
truth = factor(xgboost::getinfo(dtrain, "label"), levels = c(0, 1))
# pred is a vector of log odds
response = factor(as.integer(pred > 0), levels = c(0, 1))
measure$fun(truth, response)
}

xgboost_multiclass_prob = function(pred, dtrain, measure, n_classes, ...) {
# label is a vector of labels (0, 1, ..., n_classes - 1)
truth = factor(xgboost::getinfo(dtrain, "label"), levels = seq_len(n_classes) - 1L)

# pred is a vector of log odds for each class
# matrix must be filled by row
pred_mat = matrix(pred, ncol = n_classes, byrow = TRUE)
# transform log odds to probabilities
pred_exp = exp(pred_mat)
pred_mat = pred_exp / rowSums(pred_exp)
colnames(pred_mat) = levels(truth)

measure$fun(truth, pred_mat)
}

xgboost_multiclass_response = function(pred, dtrain, measure, n_classes, ...) {
# label is a vector of labels (0, 1, ..., n_classes - 1)
truth = factor(xgboost::getinfo(dtrain, "label"), levels = seq_len(n_classes) - 1L)

# pred is a vector of log odds for each class
# matrix must be filled by row
pred_mat = matrix(pred, ncol = n_classes, byrow = TRUE)

response = factor(max.col(pred_mat, ties.method = "random") - 1, levels = levels(truth))
measure$fun(truth, response)
}

21 changes: 18 additions & 3 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
disable_default_eval_metric = p_lgl(default = FALSE, tags = "train"),
early_stopping_rounds = p_int(1L, default = NULL, special_vals = list(NULL), tags = "train"),
eta = p_dbl(0, 1, default = 0.3, tags = "train"),
eval_metric = p_uty(default = "rmse", tags = "train"),
eval_metric = p_uty(default = "rmse", tags = "train", custom_check = crate({function(x) check_true(any(is.character(x), is.function(x), inherits(x, "Measure")))})),
feature_selector = p_fct(c("cyclic", "shuffle", "random", "greedy", "thrifty"), default = "cyclic", tags = "train", depends = quote(booster == "gblinear")),
feval = p_uty(default = NULL, tags = "train"),
gamma = p_dbl(0, default = 0, tags = "train"),
grow_policy = p_fct(c("depthwise", "lossguide"), default = "depthwise", tags = "train", depends = quote(tree_method == "hist")),
interaction_constraints = p_uty(tags = "train"),
Expand Down Expand Up @@ -208,7 +207,6 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
}

# the last element in the watchlist is used as the early stopping set

internal_valid_task = task$internal_valid_task
if (!is.null(pv$early_stopping_rounds) && is.null(internal_valid_task)) {
stopf("Learner (%s): Configure field 'validate' to enable early stopping.", self$id)
Expand All @@ -220,6 +218,23 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
pv$watchlist = c(pv$watchlist, list(test = test_data))
}

# set internal validation measure
if (inherits(pv$eval_metric, "Measure")) {
measure = pv$eval_metric

if (pv$objective %nin% c("reg:absoluteerror", "reg:squarederror")) {
stop("Only 'reg:squarederror' and 'reg:absoluteerror' objectives are supported.")
}

pv$eval_metric = mlr3misc::crate({function(pred, dtrain) {
truth = xgboost::getinfo(dtrain, "label")
scores = measure$fun(truth, pred)
list(metric = measure$id, value = scores)
}}, measure)

pv$maximize = !measure$minimize
}

invoke(xgboost::xgb.train, data = data, .args = pv)
},
#' Returns the `$best_iteration` when early stopping is activated.
Expand Down
3 changes: 2 additions & 1 deletion inst/paramtest/test_paramtest_classif.xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ test_that("classif.xgboost", {
"eval_metric", # handled by mlr3
"label", # handled by mlr3
"weight", # handled by mlr3
"nthread" # handled by mlr3
"nthread", # handled by mlr3
"feval" # handled via eval_metric parameter
)

ParamTest = run_paramtest(learner, fun, exclude, tag = "train")
Expand Down
3 changes: 2 additions & 1 deletion inst/paramtest/test_paramtest_regr.xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ test_that("regr.xgboost", {
"eval_metric", # handled by mlr3
"label", # handled by mlr3
"weight", # handled by mlr3
"nthread" # handled by mlr3
"nthread", # handled by mlr3
"feval" # handled via eval_metric parameter
)

ParamTest = run_paramtest(learner, fun, exclude, tag = "train")
Expand Down
5 changes: 2 additions & 3 deletions man/mlr_learners_classif.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/mlr_learners_regr.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit da24fe1

Please sign in to comment.