Skip to content

Commit

Permalink
Merge pull request #836 from mlr-org/glrn_shortcuts
Browse files Browse the repository at this point in the history
AB shortcuts for GraphLearner
  • Loading branch information
mb706 authored Oct 17, 2024
2 parents c823f18 + 9e3289a commit 750636f
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 26 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mlr3pipelines 0.7.0-9000

* New down-sampling PipeOps for inbalanced data: `PipeOpTomek` / `po("tomek")` and `PipeOpNearmiss` / `po("nearmiss")`
* `GraphLearner` has new active bindings/methods as shortcuts for active bindings/methods of the underlying `Graph`:
`$pipeops`, `$edges`, `$pipeops_param_set`, and `$pipeops_param_set_values` as well as `$ids()` and `$plot()`.

# mlr3pipelines 0.7.0

Expand Down
5 changes: 3 additions & 2 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@
#' Takes a list of `Graph`s or [`PipeOp`]s (or objects that can be automatically converted into `Graph`s or [`PipeOp`]s,
#' see [`as_graph()`] and [`as_pipeop()`]) as inputs and joins them in a serial `Graph` coming after `self`, as if
#' connecting them using [`%>>%`].
#' * `plot(html)` \cr
#' (`logical(1)`) -> `NULL` \cr
#' * `plot(html = FALSE, horizontal = FALSE)` \cr
#' (`logical(1)`, `logical(1)`) -> `NULL` \cr
#' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or
#' the `visNetwork` package for `html = TRUE` producing a [`htmlWidget`][htmlwidgets::htmlwidgets].
#' The [`htmlWidget`][htmlwidgets::htmlwidgets] can be rescaled using [`visOptions`][visNetwork::visOptions].
#' For `html = FALSE`, the orientation of the plotted graph can be controlled through `horizontal`.
#' * `print(dot = FALSE, dotname = "dot", fontsize = 24L)` \cr
#' (`logical(1)`, `character(1)`, `integer(1)`) -> `NULL` \cr
#' Print a representation of the [`Graph`] on the console. If `dot` is `FALSE`, output is a table with one row for each contained [`PipeOp`] and
Expand Down
85 changes: 70 additions & 15 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,25 @@
#' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
#' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
#' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
#' * `pipeops` :: named `list` of [`PipeOp`] \cr
#' Contains all [`PipeOp`]s in the underlying [`Graph`], named by the [`PipeOp`]'s `$id`s. Shortcut for `$graph_model$pipeops`. See [`Graph`] for details.
#' * `edges` :: [`data.table`][data.table::data.table] with columns `src_id` (`character`), `src_channel` (`character`), `dst_id` (`character`), `dst_channel` (`character`)\cr
#' Table of connections between the [`PipeOp`]s in the underlying [`Graph`]. Shortcut for `$graph$edges`. See [`Graph`] for details.
#' * `param_set` :: [`ParamSet`][paradox::ParamSet]\cr
#' Parameters of the underlying [`Graph`]. Shortcut for `$graph$param_set`. See [`Graph`] for details.
#' * `pipeops_param_set` :: named `list()`\cr
#' Named list containing the [`ParamSet`][paradox::ParamSet]s of all [`PipeOp`]s in the [`Graph`]. See there for details.
#' * `pipeops_param_set_values` :: named `list()`\cr
#' Named list containing the set parameter values of all [`PipeOp`]s in the [`Graph`]. See there for details.
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
#' The internal tuned parameter values collected from all `PipeOp`s.
#' The internal tuned parameter values collected from all [`PipeOp`]s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
#' The internal validation scores as retrieved from the `PipeOps`.
#' The names are prefixed with the respective IDs of the `PipeOp`s.
#' The internal validation scores as retrieved from the [`PipeOp`]s.
#' The names are prefixed with the respective IDs of the [`PipeOp`]s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
#' How to construct the validation data. This also has to be configured for the individual `PipeOp`s such as
#' How to construct the validation data. This also has to be configured for the individual [`PipeOp`]s such as
#' `PipeOpLearner`, see [`set_validate.GraphLearner`].
#' For more details on the possible values, see [`mlr3::Learner`].
#' * `marshaled` :: `logical(1)`\cr
Expand All @@ -75,6 +85,16 @@
#'
#' @section Methods:
#' Methods inherited from [`Learner`][mlr3::Learner], as well as:
#' * `ids(sorted = FALSE)` \cr
#' (`logical(1)`) -> `character` \cr
#' Get IDs of all [`PipeOp`]s. This is in order that [`PipeOp`]s were added if
#' `sorted` is `FALSE`, and topologically sorted if `sorted` is `TRUE`.
#' * `plot(html = FALSE, horizontal = FALSE)` \cr
#' (`logical(1)`, `logical(1)`) -> `NULL` \cr
#' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or
#' the `visNetwork` package for `html = TRUE` producing a [`htmlWidget`][htmlwidgets::htmlwidgets].
#' The [`htmlWidget`][htmlwidgets::htmlwidgets] can be rescaled using [`visOptions`][visNetwork::visOptions].
#' For `html = FALSE`, the orientation of the plotted graph can be controlled through `horizontal`.
#' * `marshal`\cr
#' (any) -> `self`\cr
#' Marshal the model.
Expand Down Expand Up @@ -104,11 +124,11 @@
#' This works well for simple [`Graph`]s that do not modify features too much, but may give unexpected results for `Graph`s that
#' add new features or move information between features.
#'
#' As an example, consider a feature `A`` with missing values, and a feature `B`` that is used for imputatoin, using a [`po("imputelearner")`][PipeOpImputeLearner].
#' In a case where the following [`Learner`][mlr3::Learner] performs embedded feature selection and only selects feature A,
#' the `selected_features()` method could return only feature `A``, and `$importance()` may even report 0 for feature `B`.
#' This would not be entirbababababely accurate when considering the entire `GraphLearner`, as feature `B` is used for imputation and would therefore have an impact on predictions.
#' The following should therefore only be used if the `Graph` is known to not have an impact on the relevant properties.
#' As an example, consider a feature `A` with missing values, and a feature `B` that is used for imputation, using a [`po("imputelearner")`][PipeOpImputeLearner].
#' In a case where the following [`Learner`][mlr3::Learner] performs embedded feature selection and only selects feature `A`,
#' the `selected_features()` method could return only feature `A`, and `$importance()` may even report 0 for feature `B`.
#' This would not be entirely accurate when considering the entire `GraphLearner`, as feature `B` is used for imputation and would therefore have an impact on predictions.
#' The following should therefore only be used if the [`Graph`] is known to not have an impact on the relevant properties.
#'
#' * `importance()`\cr
#' () -> `numeric`\cr
Expand Down Expand Up @@ -286,6 +306,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
} else {
stopf("Baselearner %s of %s does not implement '$loglik()'.", base_learner$id, self$id)
}
},
ids = function(sorted = FALSE) {
private$.graph$ids(sorted = sorted)
},
plot = function(html = FALSE, horizontal = FALSE, ...) {
private$.graph$plot(html = html, horizontal = horizontal, ...)
}
),
active = list(
Expand Down Expand Up @@ -339,12 +365,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
pt
},
param_set = function(rhs) {
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
stop("param_set is read-only.")
}
self$graph$param_set
},
graph = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.graph)) stop("graph is read-only")
private$.graph
Expand All @@ -360,6 +380,41 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
g$state = self$model
g
}
},
pipeops = function(rhs) {
if (!missing(rhs) && (!identical(rhs, self$graph_model$pipeops))) {
stop("pipeops is read-only")
}
self$graph_model$pipeops
},
edges = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.graph$edges)) {
stop("edges is read-only")
}
private$.graph$edges
},
param_set = function(rhs) {
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
stop("param_set is read-only.")
}
self$graph$param_set
},
pipeops_param_set = function(rhs) {
value = map(self$graph$pipeops, "param_set")
if (!missing(rhs) && !identical(value, rhs)) {
stop("pipeops_param_set is read-only")
}
value
},
pipeops_param_set_values = function(rhs) {
if (!missing(rhs)) {
assert_list(rhs)
assert_names(names(rhs), permutation.of = names(self$graph$pipeops))
for (n in names(rhs)) {
self$graph$pipeops[[n]]$param_set$values = rhs[[n]]
}
}
map(self$graph$pipeops, function(x) x$param_set$values)
}
),
private = list(
Expand Down
5 changes: 3 additions & 2 deletions man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 29 additions & 7 deletions man/mlr_learners_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ test_that("graphlearner parameters behave as they should", {
dblrn = mlr_learners$get("classif.debug")
dblrn$param_set$values$save_tasks = TRUE

# Graph ParamSet
dbgr = PipeOpScale$new() %>>% PipeOpLearner$new(dblrn)

expect_subset(c("scale.center", "scale.scale", "classif.debug.x"), dbgr$param_set$ids())
Expand Down Expand Up @@ -163,6 +164,7 @@ test_that("graphlearner parameters behave as they should", {
expect_equal(dbgr$pipeops$classif.debug$param_set$values$x, 0.5)
expect_equal(dbgr$pipeops$classif.debug$learner$param_set$values$x, 0.5)

# Graph Learner ParamSet
dblrn = mlr_learners$get("classif.debug")
dblrn$param_set$values$message_train = 1
dblrn$param_set$values$message_predict = 1
Expand All @@ -177,6 +179,32 @@ test_that("graphlearner parameters behave as they should", {

expect_mapequal(gl$param_set$values,
list(classif.debug.message_predict = 0, classif.debug.message_train = 1, classif.debug.warning_predict = 0, classif.debug.warning_train = 1))

# GraphLearner AB shortcuts
gl = GraphLearner$new(dbgr)

# GraphLearner AB $pipeops
expect_no_error({gl$pipeops$classif.debug$param_set$values$x = 0.5})
expect_equal(gl$pipeops$classif.debug$param_set$values$x, 0.5)
expect_equal(gl$graph_model$pipeops$classif.debug$param_set$values$x, 0.5)

# GraphLearner AB $pipeops_param_set
expect_no_error({gl$pipeops_param_set$classif.debug$values$x = 0})
expect_equal(gl$pipeops_param_set$classif.debug$values$x, 0)
expect_equal(gl$graph_model$pipeops$classif.debug$param_set$values$x, 0)

# GraphLearner AB $pipeops_param_set_values
expect_no_error({gl$pipeops_param_set_values$classif.debug$x = 1})
expect_equal(gl$pipeops_param_set_values$classif.debug$x, 1)
expect_equal(gl$graph_model$pipeops$classif.debug$param_set$values$x, 1)

# Change param_set pointer should throw error
expect_error({gl$pipeops$scale$param_set = ps()})
expect_error({gl$pipeops_param_set$scale = ps()})
# Lists with wrong properties should not be accepted
expect_error({gl$pipeops_param_set_values = list()})
expect_error({gl$pipeops_param_set_values = list(x = 5)})

})

test_that("graphlearner type inference", {
Expand Down

0 comments on commit 750636f

Please sign in to comment.