From 43772c0f44c17b008dc8d513acac710752c42aa7 Mon Sep 17 00:00:00 2001 From: Hen Ri Date: Fri, 26 Feb 2021 21:53:20 +0100 Subject: [PATCH] make mlr3cluster usable for pipelines: store task in the prediction --- R/MeasureClustInternal.R | 4 ++-- R/PredictionClust.R | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/R/MeasureClustInternal.R b/R/MeasureClustInternal.R index b269cd9d..6cc3975a 100644 --- a/R/MeasureClustInternal.R +++ b/R/MeasureClustInternal.R @@ -20,8 +20,8 @@ MeasureClustInternal = R6Class("MeasureClustInternal", } ), private = list( - .score = function(prediction, task, ...) { - X = as.matrix(task$data(rows = prediction$row_ids)) + .score = function(prediction, ...) { + X = as.matrix(prediction$data$task$data(rows = prediction$row_ids)) if (!is.double(X)) { # clusterCrit does not convert lgls/ints storage.mode(X) = "double" } diff --git a/R/PredictionClust.R b/R/PredictionClust.R index f953569d..2e9bd223 100644 --- a/R/PredictionClust.R +++ b/R/PredictionClust.R @@ -39,7 +39,8 @@ PredictionClust = R6Class("PredictionClust", #' @param check (`logical(1)`)\cr #' If `TRUE`, performs some argument checks and predict type conversions. initialize = function(task = NULL, row_ids = task$row_ids, partition = NULL, prob = NULL, check = TRUE) { - pdata = list(row_ids = row_ids, partition = partition, prob = prob) + pdata = list(row_ids = row_ids, partition = partition, + prob = prob, task = task) pdata = discard(pdata, is.null) class(pdata) = c("PredictionDataClust", "PredictionData")