Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time points used in survival learners predicted matrix (distr) #387

Open
bblodfon opened this issue Sep 26, 2024 · 1 comment
Open

Time points used in survival learners predicted matrix (distr) #387

bblodfon opened this issue Sep 26, 2024 · 1 comment

Comments

@bblodfon
Copy link
Collaborator

bblodfon commented Sep 26, 2024

Hi Byron (@bcjaeger)!

I performed a small benchmark related to this PR: I wanted to know across all survival mlr3 learners that produce a survival matrix (distr predict type in mlr3proba), which time points are used as columns. The results are below and the summary is that most learners use all the train times points (this plays a large role for computing metrics like eg IBS and making things fair):

library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)

lrn_ids = mlr_learners$keys("^surv")
# remove some learners (DL models, take too much time: bart, mboost has issues, etc.)
lrn_ids = lrn_ids[!grepl(pattern = "blackboost|mboost|deep|pchazard|coxtime|priority|dnn|loghaz|gamboost", lrn_ids)]
# remove learners that don't predict `distr`
lrn_ids = lapply(lrn_ids, function(id) {
  learner = lrn(id)
  if ("distr" %in% learner$predict_types) {
    id
  } else {
    NULL
  }
}) |> unlist()

lrn_ids # ~18 survival learners
#>  [1] "surv.akritas"     "surv.aorsf"       "surv.bart"        "surv.cforest"    
#>  [5] "surv.coxboost"    "surv.coxph"       "surv.ctree"       "surv.cv_coxboost"
#>  [9] "surv.cv_glmnet"   "surv.flexible"    "surv.glmnet"      "surv.kaplan"     
#> [13] "surv.nelson"      "surv.parametric"  "surv.penalized"   "surv.ranger"     
#> [17] "surv.rfsrc"       "surv.xgboost.cox"

task = tsk("gbcs")
set.seed(42)
part = partition(task, ratio = 0.5)

# keep different time points sets to check later
train_times = task$unique_times(part$train)
train_event_times = task$unique_event_times(part$train)

test_times = task$times(part$test)
test_status = task$status(part$test)
test_event_times = sort(unique(test_times[test_status == 1]))
test_times = sort(unique(test_times))

all_times = task$unique_times()
all_event_times = task$unique_event_times()

res = lapply(lrn_ids, function(id) {
  print(id)
  learner = lrn(id)

  if (id == "surv.parametric") {
    learner$param_set$set_values(.values = list(discrete = TRUE))
  }

  if (id == "surv.bart") {
    learner$param_set$set_values(
      # low settings to make computation faster
      .values = list(nskip = 1, ndpost = 3, keepevery = 2e.g., mc.cores = 14)
    )
  }

  if (id == "surv.cforect") {
    learner$param_set$set_values(.values = list(cores = 14))
  }``` r
library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)

lrn_ids = mlr_learners$keys("^surv")
# remove some learners (DL models, take too much time: bart, mboost has issues, etc.)
lrn_ids = lrn_ids[!grepl(pattern = "blackboost|mboost|deep|pchazard|coxtime|priority|dnn|loghaz|gamboost", lrn_ids)]
# remove learners that don't predict `distr`
lrn_ids = lapply(lrn_ids, function(id) {
  learner = lrn(id)
  if ("distr" %in% learner$predict_types) {
    id
  } else {
    NULL
  }
}) |> unlist()

lrn_ids # ~18 survival learners
#>  [1] "surv.akritas"     "surv.aorsf"       "surv.bart"        "surv.cforest"    
#>  [5] "surv.coxboost"    "surv.coxph"       "surv.ctree"       "surv.cv_coxboost"
#>  [9] "surv.cv_glmnet"   "surv.flexible"    "surv.glmnet"      "surv.kaplan"     
#> [13] "surv.nelson"      "surv.parametric"  "surv.penalized"   "surv.ranger"     
#> [17] "surv.rfsrc"       "surv.xgboost.cox"
I haveI have
task = tsk("gbcs")
set.seed(42)
part = partition(task, ratio = 0.5)

# keep different time points sets to check later
train_times = task$unique_times(part$train)
train_event_times = task$unique_event_times(part$train)

test_times = task$times(part$test)
test_status = task$status(part$test)
test_event_times = sort(unique(test_times[test_status == 1]))
test_times = sort(unique(test_times))

all_times = task$unique_times()
all_event_times = task$unique_event_times()

res = lapply(lrn_ids, function(id) {
  print(id)
  learner = lrn(id)

  if (id == "surv.parametric") {
    learner$param_set$set_values(.values = list(discrete = TRUE))
  }

  if (id == "surv.bart") {
    learner$param_set$set_values(
      # low settings to make computation faster
      .values = list(nskip = 1, ndpost = 3, keepevery = 2, mc.cores = 14)
    )
  }

  if (id == "surv.cforect") {
    learner$param_set$set_values(.values = list(cores = 14))
  }

  if (id == "surv.ranger") {
    learner$param_set$set_values(.values = list(num.threads = 14))
  }

  learner$train(task, part$train)
  p = learner$predict(task, part$test)
  times = as.numeric(colnames(p$data$distr))

  # return discrete times for which we have the predicted S(times)
  times
})
#> [1] "surv.akritas"
#> [1] "surv.aorsf"
#> [1] "surv.bart"
#> [1] "surv.cforest"
#> [1] "surv.coxboost"
#> [1] "surv.coxph"
#> [1] "surv.ctree"
#> [1] "surv.cv_coxboost"
#> [1] "surv.cv_glmnet"
#> [1] "surv.flexible"
#> [1] "surv.glmnet"
#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see
#> parameter 's').
#> [1] "surv.kaplan"
#> [1] "surv.nelson"
#> [1] "surv.parametric"
#> [1] "surv.penalized"
#> [1] "surv.ranger"
#> [1] "surv.rfsrc"
#> [1] "surv.xgboost.cox"

names(res) = lrn_ids

# example times:
head(res$surv.aorsf)
#> [1]  72 177 210 294 311 323

which_times = lapply(lrn_ids, function(id) {
  times = res[[id]]
  #print(id)

  lgl_list = suppressWarnings(list(
    train = all(times == train_times),
    train_event = all(times == train_event_times),
    test = all(times == test_times),
    test_event = all(times == test_event_times),
    all = all(times == all_times),
    all_Events = all(times == all_event_times)
  ))
I have
  names(which(mlr3misc::map_lgl(lgl_list, isTRUE)))e.g.e.g.
})

names(which_times) = lrn_ids

# Results: which time points are used by each learner in the predicted survival matrix?
which_times
#> $surv.akritas
#> character(0)
#> 
#> $surv.aorsf
#> [1] "test_event"
#> 
#> $surv.bart
#> [1] "train"
#> 
#> $surv.cforest
#> [1] "train"
#> 
#> $surv.coxboost
#> [1] "train"
#> 
#> $surv.coxph
#> [1] "train"
#> 
#> $surv.ctree
#> [1] "train"
#> 
#> $surv.cv_coxboost
#> [1] "train"
#> 
#> $surv.cv_glmnet
#> [1] "train"
#> 
#> $surv.flexible
#> [1] "train"
#> 
#> $surv.glmnet
#> [1] "train"
#> 
#> $surv.kaplan
#> [1] "train"
#> 
#> $surv.nelson
#> [1] "train"
#> 
#> $surv.parametric
#> character(0)
#> 
#> $surv.penalized
#> character(0)
#> 
#> $surv.ranger
#> [1] "train_event"
#> 
#> $surv.rfsrc
#> [1] "train_event"
#> 
#> $surv.xgboost.cox
#> [1] "train"

Created on 2024-09-26 with reprex v2.1.1

  • surv.aorsf uses the unique event time points form the test set code - maybe it's a good idea to change that and harmonize with the rest of RSFs (for some reasons these learners use the unique event times points from the train set), which I think is the learner$model$event_times slot you have in your learner?
  • akritas and parametric have a ntime argument (default 150), to "spread out" the time points of the train set time points. The reason for this was efficiency (to NOT have too many time points). We could change that to have the default setting of using the unique train time points from the model$y[, "time"] slot, and if users want to use ntime they can influence that.
@bcjaeger
Copy link
Contributor

Hey John! I think harmonizing is a good idea, and it's much easier to align aorsf with the other learners than aligning the other learners with aorsf. I think my rationale was that evaluating model predictions at the times when events occur should improve efficiency versus evaluating the predictions at times around those points or potentially missing event times in testing data that occur before or after the first or last event time in the training data, respectively. But in most cases I think the event times will be very similar in training versus testing data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants