Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subsampling LOO estimates with diff-est-srs-wor start #496

Open
wants to merge 105 commits into
base: master
Choose a base branch
from

Conversation

avehtari
Copy link
Collaborator

@avehtari avehtari commented Apr 10, 2024

Work in progress.

Tested with

  • family="normal" and stats %in% c("elpd", "mlpd", "gmpd", "mse", "rmse")
  • family="bernoulli" and stats %in% c("acc","pctcorr")
  • family="poisson" usings trad and latent approaches
  • tests pass expect some plot tests fail with svglite: ... Graphics API version mismatch

Notes

  • n_loo matters only if validate_search = TRUE (with FALSE there is no speed advantage)
  • in cv_varsel if nloo<n and fast PSIS-LOO result is not yet available, fast PSIS-LOO result is computed
  • in cv_varsel if nloo<n, fast PSIS-LOO result is stored in slot $summaries_fast
  • the subsampling indices are stored in slot $loo_inds
  • the actual subsampling estimating happens in summary_funs.R get_stat()
  • removed some NA checking and need to recheck if those need to put back
  • improved quantiles for "mse" if the value is close to 0 (can't get negative lq)
  • "auc" is not supported (complicated)

Next

  • add support for incrementally increasing nloo?

tagging @n-kall

@avehtari avehtari requested a review from fweber144 April 12, 2024 13:03
@avehtari
Copy link
Collaborator Author

avehtari commented Apr 15, 2024

I wanted to use R2, and as I had rewrote summary stats anyway, added R2 and made all mse, rmse, and R2 to use only normal approximation with as much shared computation as possible

With the added R2 support, this PR will close also #483

@n-kall
Copy link
Collaborator

n-kall commented Apr 16, 2024

I can take a look

@n-kall n-kall self-assigned this Apr 16, 2024
@avehtari avehtari requested a review from n-kall April 16, 2024 13:25
@avehtari avehtari assigned avehtari and unassigned n-kall Apr 16, 2024
R/cv_varsel.R Outdated Show resolved Hide resolved
R/cv_varsel.R Outdated Show resolved Hide resolved
R/cv_varsel.R Outdated Show resolved Hide resolved
R/cv_varsel.R Outdated Show resolved Hide resolved
R/cv_varsel.R Outdated Show resolved Hide resolved
R/glmfun.R Outdated Show resolved Hide resolved
R/misc.R Outdated Show resolved Hide resolved
@fweber144
Copy link
Collaborator

fweber144 commented Apr 21, 2024

I have added some comments, but I'm not done with the review yet.

Besides, I think documentation needs to be updated (at least re-roxygenized, but also progressr should be mentioned), the vignettes perhaps as well, and I haven't run R CMD check (including the unit tests) yet. The NEWS file would also need to be updated.

@fweber144 fweber144 mentioned this pull request Apr 21, 2024
value.se <- weighted.sd((mu - y)^2 - (mu.bs - y)^2, wcv,
na.rm = TRUE) /
sqrt(n_notna)
# Use normal approximation for mse and delta method for rmse and R2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to refer to the standard error estimation method or the CI method? The first part ("normal approximation") refers to a CI method, but the second part ("delta method") to a standard error estimation method.

Comment on lines +502 to +503
# Compute mean and variance in log scale by matching the variance of a
# log-normal approximation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it common to assume a log-normal distribution as the sampling distribution of the MSE and RMSE estimators? I haven't seen that yet (I think), but it might be perfectly fine (motivated by the central limit theorem, I guess).

# store for later calculations
mse_e <- value
if (!is.null(summaries_baseline)) {
# delta=TRUE, variance of difference of two normally distributed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is something missing at the end; perhaps "random variables"?

((mu_baseline - y)^2 - mse_b))[loo_inds],
y_idx = loo_inds,
w = wobs)
cov_mse_e_b <- srs_diffe$y_hat / n_full^2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding: This procedure for estimating the covariance between mse_e and mse_b assumes that the summands within mse_e and mse_b coming from different observations are uncorrelated, right? At first, I thought this was violated here because mu, mu_baseline, and summaries_fast$mu are model-based and hence there is potential for cross-observations dependencies, but then I realized that mu, mu_baseline, and summaries_fast$mu all are based on the leave-one-out principle, so is this the reason why we can assume a cross-observation correlation of zero here?

@avehtari
Copy link
Collaborator Author

When I changed several bootstraps to analytic approximations and improved other approximations, I thought the math I was writing in the code was so trivial that I didn't write all the derivations and assumptions separately. Now I see I should have done that, as it takes also me a bit of time to re-check any of these when you ask a question, so they are not as obvious as I thought them to be. If you like, I can some day write the equations and assumptions for easier checking. Before that, at least every approximation I wrote is based on the tests at least as accurate as the earlier bootstrap, but much faster.

R/summary_funs.R Outdated
mse_y <- mean(wobs * (mean(y) - y)^2)
value <- 1 - mse_e / mse_y - ifelse(is.null(summaries_baseline), 0, 1 - mse_b / mse_y)
# the first-order Taylor approximation of the variance
var_mse_y <- .weighted_sd((mean(y) - y)^2, wobs)^2 / n_full
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mean(y), don't we need to take wobs into account? I think this is similar to line var_mse_b <- .weighted_sd((mu_baseline - y)^2, wobs)^2 / n_full where the parameter estimates from which mu_baseline is computed also take wobs into account.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also concerns several other occurrences of mean(y) here in get_stat().

@fweber144 fweber144 mentioned this pull request Sep 2, 2024
R/summary_funs.R Outdated
Comment on lines 387 to 393
if (!is.null(summaries_baseline)) {
# delta=TRUE
mse_e <- mse_e - mse_b
}
value_se <- sqrt((value_se^2 -
2 * mse_e / mse_y * cov_mse_e_y +
(mse_e / mse_y)^2 * var_mse_y) / mse_y^2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding from stan-dev/loo#205 (comment) is correct, then I think we would need a trivariate delta method in the !is.null(summaries_baseline) case (because mse_b comes in, too). I haven't checked whether such a trivariate delta method would give the same formula as used here. Have you checked this?

@fweber144
Copy link
Collaborator

The tests run if I change in setup.R

nobsv <- 41L

to

nobsv <- 43L

but then of course all the results change. But this shows it's a problem with the small nobsv and random data

For me,

nobsv <- 43L

does not work (runs into some error, similarly as nobsv <- 41L did for you). However,

nobsv <- 39L

works for me. Does it work for you as well (on master)? Then I would pick that for the time being. Ultimately, it would be desirable to completely revise the tests because currently, we mainly test the "big" user-level functions, with the tests for all the "small" internal functions being quite short or not existing at all. The principle of testing should rather be to test the underlying functions extensively, because then it is easier to keep the tests for the "big" (and hence slow) user-level functions short.

@avehtari
Copy link
Collaborator Author

avehtari commented Sep 6, 2024

With nobsv <- 39L I get [ FAIL 0 | WARN 902 | SKIP 2 | PASS 60545 ]

@fweber144
Copy link
Collaborator

With nobsv <- 39L I get [ FAIL 0 | WARN 902 | SKIP 2 | PASS 60545 ]

That sounds good. The warnings probably arise from the first creation of the snapshots. If you are running the tests via R CMD check, then from the second run on, you can avoid these warnings by (locally) removing the entry ^tests/testthat/bfits$ from the .Rbuildignore file. For the two skips, I don't know where they are coming from, but they are probably due to a suggested package that is not installed.

Since this solution seems to be working for you, I will push a commit to master (and merge it here) that changes nobsv to 39. As mentioned above, this is only a quick workaround.

This fixes commit 0d73c8e. However, before
commit 0d73c8e, `is.null(mu_baseline)` should
have never occurred because if `summaries_baseline` was `NULL`, then
`mu_baseline` was set to `0` (and if `summaries_baseline` was not `NULL`, then
`mu_baseline` was set to `summaries_baseline$mu` which should not be `NULL`
either). Hence, this fixup here does not only fix commit
0d73c8e, but also the incorrect behavior which
existed before it.
# log-normal approximation
# https://en.wikipedia.org/wiki/Log-normal_distribution#Arithmetic_moments
mul <- log(value^2 / sqrt(value_se^2 + value^2))
varl <- log(1 + value_se^2 / value^2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use log1p() here (for numerical stability)?

@fweber144
Copy link
Collaborator

While working on the tests, I discovered a few bugs and fixed them. However (apart from what is still open in the comments above), we still need to think about the following:

  1. In cv_varsel.vsel(), is it correct to treat summaries_fast like the search-related arguments? Or do we need to set summaries_fast to NULL if not usable, just like rk_foldwise?
  2. In case of not-subsampled LOO, I think we can "usually" (for latent projection, this requires a well-behaving family$latent_ilink() function) expect no NAs in reference and standard submodel summaries (mu and lppd). In case of subsampled LOO, I think we can expect no NAs in reference and fast submodel summaries. So in such cases, should (unexpected) NAs in the corresponding summaries be left as-is so that downstream code "sees" them? What I have in mind is for example whether we can remove the na.rm = TRUE part from

    projpred/R/summary_funs.R

    Lines 305 to 306 in 7469aea

    value <- sum(lppd - lppd_baseline, na.rm = TRUE)
    value_se <-sd(lppd - lppd_baseline, na.rm = TRUE) * sqrt(n_full)
    so that the default of na.rm = FALSE is used.

For the tests, I think it would make sense to add subsampled-LOO cv_varsel() calls already in setup.R so that all post-processing functions are tested for such subsampled-LOO cv_varsel() output. When doing so, we should pay attention to cover augmented-data and latent projection cases.

Finally, while experimenting with the tests, I discovered some strange NAs in the summaries_sub (of .tabulate_stats()) after running

cvvs_nloo <- suppressWarnings(do.call(cv_varsel, c(
list(object = refmods[[args_cvvs_i$tstsetup_ref]],
nloo = nloo_tst),
excl_nonargs(args_cvvs_i)
)))
(slightly adapted, namely using excl_nonargs(args_cvvs_i, nms_excl_add = "validate_search")) on tstsetup <- "rstanarm.glm.cumul.stdformul.without_wobs.with_offs.latent.default_meth.default_cvmeth.default_search_trms". This needs to be investigated.

`validate_search = FALSE`, the search is not run again when creating
`summaries_fast`. Only the performance evaluation (including the re-projections
required for it) is run again. Hence, it would be inconsistent to treat
`summaries_fast` like the search-related arguments of `cv_varsel.refmodel()`
when calling `cv_varsel.refmodel()` from within `cv_varsel.vsel()`. Thus, a lot
of code related to `summaries_fast` can be removed, which is done here.
@avehtari
Copy link
Collaborator Author

In cv_varsel.vsel(), is it correct to treat summaries_fast like the search-related arguments? Or do we need to set summaries_fast to NULL if not usable, just like rk_foldwise?

I think I have forgotten the context to understand this question

In case of not-subsampled LOO, I think we can "usually" (for latent projection, this requires a well-behaving family$latent_ilink() function) expect no NAs in reference and standard submodel summaries (mu and lppd). In case of subsampled LOO, I think we can expect no NAs in reference and fast submodel summaries. So in such cases, should (unexpected) NAs in the corresponding summaries be left as-is so that downstream code "sees" them? What I have in mind is for example whether we can remove the na.rm = TRUE part from

Fine for me. NA handling was there before I started this PR, and I left most of that there.

@fweber144
Copy link
Collaborator

In cv_varsel.vsel(), is it correct to treat summaries_fast like the search-related arguments? Or do we need to set summaries_fast to NULL if not usable, just like rk_foldwise?

I think I have forgotten the context to understand this question

No problem, I think I figured this out and pushed commit 0dcfcf2 which resolves this.

In case of not-subsampled LOO, I think we can "usually" (for latent projection, this requires a well-behaving family$latent_ilink() function) expect no NAs in reference and standard submodel summaries (mu and lppd). In case of subsampled LOO, I think we can expect no NAs in reference and fast submodel summaries. So in such cases, should (unexpected) NAs in the corresponding summaries be left as-is so that downstream code "sees" them? What I have in mind is for example whether we can remove the na.rm = TRUE part from

Fine for me. NA handling was there before I started this PR, and I left most of that there.

Ok, I will check this in detail and will commit changes (if possible).

why some test snapshots changed unexpectedly).
@fweber144
Copy link
Collaborator

I just pushed commit 6151b0c which reverts changes that are unrelated to subsampled LOO-CV because I observed some unexpected changes in the test snapshots. The state before reverting these subsampling-unrelated changes is now in branch misc_from_fix-subsampling. It turned out that the unexpected snapshot changes were probably due to the changes concerning SIS and PSIS in lines

projpred/R/cv_varsel.R

Lines 797 to 891 in 6151b0c

if (nrow(log_lik_ref) > 1) {
# Use loo::sis() if the projected draws (i.e., the draws resulting
# from the clustering or thinning) have nonconstant weights:
if (refdist_eval$const_wdraws_prj) {
# Internally, loo::psis() doesn't perform the Pareto smoothing if the
# number of draws is small (as indicated by object `no_psis_eval`, see
# below). In projpred, this can occur, e.g., if users request a number
# of projected draws (for performance evaluation, either after
# clustering or thinning the reference model's posterior draws) that is
# much smaller than the default of 400. In order to throw a customized
# warning message (and to avoid the calculation of Pareto k-values, see
# loo issue stan-dev/loo#227), object `no_psis_eval` indicates whether
# loo::psis() would perform the Pareto smoothing or not (for the
# decision rule, see loo:::n_pareto() and loo:::enough_tail_samples(),
# keeping in mind that we have `r_eff = 1` for all observations here).
S_for_psis_eval <- nrow(log_lik_ref)
no_psis_eval <- ceiling(min(0.2 * S_for_psis_eval,
3 * sqrt(S_for_psis_eval))) < 5
if (no_psis_eval) {
if (getOption("projpred.warn_psis", TRUE)) {
warning(
"Using standard importance sampling (SIS), as the number of ",
"draws or clusters is too small for PSIS. For improved ",
"accuracy, increase the number of draws or clusters, or use ",
"K-fold-CV."
)
}
# Use loo::sis().
# In principle, we could rely on loo::psis() here (because in such a
# case, it would internally switch to SIS automatically), but using
# loo::sis() explicitly is safer because if the loo package changes
# its decision rule, we would get a mismatch between our customized
# warning here and the IS method used by loo. See also loo issue
# stan-dev/loo#227.
importance_sampling_nm <- "sis"
} else {
# Use loo::psis().
# Usually, we have a small number of projected draws here (400 by
# default), which means that the 'loo' package will automatically
# perform the regularization from Vehtari et al. (2024,
# <https://jmlr.org/papers/v25/19-556.html>, appendix G).
importance_sampling_nm <- "psis"
}
} else {
if (getOption("projpred.warn_psis", TRUE)) {
warning(
"The projected draws used for the performance evaluation have ",
"different (i.e., nonconstant) weights, so using standard ",
"importance sampling (SIS) instead of Pareto-smoothed importance ",
"sampling (PSIS). In general, PSIS is recommended over SIS."
)
}
# Use loo::sis().
importance_sampling_nm <- "sis"
}
importance_sampling_func <- get(importance_sampling_nm,
asNamespace("loo"))
mssgs_warns_capt <- capt_mssgs_warns(
sub_psisloo <- importance_sampling_func(-log_lik_ref, cores = 1,
r_eff = NA)
)
mssgs_warns_capt <- setdiff(mssgs_warns_capt, "")
# Filter out Pareto k-value warnings (we throw a customized one instead):
mssgs_warns_capt <- grep(
"Some Pareto k diagnostic values are (too|slightly) high",
mssgs_warns_capt, value = TRUE, invert = TRUE
)
if (length(mssgs_warns_capt) > 0) {
warning(mssgs_warns_capt)
}
if (importance_sampling_nm == "psis") {
pareto_k_eval <- loo::pareto_k_values(sub_psisloo)
warn_pareto(
n07 = sum(pareto_k_eval > .ps_khat_threshold(dim(psisloo)[1])), n = n,
khat_threshold = .ps_khat_threshold(dim(sub_psisloo)[1]),
warn_txt = paste0(
"Some (%d / %d) Pareto k's for the reference model's PSIS-LOO ",
"weights given ",
ifelse(clust_used_eval,
paste0(nclusters_pred, " clustered "),
paste0(ndraws_pred, " posterior ")),
"draws are > %s."
)
)
}
lw_sub <- weights(sub_psisloo)
} else {
lw_sub <- matrix(0, nrow = nrow(log_lik_ref), ncol = ncol(log_lik_ref))
}
# Take into account that clustered draws usually have different weights:
lw_sub <- lw_sub + log(refdist_eval$wdraws_prj)
# This re-weighting requires a re-normalization (as.array() is applied to
# have stricter consistency checks, see `?sweep`):
lw_sub <- sweep(lw_sub, 2, as.array(apply(lw_sub, 2, log_sum_exp)))
(which used to read

projpred/R/cv_varsel.R

Lines 819 to 902 in 0dcfcf2

if (nrow(log_lik_ref) > 1) {
# Take into account that clustered draws usually have different weights:
lw_sub <- log_lik_ref + log(refdist_eval$wdraws_prj)
# This re-weighting requires a re-normalization (as.array() is applied to
# have stricter consistency checks, see `?sweep`):
lw_sub <- sweep(lw_sub, 2, as.array(apply(lw_sub, 2, log_sum_exp)))
# Internally, loo::psis() doesn't perform the Pareto smoothing if the
# number of draws is small (as indicated by object `no_psis_eval`, see
# below). In projpred, this can occur, e.g., if users request a number
# of projected draws (for performance evaluation, either after
# clustering or thinning the reference model's posterior draws) that is
# much smaller than the default of 400. In order to throw a customized
# warning message (and to avoid the calculation of Pareto k-values, see
# loo issue stan-dev/loo#227), object `no_psis_eval` indicates whether
# loo::psis() would perform the Pareto smoothing or not (for the
# decision rule, see loo:::n_pareto() and loo:::enough_tail_samples(),
# keeping in mind that we have `r_eff = 1` for all observations here).
S_for_psis_eval <- nrow(log_lik_ref)
no_psis_eval <- ceiling(min(0.2 * S_for_psis_eval,
3 * sqrt(S_for_psis_eval))) < 5
if (no_psis_eval) {
if (getOption("projpred.warn_psis", TRUE)) {
message(
"Using standard importance sampling (SIS) due to a small number of",
ifelse(refit_prj,
ifelse(!is.null(nclusters_pred),
" clusters",
" draws (from thinning)"),
ifelse(!is.null(nclusters),
" clusters",
" draws (from thinning)"))
)
}
# Use loo::sis().
# In principle, we could rely on loo::psis() here (because in such a
# case, it would internally switch to SIS automatically), but using
# loo::sis() explicitly is safer because if the loo package changes
# its decision rule, we would get a mismatch between our customized
# warning here and the IS method used by loo. See also loo issue
# stan-dev/loo#227.
importance_sampling_nm <- "sis"
} else {
# Use loo::psis().
# Usually, we have a small number of projected draws here (400 by
# default), which means that the 'loo' package will automatically
# perform the regularization from Vehtari et al. (2022,
# <https://doi.org/10.48550/arXiv.1507.02646>, appendix G).
importance_sampling_nm <- "psis"
}
importance_sampling_func <- get(importance_sampling_nm,
asNamespace("loo"))
mssgs_warns_capt <- capt_mssgs_warns(
sub_psisloo <- importance_sampling_func(-log_lik_ref, cores = 1,
r_eff = NA)
)
mssgs_warns_capt <- setdiff(mssgs_warns_capt, "")
# Filter out Pareto k-value warnings (we throw a customized one instead):
mssgs_warns_capt <- grep(
"Some Pareto k diagnostic values are (too|slightly) high",
mssgs_warns_capt, value = TRUE, invert = TRUE
)
if (length(mssgs_warns_capt) > 0) {
warning(mssgs_warns_capt)
}
if (importance_sampling_nm == "psis") {
pareto_k_eval <- loo::pareto_k_values(sub_psisloo)
warn_pareto(
n07 = sum(pareto_k_eval > .ps_khat_threshold(dim(psisloo)[1])), n = n,
khat_threshold = .ps_khat_threshold(dim(sub_psisloo)[1]),
warn_txt = paste0(
"Some (%d / %d) Pareto k's for the reference model's PSIS-LOO ",
"weights given ",
ifelse(clust_used_eval,
paste0(nclusters_pred, " clustered "),
paste0(ndraws_pred, " posterior ")),
"draws are > %s."
)
)
}
lw_sub <- weights(sub_psisloo)
} else {
lw_sub <- matrix(0, nrow = nrow(log_lik_ref), ncol = ncol(log_lik_ref))
}
before reverting the subsampling-unrelated changes). I think there is a bug in the SIS / PSIS changes which I will comment on above (I had some further questions concerning those lines a while ago). I would suggest to keep this PR as much restricted to subsampled LOO-CV as possible and to create a separate PR for the subsampling-unrelated changes (which are now in branch misc_from_fix-subsampling) after this PR has been merged. And, of course, there is still branch mixed_deltas_plot which needs to be worked on and eventually merged as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants