Skip to content

Commit

Permalink
mlr3 upkeep
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Aug 14, 2024
1 parent f5ae9c7 commit 8843d48
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion R/ResamplingNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ResamplingNestedCV = R6::R6Class("ResamplingNestedCV",
pv = self$param_set$get_values()
folds = pv$folds
repeats = pv$repeats
self$primary_iters = as.vector(outer(seq_len(pv$folds), pv$folds^2 * seq(0, pv$repeats - 1), `+`))
private$.primary_iters = as.vector(outer(seq_len(pv$folds), pv$folds^2 * seq(0, pv$repeats - 1), `+`))
map_dtr(seq(repeats), function(r) {
data.table(
row_id = ids,
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingPairedSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ ResamplingPairedSubsampling = R6Class("ResamplingPairedSubsampling",
repeats_out = pvs$repeats_out
ratio = pvs$ratio

self$primary_iters = repeats_in
private$.primary_iters = repeats_in

n = length(ids)
n1 = round(n * ratio)
Expand Down
2 changes: 2 additions & 0 deletions man/mlr3inferr-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test_ResamplingNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ test_that("primary iters", {
task$col_roles$stratum = "Species"
r = rsmp("nested_cv", folds = 3, repeats = 1)
r$instantiate(task)
expect_equal(r$primary_iters, 1:3)
expect_equal(get_private(r)$.primary_iters, 1:3)
r$param_set$set_values(
folds = 4L, repeats = 1
)
r$instantiate(task)
expect_equal(r$primary_iters, 1:4)
expect_equal(get_private(r)$.primary_iters, 1:4)
r$param_set$set_values(
folds = 4L, repeats = 2
)
r$instantiate(task)
expect_equal(r$primary_iters, c(1:4, 17:20))
expect_equal(get_private(r)$.primary_iters, c(1:4, 17:20))
})
6 changes: 3 additions & 3 deletions tests/testthat/test_ResamplingPairedSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ test_that("primary_iters", {
task = tsk("iris")
r = rsmp("paired_subsampling", repeats_in = 1, repeats_out = 1, ratio = 0.8)
r$instantiate(task)
expect_equal(r$primary_iters, 1L)
expect_equal(get_private(r)$.primary_iters, 1L)
r$param_set$values$repeats_in = 2
r$instantiate(task)
expect_equal(r$primary_iters, 2L)
expect_equal(get_private(r)$.primary_iters, 2L)
r$instantiate(task)
r$param_set$values$repeats_out = 2L
expect_equal(r$primary_iters, 2L)
expect_equal(get_private(r)$.primary_iters, 2L)
})

0 comments on commit 8843d48

Please sign in to comment.