Skip to content

Commit

Permalink
MAINT: stats.ContinuousDistribution: lower-precision RNG generation, …
Browse files Browse the repository at this point in the history
…fix docs?
  • Loading branch information
mdhaber committed Apr 26, 2024
1 parent 159d927 commit 2725f45
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
15 changes: 9 additions & 6 deletions scipy/stats/_distribution_infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _isnull(x):
_NO_CACHE = "no_cache"

# TODO:
# Test sample dtypes
# Add dtype kwarg (especially for distributions with no parameters)
# When drawing endpoint/out-of-bounds values of a parameter, draw them from
# the endpoints/out-of-bounds region of the full `domain`, not `typical`.
# Distributions without shape parameters probably need to accept a `dtype` parameter;
Expand Down Expand Up @@ -4206,15 +4208,16 @@ def sample(self, shape=(), *, method=None, rng=None, qmc_engine=None):
rng = self._validate_rng(rng) or self.rng or np.random.default_rng()

if qmc_engine is None:
return self._sample_dispatch(sample_shape, full_shape, method=method,
rng=rng, **self._parameters)
res = self._sample_dispatch(sample_shape, full_shape, method=method,
rng=rng, **self._parameters)
else:
# needs input validation for qrng
d = int(np.prod(full_shape[1:]))
length = full_shape[0] if full_shape else 1
qrng = qmc_engine(d=d, seed=rng)
return self._qmc_sample_dispatch(length, full_shape, method=method,
qrng=qrng, **self._parameters)
res = self._qmc_sample_dispatch(length, full_shape, method=method,
qrng=qrng, **self._parameters)
return res.astype(self._dtype, copy=False)

@_dispatch
def _sample_dispatch(self, sample_shape, full_shape, *, method, rng, **params):
Expand All @@ -4229,7 +4232,7 @@ def _sample_formula(self, sample_shape, full_shape, *, rng, **params):
raise NotImplementedError(self._not_implemented)

def _sample_inverse_transform(self, sample_shape, full_shape, *, rng, **params):
uniform = rng.uniform(size=full_shape)
uniform = rng.random(size=full_shape, dtype=self._dtype)
return self._icdf_dispatch(uniform, **params)

@_dispatch
Expand All @@ -4246,7 +4249,7 @@ def _qmc_sample_formula(self, length, full_shape, *, qrng, **params):

def _qmc_sample_inverse_transform(self, length, full_shape, *, qrng, **params):
uniform = qrng.random(length)
uniform = np.reshape(uniform, full_shape)
uniform = np.reshape(uniform, full_shape).astype(self._dtype)
return self._icdf_dispatch(uniform, **params)

### Moments
Expand Down
Loading

0 comments on commit 2725f45

Please sign in to comment.