Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KDE bandwidth selectors using biased or unbiased cross-validation #2384

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Sep 19, 2024

Description

As discussed on Slack, the existing bandwidth selection methods for kde oversmooth for draws from multimodal distributions with well-separated modes. This PR adds new bw options "ucv" and "bcv", which used unbiased or biased LOO-CV to select the bandwidth.

Example

import numpy as np
import matplotlib.pyplot as plt
import arviz as az
az.style.use('arviz-doc')

rng = np.random.default_rng(123)
x = np.concatenate([rng.normal(0, 1, 1000), rng.normal(60, 1, 1000)])
fig, ax = plt.subplots()
az.plot_kde(x, ax=ax, plot_kwargs={"color": "k"}, label="default")
az.plot_kde(x, ax=ax, bw='ucv', plot_kwargs={"color": "C0"}, label="UCV")
az.plot_kde(x, ax=ax, bw='bcv', plot_kwargs={"color": "C1"}, label="BCV")
ax.set_xlabel("x")
ax.set_ylabel("Density")
ax.legend()

kde

Checklist

  • Does the PR follow official
    PR format?
  • Has included a sample plot to visually illustrate the changes? (only for plot-related functions)
  • Is the new feature properly documented with an example?
  • Does the PR include new or updated tests to cover the new feature (using pytest fixture pattern)?
  • Is the code style correct (follows pylint and black guidelines)?
  • Is the new feature listed in the New features
    section of the changelog?

📚 Documentation preview 📚: https://arviz--2384.org.readthedocs.build/en/2384/

@sethaxen
Copy link
Member Author

sethaxen commented Sep 20, 2024

Here are some example plots showing how KDEs with UCV and BCV bandwidths differ from those with the default bandwidth. In general, both methods seem to work pretty well when the original density does not have bounded support. BCV tends to smooth a little less than the default, while UCV smoothes even less. Personally I would still prefer the default or maybe BCV over UCV in almost every case except for the multimodal one in the OP:
normal_kde
student-t_kde
exponential_kde
lognormal_kde
uniform_kde
beta_kde

In terms of performance, the CV-based bandwidth selection methods are significantly faster than the default, since _bw_isj is quite slow, but probably also because they tend to select lower bandwidths, which allows for fewer convolutions with the kernel.

In [1]: import arviz as az

In [2]: import numpy as np

In [3]: x = np.random.normal(0, 1, 10_000);

In [4]: %time [az.kde(x) for _ in range(1_000)];
CPU times: user 2.52 s, sys: 9.99 ms, total: 2.53 s
Wall time: 2.55 s

In [5]: %time [az.kde(x, bw='scott') for _ in range(1_000)];
CPU times: user 106 ms, sys: 0 ns, total: 106 ms
Wall time: 106 ms

In [6]: %time [az.kde(x, bw='ucv') for _ in range(1_000)];
CPU times: user 544 ms, sys: 228 μs, total: 544 ms
Wall time: 543 ms

In [7]: %time [az.kde(x, bw='bcv') for _ in range(1_000)];
CPU times: user 577 ms, sys: 73 μs, total: 577 ms
Wall time: 577 ms

@sethaxen sethaxen changed the title [WIP] Add KDE bandwidth selectors using biased or unbiased cross-validation Add KDE bandwidth selectors using biased or unbiased cross-validation Sep 20, 2024
@sethaxen sethaxen marked this pull request as ready for review September 20, 2024 19:17
@sethaxen
Copy link
Member Author

Remaining pylint errors seem to be in code not touched in this PR

@sethaxen
Copy link
Member Author

sethaxen commented Sep 22, 2024

I re-made the plots in #2384 (comment) using the explicit algorithm in the reference (i.e. without histogram binning) to see if the results were very different, and for all but log-normal, the bandwidth from the explicit algorithm is indistinguishable from the one implemented here. In that case, explicit UCV looks like BCV.

@aloctavodia
Copy link
Contributor

It seems UCV and BCV, can be quite noisy compared to "default" except when distribution are normal-like..., so their main use case could be multimodal densities? Did you compared against bw="isj"?

}


def _get_bw(x, bw, grid_counts=None, x_std=None, x_range=None):
def _get_bw(x, bw, grid_counts=None, bin_width=None, x_std=None, x_range=None): # pylint: disable=too-many-positional-arguments
Copy link
Contributor

@aloctavodia aloctavodia Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can disable "too-many..." globally. It's popping up in many places.

@sethaxen
Copy link
Member Author

It seems UCV and BCV, can be quite noisy compared to "default" except when distribution are normal-like..., so their main use case could be multimodal densities?

I suspect it's more that the UCV and BCV criteria don't use any boundary correction, and when the distribution is bounded with high density near the bounds, then for large bandwidths the mean integrated squared error (MISE) is very high near the bound, while smaller bandwidths allow a sharper transition from low density on the wrong side of the bound to high density on the right side of the bound. So this explains why the bw that minimizes this objective for Normal, beta, t, and GMM produces smooth KDEs, while for exponential and uniform it looks undersmoothed. (log-normal looking weird is due to histogram binning, not due to the objective itself).

I've worked out the equivalent objective for boundary correction using reflection. Reflection is not technically equivalent to what we do, but I suspect it would produce similar results. I'm working out efficient evaluation of the objective now.

Did you compared against bw="isj"?

Not directly, but I will.

I also located the source for how the R stats module computes selects the UCV and BCV-based bandwidths, and they seem to use the identical approach to this PR, except they set the lower bound for the optimization to 0.1 * bw_max. See https://github.com/wch/r-source/blob/24e4ab74c81594bdbf9e05fd60b60cff2efe4215/src/library/stats/R/bandwidths.R#L107-L147 and https://github.com/wch/r-source/blob/24e4ab74c81594bdbf9e05fd60b60cff2efe4215/src/library/stats/src/bandwidths.c#L39-L72. Any differences between our results and theirs should be due to differences in binning and in the optimization algorithm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants