Skip to content

Commit

Permalink
add tests for primary iters
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jul 16, 2024
1 parent 8b370dc commit e72ae6b
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 2 deletions.
2 changes: 1 addition & 1 deletion R/MeasureAbstractCi.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
range = self$measure$range,
minimize = self$measure$minimize,
average = "custom",
properties = self$measure$properties,
properties = "primary_iters",
predict_type = self$measure$predict_type,
packages = unique(c(self$measure$packages, "mlr3inference"), packages),
label = label
Expand Down
3 changes: 3 additions & 0 deletions R/MeasureCiNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ MeasureCiNestedCV = R6Class("MeasureCiNestedCV",
s = qnorm(1 - param_vals$alpha / 2) * se
c(err_ncv - bias, err_ncv - bias - s, err_ncv - bias + s)
}
),
active = list(

)
)

Expand Down
7 changes: 7 additions & 0 deletions R/ResamplingNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ ResamplingNestedCV = R6::R6Class("ResamplingNestedCV",
assert_ro_binding(rhs)
pv = self$param_set$get_values()
pv$repeats * pv$folds^2
},
#' @field primary_iters (`integer()`)\cr
#' The primary iterations to be used for point estimation.
primary_iters = function(rhs) {
assert_ro_binding(rhs)
pvs = self$param_set$get_values()
as.vector(outer(seq_len(pvs$folds), pvs$folds^2 * seq(0, pvs$repeats - 1), `+`))
}
),
private = list(
Expand Down
7 changes: 7 additions & 0 deletions R/ResamplingPairedSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ ResamplingPairedSubsampling = R6Class("ResamplingPairedSubsampling",
iters = function(rhs) {
pvs = self$param_set$get_values()
(pvs$repeats_out * 2 + 1) * pvs$repeats_in
},
#' @field primary_iters (`integer()`)\cr
#' The primary iterations to be used for point estimation.
primary_iters = function(rhs) {
assert_ro_binding(rhs)
pvs = self$param_set$get_values()
pvs$repeats_in
}
)
)
Expand Down
3 changes: 3 additions & 0 deletions man/mlr_resamplings_ncv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/mlr_resamplings_paired_subsampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test_MeasureAbstractCI.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_that("aggregation works", {

test_that("ci can be used with other measure", {
task = tsk("iris")
learner = lrn("classif.rpart")
learner = lrn("classif.featureless")
resampling = rsmp("holdout")

rr = resample(task, learner, resampling)
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test_ResamplingNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,19 @@ test_that("stratification works",{
expect_equal(length(unique(table(task$data(r$train_set(i), "Species")$Species))), 1)
})
})

test_that("primary iters", {
task = tsk("iris")$filter(c(1:30, 51:80))$droplevels()
task$col_roles$stratum = "Species"
r = rsmp("nested_cv", folds = 3, repeats = 1)
r$instantiate(task)
expect_equal(r$primary_iters, 1:3)
r$param_set$set_values(
folds = 4L, repeats = 1
)
expect_equal(r$primary_iters, 1:4)
r$param_set$set_values(
folds = 4L, repeats = 2
)
expect_equal(r$primary_iters, c(1:4, 17:20))
})
11 changes: 11 additions & 0 deletions tests/testthat/test_ResamplingPairedSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,14 @@ test_that("uneven dataset size stratification", {
expect_equal(length(r$train_set(3)), 42)
expect_equal(length(r$test_set(3)), 28)
})

test_that("primary_iters", {
task = tsk("iris")
r = rsmp("paired_subsampling", repeats_in = 1, repeats_out = 1, ratio = 0.8)
r$instantiate(task)
expect_equal(r$primary_iters, 1L)
r$param_set$values$repeats_in = 2
expect_equal(r$primary_iters, 2L)
r$param_set$values$repeats_out = 2L
expect_equal(r$primary_iters, 2L)
})

0 comments on commit e72ae6b

Please sign in to comment.