Skip to content

Commit

Permalink
ENH: stats.ContinuousDistribution: improve plot
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Apr 3, 2024
1 parent 8c788cb commit 5a8bd81
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 60 deletions.
105 changes: 91 additions & 14 deletions scipy/stats/_distribution_infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_NO_CACHE = "no_cache"

# TODO:
# When a parameter is invalid, set only the offending parameter to NaN (if possible)?
# fix QMC bug with size=() but distribution shape, say, 2
# kurtosis input validation test
# clip - ShiftedScaledNormal(loc=0, scale=0.01).ccdf(-7.32, method='quadrature') > 1
Expand Down Expand Up @@ -1712,7 +1713,7 @@ def __getattr__(self, item):
return super().__getattribute__(item)

if item in self._parameters:
return self._parameters[item]
return self._parameters[item][()]

return super().__getattribute__(item)

Expand Down Expand Up @@ -4424,8 +4425,68 @@ def logintegrand(x, order, logcenter, **kwargs):

### Convenience

def plot(self, x='x', y='pdf', *, t=('cdf', 0.001, 0.999), ax=None):
"""Plot a function of the distribution"""
def plot(self, x='x', y='pdf', *, t=('cdf', 0.0005, 0.9995), ax=None):
r"""Plot a function of the distribution
Convenience function for quick visualization of the distribution
underlying the random variable.
Parameters
----------
x, y : str, optional
String indicating the quantities to be used as the abscissa and
ordinate (horizontal and vertical coordinates), respectively.
Defaults are ``'x'`` (the domain of the random variable) and
``'pdf'`` (the probability density function). Valid values are:
'x', 'pdf', 'cdf', 'ccdf', 'icdf', 'iccdf', 'logpdf', 'logcdf',
'logccdf', 'ilogcdf', 'ilogccdf'.
t : 3-tuple of (str, float, float), optional
Tuple indicating the limits within which the quantities are plotted.
Default is ``('cdf', 0.001, 0.999)`` indicating that the central
99.9% of the distribution is to be shown. Valid values of the
string are the same as for ``x`` and ``y`` except for ``pdf``
and ``logpdf``.
ax : `matplotlib.axes`, optional
Axes on which to generate the plot. If not provided, use the
current axes.
Returns
-------
ax : `matplotlib.axes`, optional
Axes on which the plot was generated.
The plot can be customized by manipulating this object.
Examples
--------
Instantiate a distribution with the desired parameters:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from scipy import stats
>>> X = stats.Normal(mu=1., sigma=2.)
Plot the PDF over the central 99.9% of the distribution.
Compare against a histogram of a random sample.
>>> ax = X.plot()
>>> sample = X.sample(10000, qmc_engine=stats.qmc.Halton)
>>> ax.hist(sample, density=True, bins=50, alpha=0.5)
>>> plt.show()
Plot ``logpdf(x)`` as a function of ``x`` in the left tail,
where the log of the CDF is between -10 and ``np.log(0.5)``.
>>> X.plot('x', 'logpdf', t=('logcdf', -10, np.log(0.5)))
>>> plt.show()
Plot the PDF of the normal distribution as a function of the
CDF for various values of the scale parameter.
>>> X = stats.Normal(mu=0., sigma=[0.5, 1., 2])
>>> X.plot('cdf', 'pdf')
>>> plt.show()
"""

# Strategy: given t limits, get quantile limits. Form grid of
# quantiles, compute requested x and y at quantiles, and plot.
Expand All @@ -4435,6 +4496,10 @@ def plot(self, x='x', y='pdf', *, t=('cdf', 0.001, 0.999), ax=None):
# a) quantiles or probabilities
# b) linearly or logarithmically spaced
# based on the specified `t`.
# TODO:
# - smart spacing of points
# - when the parameters of the distribution are an array,
# use the full range of abscissae for all curves

t_is_quantile = {'x', 'icdf', 'iccdf', 'ilogcdf', 'ilogccdf'}
t_is_probability = {'cdf', 'ccdf', 'logcdf', 'logccdf'}
Expand All @@ -4447,41 +4512,51 @@ def plot(self, x='x', y='pdf', *, t=('cdf', 0.001, 0.999), ax=None):
tlim = tlim[:, np.newaxis] if ndim else tlim

# pdf/logpdf are not valid for `t` because we can't easily invert them
message = (f'Argument `t` of {self.__class__.__name__}.plot "'
message = (f'Argument `t` of `{self.__class__.__name__}.plot` "'
f'must be one of {valid_t}')
if y_name not in valid_xy:
raise ValueError(message)

message = (f'Argument `x` of {self.__class__.__name__}.plot "'
message = (f'Argument `x` of `{self.__class__.__name__}.plot` "'
f'must be one of {valid_xy}')
if x_name not in valid_xy:
raise ValueError(message)

message = (f'Argument `y` of {self.__class__.__name__}.plot "'
message = (f'Argument `y` of `{self.__class__.__name__}.plot` "'
f'must be one of {valid_xy}')
if y_name not in valid_xy:
raise ValueError(message)

# This could just be a warning
message = (f'`{self.__class__.__name__}.plot` was called on a random '
'variable with at least one invalid shape parameters. When '
'a parameter is invalid, no plot can be shown.')
if self._any_invalid:
raise ValueError(message)

# We could automatically ravel, but do we want to? For now, raise.
message = ("To use `plot`, distribution parameters must be "
"scalars or arrays with one or fewer dimensions.")
if ndim > 1:
raise ValueError(message)

try:
import matplotlib # noqa: F401, E402
import matplotlib.pyplot as plt # noqa: F401, E402
except ModuleNotFoundError as exc:
message = ("`matplotlib` must be installed to use "
f"`{self.__class__.__name__}.plot`.")
raise ModuleNotFoundError(message) from exc

if ax is None:
import matplotlib.pyplot as plt
ax = plt.gca()
ax = plt.gca() if ax is None else ax

# get quantile limits given t limits
qlim = tlim if t_name in t_is_quantile else getattr(self, 'i'+t_name)(tlim)

message = (f"`{self.__class__.__name__}.plot` received invalid input for `t`: "
f"calling {'i'+t_name}({tlim}) produced {qlim}.")
if not np.all(np.isfinite(qlim)):
raise ValueError(message)

# form quantile grid
grid = np.linspace(0, 1, 300)
grid = grid[:, np.newaxis] if ndim else grid
Expand All @@ -4500,17 +4575,19 @@ def plot(self, x='x', y='pdf', *, t=('cdf', 0.001, 0.999), ax=None):
# only need a legend if distribution has parameters
if len(self._parameters):
label = []
param_names = list(self._parameters)
param_arrays = [np.atleast_1d(val)
for val in self._parameters.values()]
parameters = self._parameterization.parameters
param_names = list(parameters)
param_arrays = [np.atleast_1d(self._parameters[pname])
for pname in param_names]
for param_vals in zip(*param_arrays):
assignments = [f"{name} = {val}"
assignments = [f"{parameters[name].symbol} = {val:.4g}"
for name, val in zip(param_names, param_vals)]
label.append(", ".join(assignments))
ax.legend(label)

return ax


### Fitting
# All methods above treat the distribution parameters as fixed, and the
# variable argument may be a quantile or probability. The fitting functions
Expand Down
4 changes: 2 additions & 2 deletions scipy/stats/_new_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ class Normal(ContinuousDistribution):
_sigma_domain = _RealDomain(endpoints=(0, oo))
_x_support = _RealDomain(endpoints=(-oo, oo), inclusive=(True, True))

_mu_param = _RealParameter('mu', symbol=r'\mu', domain=_mu_domain,
_mu_param = _RealParameter('mu', symbol=r'µ', domain=_mu_domain,
typical=(-1, 1))
_sigma_param = _RealParameter('sigma', symbol=r'\sigma', domain=_sigma_domain,
_sigma_param = _RealParameter('sigma', symbol=r'σ', domain=_sigma_domain,
typical=(0.5, 1.5))
_x_param = _RealParameter('x', domain=_x_support, typical=(-1, 1))

Expand Down
95 changes: 51 additions & 44 deletions scipy/stats/tests/distribution_infrastructure.ipynb

Large diffs are not rendered by default.

0 comments on commit 5a8bd81

Please sign in to comment.