Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CIs for non-decomposable measures #9

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ Imports:
R6,
withr
Suggests:
testthat (>= 3.0.0)
testthat (>= 3.0.0),
rpart
Remotes:
mlr-org/mlr3
Config/testthat/edition: 3
Expand Down
36 changes: 27 additions & 9 deletions R/MeasureAbstractCi.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#' The measure for which to calculate a confidence interval. Must have `$obs_loss`.
#' @param resamplings (`character()`)\cr
#' To which resampling classes this measure can be applied.
#' @param requires_obs_loss (`logical(1)`)\cr
#' Whether the inference method requires a pointwise loss function.
#' @template param_param_set
#' @template param_packages
#' @template param_label
Expand All @@ -28,7 +30,8 @@
#' @section Inheriting:
#' To define a new CI method, inherit from the abstract base class and implement the private method:
#' `ci: function(tbl: data.table, rr: ResampleResult, param_vals: named `list()`) -> numeric(3)`
#' Here, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' If `requires_obs_loss` is set to `TRUE`, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' Otherwise, `tbl` contains the result of `rr$score()` with the name of the loss column set to `"loss"`.
#' the identifier of the observation and the resampling iteration.
#' It should return a vector containing the `estimate`, `lower` and `upper` boundary in that order.
#'
Expand All @@ -49,19 +52,28 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
measure = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE) {
initialize = function(measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE,
requires_obs_loss = TRUE) { # nolint
private$.delta_method = assert_flag(delta_method, na.ok = TRUE)
self$measure = if (test_string(measure)) {
msr(measure)
} else {
private$.requires_obs_loss = assert_flag(requires_obs_loss)
if (test_string(measure)) measure = msr(measure)
self$measure = measure

if (private$.requires_obs_loss) {
assert(
check_class(measure, "Measure"),
check_false(inherits(measure, "MeasureCi")),
check_function(measure$obs_loss),
combine = "and",
.var.name = "Argument measure must be a scalar Measure with a pointwise loss function (has $obs_loss field)"
)
measure
} else {
assert(
check_class(measure, "Measure"),
check_false(inherits(measure, "MeasureCi")),
combine = "and",
.var.name = "Argument measure must be a scalar Measure."
)
}

param_set = c(param_set,
Expand Down Expand Up @@ -108,8 +120,13 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}

param_vals = self$param_set$get_values()
tbl = rr$obs_loss(self$measure)
names(tbl)[names(tbl) == self$measure$id] = "loss"
tbl = if (private$.requires_obs_loss) {
rr$obs_loss(self$measure)
} else {
rr$score(self$measure)
}
setnames(tbl, self$measure$id, "loss")

ci = private$.ci(tbl, rr, param_vals)
if (!is.null(self$measure$trafo)) {
ci = private$.trafo(ci)
Expand All @@ -121,10 +138,11 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}
),
private = list(
.requires_obs_loss = NULL,
.delta_method = FALSE,
.trafo = function(ci) {
if (!private$.delta_method) {
stopf("Measure '%s' has a trafo, but the CI does handle it", self$measure$id)
stopf("Measure '%s' has a trafo, but the CI does not handle it", self$measure$id)
}
measure = self$measure
# delta-rule
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureCIConZ.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @description
#' The conservative-z confidence intervals based on the [`ResamplingPairedSubsampling`].
#' Because the variance estimate is obtained using only `n / 2` observations, it tends to be conservative.
#' This inference method can also be applied to non-decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand All @@ -22,6 +23,7 @@ MeasureCiConZ = R6Class("MeasureCiConZ",
measure = measure,
resamplings = "ResamplingPairedSubsampling",
label = "Conservative-Z CI",
requires_obs_loss = FALSE,
delta_method = TRUE
)
}
Expand All @@ -30,7 +32,6 @@ MeasureCiConZ = R6Class("MeasureCiConZ",
.ci = function(tbl, rr, param_vals) {
repeats_in = rr$resampling$param_set$values$repeats_in
repeats_out = rr$resampling$param_set$values$repeats_out
tbl = tbl[, list(loss = mean(get("loss"))), by = "iteration"]

estimate = tbl[get("iteration") <= repeats_in, mean(get("loss"))]

Expand Down
4 changes: 3 additions & 1 deletion R/MeasureCICorT.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' Corrected-T confidence intervals based on [`ResamplingSubsampling`][mlr3::ResamplingSubsampling].
#' A heuristic factor is applied to correct for the dependence between the iterations.
#' The confidence intervals tend to be liberal.
#' This inference method can also be applied to non-decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand All @@ -29,6 +30,7 @@ MeasureCiCorrectedT = R6Class("MeasureCiCorrectedT",
measure = measure,
resamplings = "ResamplingSubsampling",
label = "Corrected-T CI",
requires_obs_loss = FALSE,
delta_method = TRUE
)
}
Expand All @@ -45,7 +47,7 @@ MeasureCiCorrectedT = R6Class("MeasureCiCorrectedT",
n2 = n - n1

# the different mu in the rows are the mu_j
mus = tbl[, list(estimate = mean(get("loss"))), by = "iteration"]$estimate
mus = tbl$loss
# the global estimator
estimate = mean(mus)
# The naive SD estimate (does not take correlation between folds into account)
Expand Down
1 change: 1 addition & 0 deletions R/MeasureCIHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @name mlr_measures_ci_holdout
#' @description
#' Standard holdout CI.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Only those from [`MeasureAbstractCi`].
#' @template param_measure
Expand Down
1 change: 1 addition & 0 deletions R/MeasureCINaiveCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' Confidence intervals for cross-validation.
#' The method is asymptotically exact for the so called *Test Error* as defined by Bayle et al. (2020).
#' For the (expected) risk, the confidence intervals tend to be too liberal.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Those from [`MeasureAbstractCi`], as well as:
#' * `variance` :: `"all-pairs"` or `"within-fold"`\cr
Expand Down
1 change: 1 addition & 0 deletions R/MeasureCiNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @name mlr_measures_ci_ncv
#' @description
#' Confidence Intervals based on [`ResamplingNestedCV`][ResamplingNestedCV], including bias-correction.
#' This inference method can only be applied to decomposable losses.
#' @section Parameters:
#' Those from [`MeasureAbstractCi`], as well as:
#' * `bias` :: `logical(1)`\cr
Expand Down
13 changes: 8 additions & 5 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,25 @@ autoplot(bmr, "ci", msr("ci", "classif.ce"))

Note that:

* Confidence Intervals can only be obtained for measures that are based on pointwise loss functions, i.e. have an `$obs_loss` field.
* Some methods require pointwise loss functions, i.e. have an `$obs_loss` field.
* Not for every resampling method exists an inference method.
* There are combinations of datasets and learners, where inference methods can fail.

## Features

* Additional Resampling Methods
* Confidence Intervals for the Generalization Error for some resampling methods
* Confidence Intervals for the Generalization Error for some resampling methods


## Inference Methods

```{r, echo = FALSE}
```{r, echo = TRUE}
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, "resamplings")
content = content[, c("key", "label", "resamplings")]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```

Expand Down
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ autoplot(bmr, "ci", msr("ci", "classif.ce"))

Note that:

- Confidence Intervals can only be obtained for measures that are based
on pointwise loss functions, i.e. have an `$obs_loss` field.
- Some methods require pointwise loss functions, i.e. have an
`$obs_loss` field.
- Not for every resampling method exists an inference method.
- There are combinations of datasets and learners, where inference
methods can fail.
Expand All @@ -89,13 +89,23 @@ Note that:

## Inference Methods

| Key | Label | Resamplings |
|:------------|:------------------|:-----------------------------|
| ci.con_z | Conservative-Z CI | ResamplingPairedSubsampling |
| ci.cor_t | Corrected-T CI | ResamplingSubsampling |
| ci.holdout | Holdout CI | ResamplingHoldout |
| ci.naive_cv | Naive CV CI | ResamplingCV , ResamplingLOO |
| ci.ncv | Nested CV CI | ResamplingNestedCV |
``` r
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
```

| Key | Label | Resamplings | Only Pointwise Loss |
|:------------|:------------------|:------------------|:--------------------|
| ci.con_z | Conservative-Z CI | PairedSubsampling | false |
| ci.cor_t | Corrected-T CI | Subsampling | false |
| ci.holdout | Holdout CI | Holdout | yes |
| ci.naive_cv | Naive CV CI | CV, LOO | yes |
| ci.ncv | Nested CV CI | NestedCV | yes |

## Bugs, Questions, Feedback

Expand Down
9 changes: 7 additions & 2 deletions man/mlr_measures_abstract_ci.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_con_z.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_cor_t.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_holdout.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_naive_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_measures_ci_ncv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 15 additions & 5 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
expect_ci_measure = function(id, resampling, task = tsk("boston_housing"),
symmetric = TRUE, stratum = "chas", ...) {
expect_ci_measure = function(id, resampling, symmetric = TRUE, ...) {
check = function(m, rr) {
m = m$clone(deep = TRUE)
get("expect_measure", envir = .GlobalEnv)(m)
Expand All @@ -26,13 +25,24 @@ expect_ci_measure = function(id, resampling, task = tsk("boston_housing"),
expect_true(ci2[2L] >= ci1[2L])
expect_true(ci2[3L] <= ci1[3L])
}
task = tsk("boston_housing")
rr = resample(task, lrn("regr.featureless"), resampling)
check(msr(id, measure = "regr.rmse", within_range = FALSE), rr)
check(msr(id, measure = "regr.mse", within_range = FALSE), rr)

task$col_roles$stratum = "chas"
rr_strat = resample(task, lrn("regr.featureless"), resampling)
check(msr(id, measure = "regr.rmse", within_range = FALSE), rr)
check(msr(id, measure = "regr.mse", within_range = FALSE), rr)
}
check(msr(id, measure = "regr.rmse", within_range = FALSE), rr_strat)
check(msr(id, measure = "regr.mse", within_range = FALSE), rr_strat)

if (!mlr3misc::require_namespaces("rpart", quietly = TRUE)) return(NULL)

# decomposable vs. non-decomposable
rr = resample(tsk("sonar"), lrn("classif.rpart", predict_type = "prob"), resampling)
if (!get_private(msr(id, "regr.mse"))$.requires_obs_loss) {
m = msr(id, "classif.auc")
check(m, rr)
} else {
expect_error(msr(id, "classif.auc"))
}
}