Skip to content

Commit

Permalink
tests for grouped data (except optimal prop), tests for reworked task…
Browse files Browse the repository at this point in the history
…_filter_ex
  • Loading branch information
advieser committed Oct 15, 2024
1 parent d46e9cd commit 0283529
Showing 1 changed file with 67 additions and 16 deletions.
83 changes: 67 additions & 16 deletions tests/testthat/test_pipeop_subsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,53 +73,104 @@ test_that("PipeOpSubsample works stratified", {
table(list(Species = rep(c("setosa", "versicolor", "virginica"), 100))))
})

test_that("PipeOpSubsample works for grouped data", {
test_that("PipeOpSubsample - Grouped Data - Equal group sizes", {
op = PipeOpSubsample$new()

# Data with groups of the same size
test_df = data.frame(
target = runif(3000),
x1 = runif(3000),
x2 = runif(3000),
grp = sample(paste0("g", 1:100), 3000, replace = TRUE)
grp = rep(paste0("g", 1:100), 30)
)
task = TaskRegr$new(id = "test", backend = test_df, target = "target")
task$set_col_roles("grp", "group")

train_out = op$train(list(task))[[1L]]
# grouped data are kept together
# TESTCASE: default frac, replace = FALSE
train_out = op$train(list(task))[[1]]
# Grouped data are kept together
grps_out = table(train_out$groups$group)
expect_equal(
grps_out,
table(task$groups$group)[names(grps_out)]
)
expect_equal(grps_out, table(task$groups$group)[names(grps_out)])
# Proportion optimal
train_out$nrow / task$nrow
op$param_set$values$frac

# TESTCASE: changed frac, replace = TRUE
op$param_set$set_values(frac = 2, replace = TRUE)
train_out = op$train(list(task))[[1L]]
# grouped data are kept together
train_out = op$train(list(task))[[1]]
# Grouped data are kept together
grps_out = table(train_out$groups$group)
expect_equal(
grps_out,
table(task$groups$group)[names(grps_out)]
)
expect_equal(grps_out, table(task$groups$group)[names(grps_out)])
# Proportion optimal

# use_groups = TRUE and stratify = TRUE should throw an error
# TESTCASE: Exclude some rows from row_roles$use and one group completely
task$row_roles$use = setdiff(seq(1, 2800), task$groups[group == "g1", row_id])
train_out = op$train(list(task))[[1]]
# test that all sampled rows are in row_roles$use
expect_in(train_out$row_ids, task$row_roles$use)

# TESTCASE: Set some rows to be included multiple times in row_roles$use
task$row_roles$use = c(task$row_roles$use, seq(1:50))
expect_no_error(op$train(list(task)))

# TESTCASE: use_groups = TRUE and stratify = TRUE should throw an error
op$param_set$set_values(stratify = TRUE, use_groups = TRUE)
expect_error(op$train(list(task)))

})

test_that("PipeOpSubsample - Grouped data - Large variance in group sizes", {
op = PipeOpSubsample$new()

# Data with one very large group
test_df = data.frame(
target = runif(3000),
x1 = runif(3000),
x2 = runif(3000),
grp = c(
sample(paste0("g", 1:50), 1500, replace = TRUE), # small groups
sample(paste0("G", 1:5), 1500, replace = TRUE) # large groups
)
)
task = TaskRegr$new(id = "test", backend = test_df, target = "target")
task$set_col_roles("grp", "group")

# TESTCASE: default frac, replace = FALSE
train_out = op$train(list(task))[[1]]
# Grouped data are kept together
grps_out = table(train_out$groups$group)
expect_equal(grps_out, table(task$groups$group)[names(grps_out)])
# Proportion optimal

# TESTCASE: changed frac, replace = TRUE
op$param_set$set_values(frac = 2, replace = TRUE)
train_out = op$train(list(task))[[1]]
# grouped data are kept together
grps_out = table(train_out$groups$group)
expect_equal(grps_out, table(task$groups$group)[names(grps_out)])
# Proportion optimal

})

test_that("task filter utility function", {
task = mlr_tasks$get("iris")

rowidx = as.integer(c(1, 2, 3, 2, 1, 2, 3, 2, 1)) # annoying and unnecessary mlr3 type strictness

tfiltered = task_filter_ex(task$clone(), rowidx)

expect_equal(tfiltered$data(), task$data(rows = rowidx))

task$select(c("Petal.Length", "Petal.Width"))
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))

task$set_col_roles("Petal.Length", "group")
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))
expect_equal(tfiltered$groups$group, task$groups[rowidx]$group)

task$select(character(0))
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))
expect_equal(tfiltered$groups$group, task$groups[rowidx]$group)
})

0 comments on commit 0283529

Please sign in to comment.