diff --git a/mlr-org/gallery/survival/2024-07-30-discrete-time/index.qmd b/mlr-org/gallery/survival/2024-07-30-discrete-time/index.qmd index cf5a3231..a4b0684c 100644 --- a/mlr-org/gallery/survival/2024-07-30-discrete-time/index.qmd +++ b/mlr-org/gallery/survival/2024-07-30-discrete-time/index.qmd @@ -1,7 +1,7 @@ --- -title: "Survival modeling in mlr3 using Bayesian Additive Regression Trees (BART)" +title: "Discrete time-to-event reduction pipeline for survival tasks" description: | - Demonstrate use of survival BART on the lung dataset via mlr3proba and distr6. + Demonstration of discrete time reduction pipeline for surival analysis. author: - name: Andreas Bender, Philip Studener, John Zobolas orcid: #TODO add ORC-IDs0000-0002-3609-8674 @@ -16,26 +16,32 @@ bibliography: ../../bibliography.bib # TODO: Add references to discrete time her In this tutorial we illustrate how to perform *discrete time-to-event analysis* [@Tutz...] using packages `mlr3` and `mlr3proba`. Discrete time-to-event analysis can be viewed as a reduction technique, where various survival tasks can be "reduced" to a more standard classification task. -To do so, the data needs to be transformed into a specific format. Details will be provided later, but briefly, we partition the follow-up into a fixed amount of intervals and create a new binary outcome variable that indicates a subjects' status within each interval. We can use this new outcome to estimate probabilities for an event within an interval, conditional on time and features. These conditional probabilities can then be combined to calculate survival probabilities. +To do so, the data needs to be transformed into a specific format. Details will be provided later, but briefly, + - we partition the follow-up into a fixed amount of intervals + - create a new binary outcome variable that indicates a subjects' status within each interval + - create a feature that represents the interval (or more generally time, which is simply an additional feature in the transformed data) + - the new outcome is used as "label" to estimate probabilities for an event within an interval, conditional on the feature "time" and other features + - these conditional probabilities (discrete hazards) can then be combined to calculate survival probabilities The advantage of this approach is that once the data are transformed, any classifier available in the `mlr3` eccosystem can be used for survival analysis. The disadvantage is that the data transformation can be cumbersome to perform manually and that it takes quite some boilerplate code to do it consistently across resampling iterations and to combine the conditional probabilities (discrete hazards) to survival probabilities. However, as we illustrate below, using the discrete time-to-event pipeline implemented in `mlr3proba` using `mlr3pipelines`, all of these details are abstracted away so that practitioners can focus on modeling their data. + +In the following paragraphs we will first provide more details and intuition on discrete time-to-event analysis and then illustrate how to perform similar and more advanced machine learning based survival tasks later on. + ## Libraries -```{r} +```{r, message = FALSE} #| output: false -library(mlr3extralearners) +library(mlr3) +library(mlr3learners) # contains logistic regression and other classifiers +library(mlr3extralearners) # additional classifiers library(mlr3pipelines) -# library(mlr3proba) -devtools::load_all("~/Dropbox/GitHub/mlr/mlr3proba") -library(survival) -library(pammtools) -library(distr6) +library(mlr3proba) +library(survival) # for survfit and proportional hazards test +library(pammtools) # some functionality to perform discrete time-to-event analysis library(dplyr) -library(tidyr) -library(tibble) library(ggplot2) ``` @@ -51,24 +57,14 @@ The data set contains survival times (in days) of patients with stomach area tum - `charlson_score`: Comorbidity score -Below we can see the first rows of the data set and check for missing values: +Below we can see the first rows of the data set: ```{r} # Load data data("tumor", package = "pammtools") head(tumor) -tsk_tumor$missings() -``` - -Next we look at the marginal survival probabilities using the Kaplan-Meier estimate: - -```{r} mean(tumor$status)# ~ca. 50% of subjects are censored during the folow up -tsk_tumor = TaskSurv$new(id = "tumor", backend = tumor, event = "status", time = "days") -# create survival task -autoplot(tsk_tumor) + ylim(c(0, 1)) ``` -The median survival time is approximately 1500 days. The Probability of surviving beyond 3000 days is about 25%. ## Discrete time-to-event analysis @@ -76,88 +72,328 @@ The median survival time is approximately 1500 days. The Probability of survivin In this section we first show how discrete time-to-event analysis works "by foot" for illustration. Later we show how to use `mlr3proba` to achieve the same. +### Data Transformation + The first step is to transform the data (we use the `as_ped` function from package `pammtools`). To do so, we first have to decide how to partition the follow-up into intervals. The general trade-off is + - high number of intervals $\rightarrow$ higher precision but also higher variability and computational burden - low number of intervals $\rightarrow$ low variability and computational burden but lower precision For now, we simply divide the follow-up into 100 intervals of equal length rounded to whole days (so each interval covers approximately 38 days). -The output below shows the original and first and last row of the transformed data. +The output below shows the original and first and last row of the transformed data for subjects 1, 3 and 4. + + + Note how each subject (id) has a different number of rows in the transformed data. -This is because we only create intervals for a subject in which they were still at risk for an event (i.e. not censored or deceased in a previous interval). -Subject 4 has only one entry, as the survival time is 33 days, thus it experiences an event in the first interval. -The new outcome variable is called `ped_status` and indicates for each interval for which a subject was at risk whether the subject survived the interval (0) or experienced an event in the interval (1). + + + +This is because we only create intervals for a subject in which they were still at risk for an event (i.e. not censored or deceased in a previous interval): + - Subject 1 is censored after 579 days, i.e. in interval (577, 615], the 16th interval (this subject has therefore been at risk for the event in the first 16 intervals and is censored in the 16th interval) + - Subject 3 experiences an event after 308 days, therefore is at risk for the event at the beginning of 8 intervals, experiencing the event in interval (269, 308], i.e. the 8th interval + - Subject 4 has only one entry, as the survival time is 33 days, thus it experiences an event in the first interval + +The new outcome variable here is called `ped_status` and indicates for each interval for which a subject was at risk whether the subject survived the interval (0) or experienced an event in the interval (1). Per construction, a subject has status (0) for all intervals, except the last one, which can be either 0 or 1. ```{r} -cut <- seq(0, max(tumor$days), length.out = 100) |> round() -disc_data <- pammtools::as_ped(data = tumor, Surv(days, status)~., cut = cut) |> - select(-offset) -unique(disc_data$interval) # The intervals in the transformed data +# create the cut points for partitioning +cut = seq(0, max(tumor$days), length.out = 100) |> round() +# create the transformed data +disc_data = pammtools::as_ped(data = tumor, Surv(days, status)~., cut = cut) |> select(-offset) +unique(disc_data$interval)[1:5] # The intervals in the transformed data tumor |> slice(c(1, 3, 4)) # original data disc_data |> filter(id %in% c(1, 3, 4)) |> group_by(id) |> summarize(number_intervals = n()) -disc_data |> filter(id %in% c(1, 3, 4)) |> group_by(id) |> slice(1, n()) +disc_data |> filter(id %in% c(1, 3, 4)) |> group_by(id) |> slice(1, n()) |> unique() ``` +Note that the number of events doesn't change between the original and transformed data: +```{r } +sum(tumor$status) +sum(disc_data$ped_status) +``` + +### Training a model Once we have the transformed data, we can fit a classifier (that returns class probabilities) to the new outcome variable to obtain predictions of the discrete time hazard, i.e. the probability $P(y =1|T = t)$ (here $t$ is a short cut for interval 1, 2, 3, ...). For illustration, we first use a logistic regression without covariates. Note however, that since the new 0/1 variable is the new outcome, the variables that respresent time (e.g. `tend` or `interval`) can (and should) be viewed as features. We use them to estimate different event probabilities in different intervals (i.e. probabilities conditional on the feature time). - -```{r} -m_0 <- glm(ped_status ~ interval, data = disc_data, family = binomial()) -coef(m_0) -``` The coefficients that we obtain from this model can be interpreted as discrete baseline hazards. Because of the reference coding, the intercept is the logit baseline hazard in the first interval, the other coefficients are the differences in baseline hazard compared to the first interval. -We can use those values to calculate the survival probability in for each time point (for comparison we also add the previous Kaplan-Meier estimate to the graph): + +```{r, message = FALSE} +m_0 = glm(ped_status ~ interval, data = disc_data, family = binomial()) +length(coef(m_0))# interval is a factor variable thus we get one coef for each interval +coef(m_0)[1:5] +``` + +We can use those values to calculate the survival probability in for each time point (for comparison we also add the Kaplan-Meier estimate to the graph): ```{r} -km <- survival::survfit(Surv(days, status)~1, data = tumor) |> broom::tidy() -ndf <- disc_data |> pammtools::make_newdata(tend = unique(tend)) -ndf$disc_haz <- predict(m_0, newdata = ndf, type = "response") -ndf$surv_prob <- cumprod(1-ndf$disc_haz) - -ggplot(ndf, aes(x = tend, y = surv_prob)) + - pammtools::geom_stephazard(aes(col = "Discrete-Time disc")) + - pammtools::geom_surv(aes(col = "Discrete-Time")) + - geom_step(data = km, aes(x = time, y = estimate, col = "Kaplan-Meier"), lty = 2) + +# Kaplan-Meier estimate +km = survival::survfit(Surv(days, status)~1, data = tumor) |> broom::tidy() +# prediction from discrete time baseline model: +# first create new data set for prediction (containing all time points/intervals) +ndf = disc_data |> pammtools::make_newdata(tend = unique(tend)) +# predict the discrete hazard in each interval +ndf$disc_haz = predict(m_0, newdata = ndf, type = "response") +# calculate the survival probability at each interval +ndf$surv_prob = cumprod(1-ndf$disc_haz) +# visualize +p_univariate = ggplot(ndf, aes(x = tend, y = surv_prob)) + + geom_surv(aes(col = "Discrete-Time")) + + geom_surv(data = km, aes(x = time, y = estimate, col = "Kaplan-Meier"), lty = 2) + ylim(c(0, 1)) + - scale_color_discrete(name = "method") + + scale_color_discrete(name = "Method") + labs(y = "Survival Probability") +p_univariate ``` As you can see, there is barely any difference between the two estimates. This illustrates, that essentially any survival time distribution can be approximated using this discrete time approach. What we have shown here annecdotaly can also be shown mathematically (TODO citation effron). -## Modeling the baseline hazard or "what is a featureless discrete time model?" - - - ## Covariate effect vs. Stratified baseline hazard -Similar to a Cox model, in a standard discrete-time model (using logistic regression or similar), we have to decide if a covariate simply shifts the baseline hazard or if it alters the shape of a baseline hazard. In the survival analysis context this is also known as proportional hazards assumption or time-varying effects. +Similar to a Cox model, in a standard discrete-time model (using logistic regression or similar), we have to decide if a covariate simply shifts the baseline hazard or if it alters the shape of a baseline hazard. +In the survival analysis context this is also known as proportional hazards assumption or time-varying effects. For illustration, we consider the `complications` variable in the `tumor` data set. For comparison, we first fit a standard Cox PH model, which indicates that complications has an substantial effect on survival (hazard twice as large as for subjects without complications). +An equivalent (proportional odds) model in discrete time is also given below. ```{r} table(tumor$complications) -m_cox <- survival::coxph(Surv(days, status)~complications, data = tumor) -summary(m_cox) +cox_complications = survival::coxph(Surv(days, status)~complications, data = tumor, x= TRUE) +exp(coef(cox_complications))# hazard ratio of 2 +disc_complications = glm(ped_status ~ interval + complications, data = disc_data, family = binomial()) +exp(coef(disc_complications)["complicationsyes"])# odds/continuation ratio +``` + +The survival curves produced for the two groups are given below: + +```{r, echo = FALSE} +ndf_ph = disc_data |> make_newdata(tend = unique(tend), complications = unique(complications)) +ndf_ph$disc_haz = predict(disc_complications, newdata = ndf_ph, type = "response") +ndf_ph = ndf_ph |> group_by(complications) |> mutate(surv_prob = cumprod(1-disc_haz)) +p_ph = ggplot(ndf_ph, aes(x = tend, y = surv_prob, lty = complications)) + + geom_surv(aes(col = "Discrete-Time")) + + ylim(c(0, 1)) +p_ph +``` + +However, looking at the Shoenfeld residuals for the proportional hazards Cox model, we see that the PH assumption is not fullfilled. Specifically, the group of patients with complications during the operation appears to have much higher hazards compared to subjects without complications + +```{r} +ph_test = cox.zph(cox_complications) +ph_test +plot(ph_test) +``` + +One solution is to fit a model where the baseline hazards depends on the covariate value. +In the Cox model context we speak of stratified baseline hazards. In the discrete-time framework, we can view this as an interaction between a feature that represents time and the feature complications. +As we can see from the resulting figure, the two models again yield similar results (allowing non-proportional hazards/odds between the two groups of patients). + + +```{r} +# stratified Cox model +cox_complications_strata = coxph(Surv(days, status) ~ strata(complications), data = tumor) +# "stratified" discrete time model +disc_complications_strata = glm(ped_status ~ interval * complications, data = disc_data, family = binomial()) +# create newdata for prediction +ndf2 = disc_data |> make_newdata(tend = unique(tend), complications = unique(complications)) +# predict discrete hazards from the model +ndf2$disc_haz = predict(disc_complications_strata, newdata = ndf2, type = "response") +# calculate survival probabilities (for each group) +ndf2 = ndf2 |> group_by(complications) |> mutate(surv_prob = cumprod(1-disc_haz)) +# visualize +sp_cox <- basehaz(cox_complications_strata) |> group_by(strata) |> + mutate(surv_prob = exp(-hazard)) |> + rename(complications = strata, tend = time) +p_stratified = ggplot(ndf2, aes(x = tend, y = surv_prob, lty = complications)) + + geom_surv(aes(col = "Discrete-Time")) + + geom_surv(data = sp_cox, aes(col = "Cox PH")) + + ylim(c(0, 1)) +p_stratified +``` + + +## Discrete time-to-event analysis using mlr3proba +As you have seen above, creating the transformed data, generating new data sets and generating predictions can become quite cumbersome. In the context of machine learning this is particularly true, as different classifiers have different interfaces for prediction and we often want to tune learners and evaluate predictive performance using different resampling strategies. + +In `mlrproba` we therefore implemented a pipeline that takes care of all these minute tasks, such that the discrete-time approach can be integrated within the usual ML workflow. + +To do so, we implemented a flexible pipeline that transforms a survival task to a classification task (using the strategies outlined above) and handles predictions of different quantities. + +The code below illustrates an example of this pipeline called `survtoclassif_disctime`, where we + +- first create a survival task (this is essentially the original data, but we need to create a survival task) +- specify a classification learner that will be used to estimate the discrete-time hazards and survival probabilities (here we use logistic regression, but any available mlr3 classification learner (TODO: link to overview) that returns probilistic predictions could be used here, as we show later) +- specify the cut points to be used to partition the follow-up (here we use 100 equidistant intervals as before) +- specify `rhs` (optional): for some learners, particularly parametric learners it can be useful to specify which features should be inlcuded and how (here the tend variable is included as factor variable in order to estimate a (different) discrete baseline hazard in each interval) +- the `graph_learner = TRUE` option indicates that the whole pipeline will be considered as a separate learner (this means that... TODO) + + +```{r} +# create survival task from data +tsk_tumor = TaskSurv$new(id="tumor", backend=tumor, time="days", event="status") +# choose classification learner. Note that we use classif.learner rather then surv.learner, i.e. any classification learner will do +lrn_logreg = lrn("classif.log_reg") +# define sequence of cut points (as before) +cut = seq(0, max(tumor$days), length.out = 100) |> round() +# define pipeline learner +pipeline_logreg = ppl( + "survtoclassif_disctime", + learner = lrn_logreg, + rhs = "as.factor(tend)", + cut = cut, + graph_learner = TRUE) +``` + +Above, the `rhs` (the *r*ight *h*and *s*ide of the model formula specification) argument controls which features are included and "how"? +This is particularly relevant for regression type learners like logistic regression. + +```{r} +# TODO: visualize graph learner +``` + +Once this graph learner has been defined, we can use it as any learner, i.e. to train the model on the data and to generate predictions: + +```{r} +pipeline_logreg$train(tsk_tumor) +pred_logreg = pipeline_logreg$predict(tsk_tumor) +pred_logreg +``` + +The predicted survival probabilities are stored in the `distr` column and are in this case the same for each subject as we don't include subject specific covariates: + +```{r} +pred_logreg$data$distr[1,] +pred_logreg$data$distr[2,] +``` + +We can also use other learners, such as `ranger` or `xgboost`: +```{r, eval = TRUE} +pipeline_ranger = ppl( + "survtoclassif_disctime", + learner = lrn("classif.ranger"), + rhs = "tend", + cut = cut, + graph_learner = TRUE) + +pipeline_ranger$train(tsk_tumor) +pred_ranger = pipeline_ranger$predict(tsk_tumor) + + +pipeline_xgboost = ppl( + "survtoclassif_disctime", + learner = lrn("classif.xgboost", nrounds = 100, eta = 0.12, max_depth = 1, objective = "binary:logistic", lambda = 0), + rhs = "tend", + cut = cut, + graph_learner = TRUE) + +pipeline_xgboost$train(tsk_tumor) +pred_xgboost = pipeline_xgboost$predict(tsk_tumor) +``` + +We can now compare the predictions of the three pipelines with the Kaplan-Meier estimate: + +```{r, eval = TRUE} +df = data.frame( + time = as.numeric(names(pred_logreg$data$distr[1,])), + y1 = pred_logreg$data$distr[1,],# logistic regression + y2 = pred_ranger$data$distr[1,], # random forest + y3 = pred_xgboost$data$distr[1,]) # xgboost + +p_univariate + + geom_surv(data = df, aes(x = time, y = y1, col = "Pipeline Logistic Regression")) + + # geom_surv(data = df, aes(x = time, y = y1, col = "Pipeline Logistic Regression Interpol")) + + geom_surv(data = df, aes(x = time, y = y2, col = "Pipeline Ranger")) + + geom_surv(data = df, aes(x = time, y = y3, col = "Pipeline XGBoost")) #+ + # geom_vline(xintercept = cut, lty = 3) +``` + +We can see that all estimates are quite similar, except for the random forest that estimates slightly higher survival probabilities in the beginning and doesn't "level off" at the end of the follow up. +"Discrete-Time" and "Pipeline Logistic Regression" are identical (which illustrates that our pipeline reproduces the 'manually' constructed estimate). + +Note that there is a subtle difference between the logistic regression pipeline and the random forest and xgboost. +The former takes time (`tend`) as a factor variable, which forces the model to estimate a different baseline hazard per interval. +For the other pipelines, we included time (`tend`) as continuous feature and the algorithms have to decide at which timepoints to split (see also section TODO: Link to featureless discrete time model). + + +## Including further features +We have seen above, that when we explicitly instruct the model to include interactions between complications and time, we obtain non-proportional hazards similar to a model with stratified baseline hazards. + +When using machine learning models like random forest and xgboost, the hope would be, that such interactions would be detected automatically. + +Note that below we modify the xgboost learner slightly by allowing tree depths of 2 in order to be able to learn interactions: + +```{r} +pipeline_ranger2 = ppl( + "survtoclassif_disctime", + learner = lrn("classif.ranger"), + rhs = "tend + complications", + cut = cut, + graph_learner = TRUE) + +pipeline_ranger2$train(tsk_tumor) +pred_ranger2 = pipeline_ranger2$predict(tsk_tumor) + + +pipeline_xgboost2 = ppl( + "survtoclassif_disctime", + learner = lrn("classif.xgboost", nrounds = 100, eta = 0.12, max_depth = 2, objective = "binary:logistic", lambda = 0), + rhs = "tend + complications", + cut = cut, + graph_learner = TRUE) + +pipeline_xgboost2$train(tsk_tumor) +pred_xgboost2 = pipeline_xgboost2$predict(tsk_tumor) +# TODO: this is somewhat clunky and unnecessary. +# can we do something like pipeline_xgboost2$predict(newdata), where +# newdata is just a data.frame with all time points for each value of complications? ``` -However, looking at the Shoenfeld residuals for this model, we see that the PH assumption is not fullfilled: +As can be seen below, the XGBoost learner learns the interaction quite well, which could be further "improved" by allowing higher tree depths and nrounds. +The random forest is more conservative, essentially estimating a proportional hazards type model. +Note, however, that we essentially learn/overfit the training data so it's not clear here which model would generalize better. + ```{r} -zph <- cox.zph(m_cox) +df2 = data.frame( + time = as.numeric(names(pred_logreg$data$distr[1,])), + rf_no = pred_ranger2$data$distr[1,], # random forest + rf_yes = pred_ranger2$data$distr[2,], # random forest + xgb_yes = pred_xgboost2$data$distr[2,], # xgboost + xgb_no = pred_xgboost2$data$distr[1,]) # xgboost + +p_stratified + + geom_surv(data = df2, aes(x = time, y = rf_no, col = "Pipeline RF", lty = "no")) + + geom_surv(data = df2, aes(x = time, y = rf_yes, col = "Pipeline RF", lty = "yes")) + + # geom_surv(data = df, aes(x = time, y = y1, col = "Pipeline Logistic Regression Interpol")) + + geom_surv(data = df2, aes(x = time, y = xgb_no, col = "Pipeline XGBoost", lty = "no")) + + geom_surv(data = df2, aes(x = time, y = xgb_yes, col = "Pipeline XGBoost", lty = "yes")) ``` +## Tuning of learners and out of sample performance + + + + +## Additional Thoughts on Discrete Time-to-Event Analysis + +### Number of Intervals + +### Modeling the baseline hazard or "what is a featureless discrete time model?" + +### Discrete-time vs. continuous time and Interpolation + + ## References