diff --git a/autograd/numpy/fft.py b/autograd/numpy/fft.py index 5746dd3c..c8279f84 100644 --- a/autograd/numpy/fft.py +++ b/autograd/numpy/fft.py @@ -2,7 +2,7 @@ from builtins import zip import numpy.fft as ffto from .numpy_wrapper import wrap_namespace -from .numpy_vjps import match_complex +from .util import match_complex from . import numpy_wrapper as anp from autograd.extend import primitive, defvjp, vspace diff --git a/autograd/numpy/numpy_jvps.py b/autograd/numpy/numpy_jvps.py index a152bf27..29d26502 100644 --- a/autograd/numpy/numpy_jvps.py +++ b/autograd/numpy/numpy_jvps.py @@ -1,15 +1,19 @@ from . import numpy_wrapper as anp -from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, - dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, - tensordot_adjoint_1, nograd_functions) +from .numpy_vjps import (untake, balanced_eq, replace_zero, dot_adjoint_0, dot_adjoint_1, + tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace) +from .util import def_ufunc_jps, def_ufunc_jps_inv_pair + from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) +defjvp(anp.broadcast_to, 'same') +defjvp(anp._broadcast_to_adjoint, 'same') + defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') @@ -18,47 +22,74 @@ lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- + defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) -# ----- Binary ufuncs (linear) ----- -def_linear(anp.multiply) +# ----- Unary ufuncs ------ + +def_ufunc_jps(anp.negative, 'same') +def_ufunc_jps(anp.rad2deg, 'same') +def_ufunc_jps(anp.degrees, 'same') +def_ufunc_jps(anp.deg2rad, 'same') +def_ufunc_jps(anp.radians, 'same') +def_ufunc_jps(anp.abs, + (lambda ans, x: replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.), 'cmul')) +def_ufunc_jps(anp.fabs, (lambda ans, x: anp.sign(x), 'mul')) # fabs doesn't take complex numbers. +def_ufunc_jps(anp.absolute, (lambda ans, x: anp.conj(x) / ans, 'cmul')) +def_ufunc_jps(anp.reciprocal, (lambda ans, x: -ans**2, 'mul' )) +def_ufunc_jps(anp.log10, (lambda ans, x: x * anp.log(10), 'div' )) +def_ufunc_jps(anp.sin, (lambda ans, x: anp.cos(x), 'mul' )) +def_ufunc_jps(anp.cos, (lambda ans, x: -anp.sin(x), 'mul' )) +def_ufunc_jps(anp.arcsin, (lambda ans, x: anp.sqrt(1 - x**2), 'div' )) +def_ufunc_jps(anp.arccos, (lambda ans, x:-anp.sqrt(1 - x**2), 'div' )) +def_ufunc_jps(anp.cosh, (lambda ans, x: anp.sinh(x), 'mul' )) +def_ufunc_jps(anp.arccosh, (lambda ans, x: anp.sqrt(x**2 - 1), 'div' )) +def_ufunc_jps(anp.sinc, (lambda ans, x: (anp.cos(anp.pi*x)-ans)/x, 'mul' )) +def_ufunc_jps(anp.real_if_close, 'cid') +def_ufunc_jps(anp.real, 'cid') +def_ufunc_jps(anp.imag, (lambda ans, x: -1j, 'cmul')) +def_ufunc_jps(anp.conj, 'same') +def_ufunc_jps(anp.conjugate, 'same') +def_ufunc_jps(anp.angle, (lambda ans, x: anp.conj(x * 1j)/anp.abs(x)**2, 'cmul')) + +def_ufunc_jps_inv_pair(anp.exp, anp.log, lambda ans, x: ans) +def_ufunc_jps_inv_pair(anp.exp2, anp.log2, lambda ans, x: ans * anp.log(2)) +def_ufunc_jps_inv_pair(anp.expm1, anp.log1p, lambda ans, x: ans + 1) +def_ufunc_jps_inv_pair(anp.tan, anp.arctan, lambda ans, x: 1 + ans**2) +def_ufunc_jps_inv_pair(anp.tanh, anp.arctanh, lambda ans, x: 1 - ans**2) +def_ufunc_jps_inv_pair(anp.sinh, anp.arcsinh, lambda ans, x: anp.sqrt(ans**2 + 1)) +def_ufunc_jps_inv_pair(anp.square, anp.sqrt, lambda ans, x: 2 * x) # ----- Binary ufuncs ----- -defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : broadcast(g, ans)) -defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : broadcast(-g, ans)) -defjvp(anp.divide, 'same', - lambda g, ans, x, y : - g * x / y**2) -defjvp(anp.maximum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.minimum, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.fmax, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.fmin, lambda g, ans, x, y : g * balanced_eq(x, ans, y), - lambda g, ans, x, y : g * balanced_eq(y, ans, x)) -defjvp(anp.logaddexp, lambda g, ans, x, y : g * anp.exp(x-ans), - lambda g, ans, x, y : g * anp.exp(y-ans)) -defjvp(anp.logaddexp2, lambda g, ans, x, y : g * 2**(x-ans), - lambda g, ans, x, y : g * 2**(y-ans)) -defjvp(anp.true_divide,'same', - lambda g, ans, x, y : - g * x / y**2) -defjvp(anp.mod, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : -g * anp.floor(x/y)) -defjvp(anp.remainder, lambda g, ans, x, y : broadcast(g, ans), - lambda g, ans, x, y : -g * anp.floor(x/y)) -defjvp(anp.power, lambda g, ans, x, y : g * y * x ** anp.where(y, y - 1, 1.), - lambda g, ans, x, y : g * anp.log(replace_zero(x, 1.)) * x ** y) -defjvp(anp.arctan2, lambda g, ans, x, y : g * y / (x**2 + y**2), - lambda g, ans, x, y : g * -x / (x**2 + y**2)) + +def_ufunc_jps(anp.add, 'id', 'id') +def_ufunc_jps(anp.subtract, 'id', 'neg') +def_ufunc_jps(anp.multiply, 'same', 'same') +def_ufunc_jps(anp.divide, 'same', (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_jps(anp.maximum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.minimum, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.fmax, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.fmin, (lambda ans, x, y: balanced_eq(x, ans, y), 'mul'), + (lambda ans, x, y: balanced_eq(y, ans, x), 'mul')) +def_ufunc_jps(anp.logaddexp, (lambda ans, x, y: anp.exp(x-ans), 'mul'), + (lambda ans, x, y: anp.exp(y-ans), 'mul')) +def_ufunc_jps(anp.logaddexp2, (lambda ans, x, y: 2**(x-ans), 'mul'), + (lambda ans, x, y: 2**(y-ans), 'mul')) +def_ufunc_jps(anp.true_divide, 'same', (lambda ans, x, y: -ans/y, 'mul')) +def_ufunc_jps(anp.mod, 'id', (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_jps(anp.remainder, 'id', (lambda ans, x, y: -anp.floor(x/y), 'mul')) +def_ufunc_jps(anp.power, (lambda ans, x, y: y * x ** anp.where(y, y - 1, 1.), 'mul'), + (lambda ans, x, y: anp.log(replace_zero(x, 1.)) * x ** y, 'mul')) +def_ufunc_jps(anp.hypot, (lambda ans, x, y: x / ans, 'mul'), + (lambda ans, x, y: y / ans, 'mul')) +def_ufunc_jps(anp.arctan2, (lambda ans, x, y: y / (x**2 + y**2), 'mul'), + (lambda ans, x, y:-x / (x**2 + y**2), 'mul')) # ----- Simple grads (linear) ----- -defjvp(anp.negative, 'same') -defjvp(anp.rad2deg, 'same') -defjvp(anp.degrees, 'same') -defjvp(anp.deg2rad, 'same') -defjvp(anp.radians, 'same') + defjvp(anp.reshape, 'same') defjvp(anp.roll, 'same') defjvp(anp.array_split, 'same') @@ -85,44 +116,14 @@ def_linear(anp.cross) # ----- Simple grads ----- -defjvp(anp.abs, - lambda g, ans, x : anp.real(g * replace_zero(anp.conj(x), 0.)) / replace_zero(ans, 1.)) -defjvp(anp.fabs, lambda g, ans, x : anp.sign(x) * g) # fabs doesn't take complex numbers. -defjvp(anp.absolute, lambda g, ans, x : anp.real(g * anp.conj(x)) / ans) -defjvp(anp.reciprocal, lambda g, ans, x : - g / x**2) -defjvp(anp.exp, lambda g, ans, x : ans * g) -defjvp(anp.exp2, lambda g, ans, x : ans * anp.log(2) * g) -defjvp(anp.expm1, lambda g, ans, x : (ans + 1) * g) -defjvp(anp.log, lambda g, ans, x : g / x) -defjvp(anp.log2, lambda g, ans, x : g / x / anp.log(2)) -defjvp(anp.log10, lambda g, ans, x : g / x / anp.log(10)) -defjvp(anp.log1p, lambda g, ans, x : g / (x + 1)) -defjvp(anp.sin, lambda g, ans, x : g * anp.cos(x)) -defjvp(anp.cos, lambda g, ans, x : - g * anp.sin(x)) -defjvp(anp.tan, lambda g, ans, x : g / anp.cos(x) **2) -defjvp(anp.arcsin, lambda g, ans, x : g / anp.sqrt(1 - x**2)) -defjvp(anp.arccos, lambda g, ans, x :-g / anp.sqrt(1 - x**2)) -defjvp(anp.arctan, lambda g, ans, x : g / (1 + x**2)) -defjvp(anp.sinh, lambda g, ans, x : g * anp.cosh(x)) -defjvp(anp.cosh, lambda g, ans, x : g * anp.sinh(x)) -defjvp(anp.tanh, lambda g, ans, x : g / anp.cosh(x) **2) -defjvp(anp.arcsinh, lambda g, ans, x : g / anp.sqrt(x**2 + 1)) -defjvp(anp.arccosh, lambda g, ans, x : g / anp.sqrt(x**2 - 1)) -defjvp(anp.arctanh, lambda g, ans, x : g / (1 - x**2)) -defjvp(anp.square, lambda g, ans, x : g * 2 * x) -defjvp(anp.sqrt, lambda g, ans, x : g * 0.5 * x**-0.5) -defjvp(anp.sinc, lambda g, ans, x : g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2)) + defjvp(anp.clip, lambda g, ans, x, a_min, a_max : g * anp.logical_and(ans != a_min, ans != a_max)) -defjvp(anp.real_if_close, lambda g, ans, x : match_complex(ans, g)) -defjvp(anp.real, lambda g, ans, x : anp.real(g)) -defjvp(anp.imag, lambda g, ans, x : match_complex(ans, -1j * g)) -defjvp(anp.conj, lambda g, ans, x : anp.conj(g)) -defjvp(anp.angle, lambda g, ans, x : match_complex(ans, g * anp.conj(x * 1j) / anp.abs(x)**2)) defjvp(anp.where, None, lambda g, ans, c, x=None, y=None : anp.where(c, g, anp.zeros(anp.shape(g))), lambda g, ans, c, x=None, y=None : anp.where(c, anp.zeros(g.shape), g)) # ----- Trickier grads ----- + defjvp(anp.kron, 'same', 'same') defjvp(anp.diff, 'same') defjvp(anp.repeat, 'same') @@ -226,15 +227,3 @@ def jvp(g, ans, *arys): defjvp(anp.atleast_3d, atleast_jvpmaker(anp.atleast_3d)) def_linear(anp.einsum) - -# TODO(mattjj): can we call np.broadcast_to or a related function instead? -def broadcast(x, target): - target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target) - while anp.ndim(x) < target_ndim: - x = anp.expand_dims(x, 0) - for axis, size in enumerate(anp.shape(x)): - if size == 1: - x = anp.repeat(x, target_shape[axis], axis=axis) - if target_iscomplex and not anp.iscomplexobj(x): - x = x + 0j # TODO(mattjj): this might promote the dtype - return x diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index 235a4218..e80a6093 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -4,9 +4,10 @@ import numpy as onp from ..util import func from . import numpy_wrapper as anp +from autograd.numpy.util import unbroadcast from .numpy_boxes import ArrayBox -from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum, - SparseObject, VJPNode, register_notrace) +from autograd.extend import (primitive, vspace, defvjp, defvjp_argnum, SparseObject, VJPNode, + register_notrace) # ----- Non-differentiable functions ----- @@ -27,77 +28,8 @@ defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.)) -# ----- Binary ufuncs ----- - -defvjp(anp.add, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: g)) -defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g), - lambda ans, x, y : unbroadcast_f(y, lambda g: x * g)) -defvjp(anp.subtract, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: -g)) -defvjp(anp.divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y), - lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2)) -defvjp(anp.maximum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.minimum, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.fmax, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.fmin, lambda ans, x, y : unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x))) -defvjp(anp.logaddexp, lambda ans, x, y : unbroadcast_f(x, lambda g: g * anp.exp(x-ans)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.exp(y-ans))) -defvjp(anp.logaddexp2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * 2**(x-ans)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * 2**(y-ans))) -defvjp(anp.true_divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y), - lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2)) -defvjp(anp.mod, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y))) -defvjp(anp.remainder, lambda ans, x, y : unbroadcast_f(x, lambda g: g), - lambda ans, x, y : unbroadcast_f(y, lambda g: -g * anp.floor(x/y))) -defvjp(anp.power, - lambda ans, x, y : unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.)) * x ** y)) -defvjp(anp.arctan2, lambda ans, x, y : unbroadcast_f(x, lambda g: g * y / (x**2 + y**2)), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * -x / (x**2 + y**2))) -defvjp(anp.hypot, - lambda ans, x, y : unbroadcast_f(x, lambda g: g * x / ans), - lambda ans, x, y : unbroadcast_f(y, lambda g: g * y / ans)) - # ----- Simple grads ----- -defvjp(anp.negative, lambda ans, x: lambda g: -g) -defvjp(anp.abs, - lambda ans, x : lambda g: g * replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.)) -defvjp(anp.fabs, lambda ans, x : lambda g: anp.sign(x) * g) # fabs doesn't take complex numbers. -defvjp(anp.absolute, lambda ans, x : lambda g: g * anp.conj(x) / ans) -defvjp(anp.reciprocal, lambda ans, x : lambda g: - g / x**2) -defvjp(anp.exp, lambda ans, x : lambda g: ans * g) -defvjp(anp.exp2, lambda ans, x : lambda g: ans * anp.log(2) * g) -defvjp(anp.expm1, lambda ans, x : lambda g: (ans + 1) * g) -defvjp(anp.log, lambda ans, x : lambda g: g / x) -defvjp(anp.log2, lambda ans, x : lambda g: g / x / anp.log(2)) -defvjp(anp.log10, lambda ans, x : lambda g: g / x / anp.log(10)) -defvjp(anp.log1p, lambda ans, x : lambda g: g / (x + 1)) -defvjp(anp.sin, lambda ans, x : lambda g: g * anp.cos(x)) -defvjp(anp.cos, lambda ans, x : lambda g: - g * anp.sin(x)) -defvjp(anp.tan, lambda ans, x : lambda g: g / anp.cos(x) **2) -defvjp(anp.arcsin, lambda ans, x : lambda g: g / anp.sqrt(1 - x**2)) -defvjp(anp.arccos, lambda ans, x : lambda g:-g / anp.sqrt(1 - x**2)) -defvjp(anp.arctan, lambda ans, x : lambda g: g / (1 + x**2)) -defvjp(anp.sinh, lambda ans, x : lambda g: g * anp.cosh(x)) -defvjp(anp.cosh, lambda ans, x : lambda g: g * anp.sinh(x)) -defvjp(anp.tanh, lambda ans, x : lambda g: g / anp.cosh(x) **2) -defvjp(anp.arcsinh, lambda ans, x : lambda g: g / anp.sqrt(x**2 + 1)) -defvjp(anp.arccosh, lambda ans, x : lambda g: g / anp.sqrt(x**2 - 1)) -defvjp(anp.arctanh, lambda ans, x : lambda g: g / (1 - x**2)) -defvjp(anp.rad2deg, lambda ans, x : lambda g: g / anp.pi * 180.0) -defvjp(anp.degrees, lambda ans, x : lambda g: g / anp.pi * 180.0) -defvjp(anp.deg2rad, lambda ans, x : lambda g: g * anp.pi / 180.0) -defvjp(anp.radians, lambda ans, x : lambda g: g * anp.pi / 180.0) -defvjp(anp.square, lambda ans, x : lambda g: g * 2 * x) -defvjp(anp.sqrt, lambda ans, x : lambda g: g * 0.5 * x**-0.5) -defvjp(anp.sinc, lambda ans, x : lambda g: g * (anp.cos(anp.pi*x)*anp.pi*x - anp.sin(anp.pi*x))/(anp.pi*x**2)) defvjp(anp.reshape, lambda ans, x, shape, order=None : lambda g: anp.reshape(g, anp.shape(x), order=order)) defvjp(anp.roll, lambda ans, x, shift, axis=None : lambda g: anp.roll(g, -shift, axis=axis)) defvjp(anp.array_split, lambda ans, ary, idxs, axis=0 : lambda g: anp.concatenate(g, axis=axis)) @@ -123,12 +55,6 @@ anp.moveaxis(g, destination, source)) defvjp(anp.rollaxis, lambda ans, a, axis, start=0: lambda g: anp.rollaxis(g, start - 1, axis) if start > axis else anp.rollaxis(g, start, axis + 1)) -defvjp(anp.real_if_close, lambda ans, x : lambda g: match_complex(x, g)) -defvjp(anp.real, lambda ans, x : lambda g: match_complex(x, g)) -defvjp(anp.imag, lambda ans, x : lambda g: match_complex(x, -1j * g)) -defvjp(anp.conj, lambda ans, x : lambda g: anp.conj(g)) -defvjp(anp.conjugate, lambda ans, x: lambda g: anp.conj(g)) -defvjp(anp.angle, lambda ans, x : lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x)**2)) defvjp(anp.where, None, lambda ans, c, x=None, y=None : lambda g: anp.where(c, g, anp.zeros(g.shape)), lambda ans, c, x=None, y=None : lambda g: anp.where(c, anp.zeros(g.shape), g)) @@ -541,31 +467,6 @@ def vjp(g): lambda ans, D, offset=0, axis1=0, axis2=1 : lambda g: anp.diagonal(g, offset, axis1, axis2)) -def match_complex(target, x): - target_iscomplex = anp.iscomplexobj(target) - x_iscomplex = anp.iscomplexobj(x) - if x_iscomplex and not target_iscomplex: - return anp.real(x) - elif not x_iscomplex and target_iscomplex: - return x + 0j - else: - return x - -def unbroadcast(x, target_meta, broadcast_idx=0): - target_shape, target_ndim, dtype, target_iscomplex = target_meta - while anp.ndim(x) > target_ndim: - x = anp.sum(x, axis=broadcast_idx) - for axis, size in enumerate(target_shape): - if size == 1: - x = anp.sum(x, axis=axis, keepdims=True) - if anp.iscomplexobj(x) and not target_iscomplex: - x = anp.real(x) - return x - -def unbroadcast_f(target, f): - target_meta = anp.metadata(target) - return lambda g: unbroadcast(f(g), target_meta) - def unbroadcast_einsum(x, target_meta, subscript): if Ellipsis not in subscript: return x @@ -576,6 +477,19 @@ def unbroadcast_einsum(x, target_meta, subscript): else: return unbroadcast(x, target_meta, subscript.index(Ellipsis)) +def _broadcast_to_vjpmaker(x_shape): + # Ensure that x can be garbage collected by only passing + # its shape to this closure. + return lambda g: anp._broadcast_to_adjoint(g, x_shape) + +def _broadcast_to_adjoint_vjpmaker(g_shape): + # Ensure that g can be garbage collected by only passing + # its shape to this closure. + return lambda x: anp.broadcast_to(x, g_shape) + +defvjp(anp.broadcast_to, lambda ans, x, ans_shp: _broadcast_to_vjpmaker(x.shape)) +defvjp(anp._broadcast_to_adjoint, lambda ans, g, ans_shp: _broadcast_to_adjoint_vjpmaker(g.shape)) + def balanced_eq(x, z, y): return (x == z) / (1.0 + (x == y)) diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index f14d5308..b8feb631 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -153,6 +153,15 @@ def metadata(A): def parse_einsum_input(*args): return _parse_einsum_input(args) +@primitive +def _broadcast_to_adjoint(x, shape): + while _np.ndim(x) > len(shape): + x = _np.sum(x, axis=0) + for axis, size in enumerate(shape): + if size == 1: + x = _np.sum(x, axis=axis, keepdims=True) + return x + @primitive def _astype(A, dtype, order='K', casting='unsafe', subok=True, copy=True): - return A.astype(dtype, order, casting, subok, copy) + return A.astype(dtype, order, casting, subok, copy) diff --git a/autograd/numpy/util.py b/autograd/numpy/util.py new file mode 100644 index 00000000..665f2580 --- /dev/null +++ b/autograd/numpy/util.py @@ -0,0 +1,173 @@ +from . import numpy_wrapper as anp +from autograd.core import defjvp, defvjp +from autograd.util import subval + +def match_complex(target, x): + target_iscomplex = anp.iscomplexobj(target) + x_iscomplex = anp.iscomplexobj(x) + if x_iscomplex and not target_iscomplex: + return anp.real(x) + elif not x_iscomplex and target_iscomplex: + return x + 0j + else: + return x + +def unbroadcast(x, target_meta, broadcast_idx=0): + target_shape, target_ndim, dtype, target_iscomplex = target_meta + x = anp._broadcast_to_adjoint(x, target_shape) + if anp.iscomplexobj(x) and not target_iscomplex: + x = anp.real(x) + return x + +def unbroadcast_f(target, f): + target_meta = anp.metadata(target) + return lambda g: unbroadcast(f(g), target_meta) + +def def_unary_ufunc_jps(ufunc, deriv_op): + jps = { + 'same': (lambda g, ans, x: ufunc(g), + lambda ans, x: ufunc), + 'cid': (lambda g, ans, x: match_complex(ans, g), + lambda ans, x: lambda g: match_complex(x , g)) + } + + linops = { + 'mul' : (lambda deriv: lambda g, ans, x: g * deriv(ans, x), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g * d), + 'div' : (lambda deriv: lambda g, ans, x: g / deriv(ans, x), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): g / d), + 'cmul': (lambda deriv: lambda g, ans, x: match_complex(ans, g * deriv(ans, x)), + lambda deriv: lambda ans, x: lambda g, d=deriv(ans, x): match_complex(x, g * d)), + } + + if type(deriv_op) is tuple: + deriv, op = deriv_op + defjvp(ufunc, linops[op][0](deriv)) + defvjp(ufunc, linops[op][1](deriv)) + elif deriv_op is None: + defjvp(ufunc, None) + defvjp(ufunc, None) + else: + defjvp(ufunc, jps[deriv_op][0]) + defvjp(ufunc, jps[deriv_op][1]) + +def def_nary_ufunc_jps(ufunc, derivs_ops): + jps = { + 'same': (lambda argnum: lambda g, ans, *args: ufunc(*subval(args, argnum, g)), + lambda argnum: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g: ufunc(*subval(args, argnum, g)))), + 'id': (lambda argnum: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(g, ans.shape)), + lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: g)), + 'neg': (lambda argnum: lambda g, ans, *args: match_complex(ans, anp.broadcast_to(-g, ans.shape)), + lambda argnum: lambda ans, *args: unbroadcast_f(args[argnum], lambda g: -g)) + } + + linops = { + 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), + 'div': (lambda argnum, deriv: lambda g, ans, *args: g / deriv(ans, *args), + lambda argnum, deriv: lambda ans, *args: + unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g / d)) + } + + def deriv_op_to_jp(idx, argnum, deriv_op): + if type(deriv_op) is tuple: + deriv, op = deriv_op + return linops[op][idx](argnum, deriv) + elif deriv_op is None: + return None + else: + return jps[deriv_op][idx](argnum) + + defjvp(ufunc, *[deriv_op_to_jp(0, argnum, deriv_op) + for argnum, deriv_op in enumerate(derivs_ops)]) + defvjp(ufunc, *[deriv_op_to_jp(1, argnum, deriv_op) + for argnum, deriv_op in enumerate(derivs_ops)]) + +def def_ufunc_jps(ufunc, *derivs_ops): + """ + Specify the derivatives of ufunc. Once this has been done the ufunc will + support both reverse and forward mode differentiation. + + The derivatives can be specified as follows. + + Unary ufuncs + ------------ + If the ufunc is unary (that is, if it takes one array valued argument), + then a single optional argument is required to specify the ufunc's + derivative. + + In the general case, this is done via a pair (deriv, op), where deriv is a + function taking in the output of the ufunc (ans), and its array argument + (x), and returning the derivative of the ufunc. + + Here 'derivative' means the elementwise derivative of the ufunc w.r.t. it's + input. + + For example, for the ufunc np.sin, this is as simple as + >>> def deriv(ans, x): + ... return np.cos(x) + ... + + Sometimes the output of the ufunc is useful, for example the derivative of + np.exp is np.exp, which is identical to ans, so the derivative of np.exp + can be efficiently implemented as + >>> def deriv(ans, x): + ... return ans + ... + + The other element of the pair is `op`, which should usually be set to + 'mul'. However, if the derivative of the ufunc is of the form + 1 / f(ans, x), then you can save some computation by using the pair + (f, 'div') to specify the derivative. The 'div' flags that the gradients + being propagated through this primitive should be divided by the result of + f, not multiplied. + + Some full examples: + >>> def_ufunc_jps(np.sin, (lambda ans, x: np.cos(x), 'mul')) + >>> def_ufunc_jps(np.exp, (lambda ans, x: ans, 'mul')) + >>> def_ufunc_jps(np.log, (lambda ans, x: x, 'div')) + + Special cases + ------------- + If the derivative of the ufunc is a constant, then you don't need to + specify its derivative and you can use just the string 'same' in place of + the pair (deriv, op). This says that its ok to propagate the gradient + through this primitive by applying the ufunc itself to the gradient, and + neither x nor ans are relevant to this computation. + + For example, the derivative of np.negative (which simply negates its + inputs), is -1, so + >>> def_ufunc_jps(np.negative, 'same') + + will correctly set its derivative. + + N-ary ufuncs + ------------ + For ufuncs which take more than one array argument, the derivatives can be + specified by passing one (deriv, op) pair for each argument (you can use + None as a placeholder for args whose derivative you don't wish to define). + + You can use 'same' in exactly the same way as for unary ufuncs, and + additionally you can use 'id' when the derivative w.r.t. an arg is always + equal to 1, and 'neg' when it's always equal to -1. + + Some examples: + >>> def_ufunc_jps(anp.divide, 'same', (lambda ans, x, y: -ans/y, 'mul')) + >>> def_ufunc_jps(anp.add, 'id', 'id') + >>> def_ufunc_jps(anp.subtract, 'id', 'neg') + """ + derivs_ops = list(derivs_ops) + if len(derivs_ops) == 1: + def_unary_ufunc_jps(ufunc, derivs_ops[0]) + elif len(derivs_ops) > 1: + def_nary_ufunc_jps(ufunc, derivs_ops) + +def def_ufunc_jps_inv_pair(ufunc, ufunc_inv, deriv): + """ + Define the derivatives for an inverse pair of unary ufuncs. deriv must be + the derivative of the first ufunc. + """ + def_ufunc_jps(ufunc, (deriv, 'mul')) + def_ufunc_jps(ufunc_inv, (lambda ans, x: deriv(x, ans), 'div')) diff --git a/autograd/scipy/special.py b/autograd/scipy/special.py index 6308f880..6abef774 100644 --- a/autograd/scipy/special.py +++ b/autograd/scipy/special.py @@ -1,23 +1,24 @@ from __future__ import absolute_import import scipy.special import autograd.numpy as np -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import def_ufunc_jps, def_ufunc_jps_inv_pair +from autograd.extend import primitive ### Beta function ### beta = primitive(scipy.special.beta) betainc = primitive(scipy.special.betainc) betaln = primitive(scipy.special.betaln) -defvjp(beta, - lambda ans, a, b: unbroadcast_f(a, lambda g: g * ans * (psi(a) - psi(a + b))), - lambda ans, a, b: unbroadcast_f(b, lambda g: g * ans * (psi(b) - psi(a + b)))) -defvjp(betainc, - lambda ans, a, b, x: unbroadcast_f(x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)), - argnums=[2]) -defvjp(betaln, - lambda ans, a, b: unbroadcast_f(a, lambda g: g * (psi(a) - psi(a + b))), - lambda ans, a, b: unbroadcast_f(b, lambda g: g * (psi(b) - psi(a + b)))) +def_ufunc_jps(beta, + (lambda ans, a, b: ans * (psi(a) - psi(a + b)), 'mul'), + (lambda ans, a, b: ans * (psi(b) - psi(a + b)), 'mul')) +def_ufunc_jps(betainc, + None, + None, + (lambda ans, a, b, x: np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b), 'mul')) +def_ufunc_jps(betaln, + (lambda ans, a, b: psi(a) - psi(a + b), 'mul'), + (lambda ans, a, b: psi(b) - psi(a + b), 'mul')) ### Gamma functions ### polygamma = primitive(scipy.special.polygamma) @@ -31,24 +32,23 @@ rgamma = primitive(scipy.special.rgamma) multigammaln = primitive(scipy.special.multigammaln) -defvjp(gammasgn, None) -defvjp(polygamma, None, lambda ans, n, x: lambda g: g * polygamma(n + 1, x)) -defvjp(psi, lambda ans, x: lambda g: g * polygamma(1, x)) -defvjp(digamma, lambda ans, x: lambda g: g * polygamma(1, x)) -defvjp(gamma, lambda ans, x: lambda g: g * ans * psi(x)) -defvjp(gammaln, lambda ans, x: lambda g: g * psi(x)) -defvjp(rgamma, lambda ans, x: lambda g: g * psi(x) / -gamma(x)) -defvjp(multigammaln,lambda ans, a, d: lambda g: - g * np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1), +def_ufunc_jps(gammasgn, None) +def_ufunc_jps(polygamma, None, (lambda ans, n, x: polygamma(n + 1, x), 'mul')) +def_ufunc_jps(psi, (lambda ans, x: polygamma(1, x), 'mul')) +def_ufunc_jps(digamma, (lambda ans, x: polygamma(1, x), 'mul')) +def_ufunc_jps(gamma, (lambda ans, x: ans * psi(x), 'mul')) +def_ufunc_jps(gammaln, (lambda ans, x: psi(x), 'mul')) +def_ufunc_jps(rgamma, (lambda ans, x: psi(x) / -gamma(x), 'mul')) +def_ufunc_jps(multigammaln, (lambda ans, a, d: + np.sum(digamma(np.expand_dims(a, -1) - np.arange(d)/2.), -1), 'mul'), None) def make_gammainc_vjp_arg1(sign): def gammainc_vjp_arg1(ans, a, x): - coeffs = sign * np.exp(-x) * np.power(x, a - 1) / gamma(a) - return unbroadcast_f(x, lambda g: g * coeffs) + return sign * np.exp(-x) * np.power(x, a - 1) / gamma(a) return gammainc_vjp_arg1 -defvjp(gammainc, make_gammainc_vjp_arg1(1), argnums=[1]) -defvjp(gammaincc, make_gammainc_vjp_arg1(-1), argnums=[1]) +def_ufunc_jps(gammainc, None, (make_gammainc_vjp_arg1(1), 'mul')) +def_ufunc_jps(gammaincc, None, (make_gammainc_vjp_arg1(-1), 'mul')) ### Bessel functions ### j0 = primitive(scipy.special.j0) @@ -58,33 +58,33 @@ def gammainc_vjp_arg1(ans, a, x): jn = primitive(scipy.special.jn) yn = primitive(scipy.special.yn) -defvjp(j0,lambda ans, x: lambda g: -g * j1(x)) -defvjp(y0,lambda ans, x: lambda g: -g * y1(x)) -defvjp(j1,lambda ans, x: lambda g: g * (j0(x) - jn(2, x)) / 2.0) -defvjp(y1,lambda ans, x: lambda g: g * (y0(x) - yn(2, x)) / 2.0) -defvjp(jn, None, lambda ans, n, x: lambda g: g * (jn(n - 1, x) - jn(n + 1, x)) / 2.0) -defvjp(yn, None, lambda ans, n, x: lambda g: g * (yn(n - 1, x) - yn(n + 1, x)) / 2.0) +def_ufunc_jps(j0, (lambda ans, x: -j1(x), 'mul')) +def_ufunc_jps(y0, (lambda ans, x: -y1(x), 'mul')) +def_ufunc_jps(j1, (lambda ans, x: (j0(x) - jn(2, x)) / 2.0, 'mul')) +def_ufunc_jps(y1, (lambda ans, x: (y0(x) - yn(2, x)) / 2.0, 'mul')) +def_ufunc_jps(jn, None, (lambda ans, n, x: (jn(n - 1, x) - jn(n + 1, x)) / 2.0, 'mul')) +def_ufunc_jps(yn, None, (lambda ans, n, x: (yn(n - 1, x) - yn(n + 1, x)) / 2.0, 'mul')) ### Error Function ### inv_root_pi = 0.56418958354775627928 erf = primitive(scipy.special.erf) erfc = primitive(scipy.special.erfc) -defvjp(erf, lambda ans, x: lambda g: 2.*g*inv_root_pi*np.exp(-x**2)) -defvjp(erfc,lambda ans, x: lambda g: -2.*g*inv_root_pi*np.exp(-x**2)) - - ### Inverse error function ### root_pi = 1.7724538509055159 erfinv = primitive(scipy.special.erfinv) erfcinv = primitive(scipy.special.erfcinv) -defvjp(erfinv,lambda ans, x: lambda g: g * root_pi / 2 * np.exp(erfinv(x)**2)) -defvjp(erfcinv,lambda ans, x: lambda g: -g * root_pi / 2 * np.exp(erfcinv(x)**2)) +def_ufunc_jps_inv_pair(erf, erfinv, lambda ans, x: 2.*inv_root_pi*np.exp(-x**2)) +def_ufunc_jps_inv_pair(erfc, erfcinv, lambda ans, x: -2.*inv_root_pi*np.exp(-x**2)) ### Logit and Expit ### logit = primitive(scipy.special.logit) expit = primitive(scipy.special.expit) -defvjp(logit,lambda ans, x: lambda g: g / ( x * (1 - x))) -defvjp(expit,lambda ans, x: lambda g: g * ans * (1 - ans)) +def_ufunc_jps_inv_pair(expit, logit, lambda ans, x: ans * (1 - ans)) + +### Relative entropy ### +rel_entr = primitive(scipy.special.rel_entr) + +def_ufunc_jps(rel_entr, (lambda ans, x, y: np.log(x / y) + 1, 'mul'), (lambda ans, x, y: - x / y, 'mul')) diff --git a/autograd/scipy/stats/beta.py b/autograd/scipy/stats/beta.py index a703ae6a..64a822d4 100644 --- a/autograd/scipy/stats/beta.py +++ b/autograd/scipy/stats/beta.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import beta, psi cdf = primitive(scipy.stats.beta.cdf) @@ -19,12 +19,15 @@ def grad_beta_logpdf_arg1(x, a, b): def grad_beta_logpdf_arg2(x, a, b): return np.log1p(-x) - psi(b) + psi(a + b) -defvjp(cdf, lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * np.power(x, a-1) * np.power(1-x, b-1) / beta(a, b)), argnums=[0]) -defvjp(logpdf, - lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * grad_beta_logpdf_arg0(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * grad_beta_logpdf_arg1(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * grad_beta_logpdf_arg2(x, a, b))) -defvjp(pdf, - lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * ans * grad_beta_logpdf_arg0(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * ans * grad_beta_logpdf_arg1(x, a, b)), - lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * ans * grad_beta_logpdf_arg2(x, a, b))) +def_ufunc_jps(cdf, + (lambda ans, x, a, b: np.power(x, a-1) * np.power(1-x, b-1) / beta(a, b), 'mul'), + None, + None) +def_ufunc_jps(logpdf, + (lambda ans, x, a, b: grad_beta_logpdf_arg0(x, a, b), 'mul'), + (lambda ans, x, a, b: grad_beta_logpdf_arg1(x, a, b), 'mul'), + (lambda ans, x, a, b: grad_beta_logpdf_arg2(x, a, b), 'mul')) +def_ufunc_jps(pdf, + (lambda ans, x, a, b: ans * grad_beta_logpdf_arg0(x, a, b), 'mul'), + (lambda ans, x, a, b: ans * grad_beta_logpdf_arg1(x, a, b), 'mul'), + (lambda ans, x, a, b: ans * grad_beta_logpdf_arg2(x, a, b), 'mul')) diff --git a/autograd/scipy/stats/chi2.py b/autograd/scipy/stats/chi2.py index 8555739a..f7d14b53 100644 --- a/autograd/scipy/stats/chi2.py +++ b/autograd/scipy/stats/chi2.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import gamma cdf = primitive(scipy.stats.chi2.cdf) @@ -13,6 +13,7 @@ def grad_chi2_logpdf(x, df): return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0) -defvjp(cdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * np.power(2., -df/2) * np.exp(-x/2) * np.power(x, df/2 - 1) / gamma(df/2)), argnums=[0]) -defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0]) -defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0]) +def_ufunc_jps(cdf, (lambda ans, x, df: (np.power(2., -df/2) * np.exp(-x/2) * + np.power(x, df/2 - 1) / gamma(df/2)), 'mul'), None) +def_ufunc_jps(logpdf, (lambda ans, x, df: (grad_chi2_logpdf(x, df)), 'mul'), None) +def_ufunc_jps(pdf, (lambda ans, x, df: (ans * grad_chi2_logpdf(x, df)), 'mul'), None) diff --git a/autograd/scipy/stats/gamma.py b/autograd/scipy/stats/gamma.py index 5b595099..56fd8547 100644 --- a/autograd/scipy/stats/gamma.py +++ b/autograd/scipy/stats/gamma.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import gamma, psi cdf = primitive(scipy.stats.gamma.cdf) @@ -16,10 +16,13 @@ def grad_gamma_logpdf_arg0(x, a): def grad_gamma_logpdf_arg1(x, a): return np.log(x) - psi(a) -defvjp(cdf, lambda ans, x, a: unbroadcast_f(x, lambda g: g * np.exp(-x) * np.power(x, a-1) / gamma(a)), argnums=[0]) -defvjp(logpdf, - lambda ans, x, a: unbroadcast_f(x, lambda g: g * grad_gamma_logpdf_arg0(x, a)), - lambda ans, x, a: unbroadcast_f(a, lambda g: g * grad_gamma_logpdf_arg1(x, a))) -defvjp(pdf, - lambda ans, x, a: unbroadcast_f(x, lambda g: g * ans * grad_gamma_logpdf_arg0(x, a)), - lambda ans, x, a: unbroadcast_f(a, lambda g: g * ans * grad_gamma_logpdf_arg1(x, a))) +def_ufunc_jps(cdf, + (lambda ans, x, a: np.exp(-x) * np.power(x, a-1) / gamma(a), 'mul'), + None, + None) +def_ufunc_jps(logpdf, + (lambda ans, x, a: grad_gamma_logpdf_arg0(x, a), 'mul'), + (lambda ans, x, a: grad_gamma_logpdf_arg1(x, a), 'mul')) +def_ufunc_jps(pdf, + (lambda ans, x, a: ans * grad_gamma_logpdf_arg0(x, a), 'mul'), + (lambda ans, x, a: ans * grad_gamma_logpdf_arg1(x, a), 'mul')) diff --git a/autograd/scipy/stats/multivariate_normal.py b/autograd/scipy/stats/multivariate_normal.py index d098ce1f..fa2524bf 100644 --- a/autograd/scipy/stats/multivariate_normal.py +++ b/autograd/scipy/stats/multivariate_normal.py @@ -2,7 +2,7 @@ import scipy.stats import autograd.numpy as np -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.numpy.util import unbroadcast_f from autograd.extend import primitive, defvjp diff --git a/autograd/scipy/stats/norm.py b/autograd/scipy/stats/norm.py index 5fa88a29..d999be72 100644 --- a/autograd/scipy/stats/norm.py +++ b/autograd/scipy/stats/norm.py @@ -2,42 +2,30 @@ from __future__ import absolute_import import scipy.stats import autograd.numpy as anp -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps pdf = primitive(scipy.stats.norm.pdf) cdf = primitive(scipy.stats.norm.cdf) logpdf = primitive(scipy.stats.norm.logpdf) logcdf = primitive(scipy.stats.norm.logcdf) -defvjp(pdf, - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * ans * (((x - loc)/scale)**2 - 1.0)/scale)) +def_ufunc_jps(pdf, + (lambda ans, x, loc=0.0, scale=1.0: -ans * (x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: ans * (x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: ans * (((x - loc)/scale)**2 - 1.0)/scale, 'mul')) -defvjp(cdf, - lambda ans, x, loc=-1.0, scale=1.0: - unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)) , - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)), - lambda ans, x, loc=-1.0, scale=1.0: - unbroadcast_f(scale, lambda g: -g * pdf(x, loc, scale)*(x-loc)/scale)) +def_ufunc_jps(logpdf, + (lambda ans, x, loc=0.0, scale=1.0: -(x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: (x - loc) / scale**2, 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: -1.0/scale + (x - loc)**2/scale**3, 'mul')) -defvjp(logpdf, - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2), - lambda ans, x, loc=-1.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * (-1.0/scale + (x - loc)**2/scale**3))) +def_ufunc_jps(cdf, + (lambda ans, x, loc=0.0, scale=1.0: pdf(x, loc, scale), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: -pdf(x, loc, scale), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0: -pdf(x, loc, scale)*(x-loc)/scale, 'mul')) -defvjp(logcdf, - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), - lambda ans, x, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale)) +def_ufunc_jps(logcdf, + (lambda ans, x, loc=0.0, scale=1.0: anp.exp(logpdf(x, loc, scale) - ans), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - ans), 'mul'), + (lambda ans, x, loc=0.0, scale=1.0:-anp.exp(logpdf(x, loc, scale) - ans)*(x-loc)/scale, 'mul')) diff --git a/autograd/scipy/stats/poisson.py b/autograd/scipy/stats/poisson.py index 381c32c1..c33d40a4 100644 --- a/autograd/scipy/stats/poisson.py +++ b/autograd/scipy/stats/poisson.py @@ -2,8 +2,8 @@ import autograd.numpy as np import scipy.stats -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps cdf = primitive(scipy.stats.poisson.cdf) logpmf = primitive(scipy.stats.poisson.logpmf) @@ -12,6 +12,6 @@ def grad_poisson_logpmf(k, mu): return np.where(k % 1 == 0, k / mu - 1, 0) -defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1]) -defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1]) -defvjp(pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1]) +def_ufunc_jps(cdf, None, (lambda ans, k, mu: -pmf(np.floor(k), mu), 'mul')) +def_ufunc_jps(logpmf, None, (lambda ans, k, mu: grad_poisson_logpmf(k, mu), 'mul')) +def_ufunc_jps(pmf, None, (lambda ans, k, mu: ans * grad_poisson_logpmf(k, mu), 'mul')) diff --git a/autograd/scipy/stats/t.py b/autograd/scipy/stats/t.py index 66453aa5..763d9ac8 100644 --- a/autograd/scipy/stats/t.py +++ b/autograd/scipy/stats/t.py @@ -2,8 +2,8 @@ from __future__ import absolute_import import scipy.stats import autograd.numpy as np -from autograd.extend import primitive, defvjp -from autograd.numpy.numpy_vjps import unbroadcast_f +from autograd.extend import primitive +from autograd.numpy.util import def_ufunc_jps from autograd.scipy.special import psi pdf = primitive(scipy.stats.t.pdf) @@ -24,34 +24,24 @@ def grad_tlogpdf_df(x, df, loc, scale): y = (x - loc)/scale return 0.5 * ((y**2 * (df+1))/(df * (y**2 + df)) - np.log(y**2 / df + 1) - 1.0/df -psi(df/2.0) + psi((df + 1)/2.0)) -defvjp(pdf, lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * ans * grad_tlogpdf_x( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(df, lambda g: g * ans * grad_tlogpdf_df( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * ans * grad_tlogpdf_loc( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * ans * grad_tlogpdf_scale(x, df, loc, scale))) +def_ufunc_jps(pdf, + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_x( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_df( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_loc( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: ans * grad_tlogpdf_scale(x, df, loc, scale), 'mul')) -defvjp(cdf, - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * pdf(x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: -g * pdf(x, df, loc, scale)), argnums=(0,2)) +def_ufunc_jps(logpdf, + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_x( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_df( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_loc( x, df, loc, scale), 'mul'), + (lambda ans, x, df, loc=0.0, scale=1.0: grad_tlogpdf_scale(x, df, loc, scale), 'mul')) -defvjp(logpdf, - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * grad_tlogpdf_x( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(df, lambda g: g * grad_tlogpdf_df( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: g * grad_tlogpdf_loc( x, df, loc, scale)), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(scale, lambda g: g * grad_tlogpdf_scale(x, df, loc, scale))) +def_ufunc_jps(cdf, + (lambda ans, x, df, loc=0.0, scale=1.0: pdf(x, df, loc, scale), 'mul'), + None, + (lambda ans, x, df, loc=0.0, scale=1.0: -pdf(x, df, loc, scale), 'mul')) -defvjp(logcdf, - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(x, lambda g: g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))), - lambda ans, x, df, loc=0.0, scale=1.0: - unbroadcast_f(loc, lambda g: -g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))), - argnums=(0,2)) +def_ufunc_jps(logcdf, + (lambda ans, x, df, loc=0.0, scale=1.0: np.exp(logpdf(x, df, loc, scale) - ans), 'mul'), + None, + (lambda ans, x, df, loc=0.0, scale=1.0: -np.exp(logpdf(x, df, loc, scale) - ans), 'mul')) diff --git a/benchmarks/bench_numpy_vjps.py b/benchmarks/bench_numpy_vjps.py index 89645093..0d1b4f80 100644 --- a/benchmarks/bench_numpy_vjps.py +++ b/benchmarks/bench_numpy_vjps.py @@ -81,3 +81,13 @@ def time_tensordot_1_1(): def time_tensordot_1_2(): tensordot_1_2(A, B, G) +C = np.random.randn(200, 200, 5, 4) +D = np.random.randn(1, 1, 5, 4) +add_0 = lambda C, D, G: make_vjp(np.add, argnum=0)(C, D)[0](G) +tanh_0 = lambda C, G: make_vjp(np.tanh, argnum=0)(C)[0](G) + +def time_add_0(): + add_0(C, D, C) + +def time_tanh_0(): + tanh_0(C, C) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index cfa7fda5..16f0a069 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -21,7 +21,7 @@ from scipy.signal import convolve as sp_convolve from autograd.test_util import combo_check, check_grads - from numpy_utils import unary_ufunc_check + from numpy_utils import unary_ufunc_check, binary_ufunc_check npr.seed(1) R = npr.randn @@ -42,45 +42,45 @@ def symmetrized_fun(*args, **kwargs): return symmetrized_fun ### Stats ### - def test_chi2_pdf(): combo_check(stats.chi2.pdf, [0])([R(4)**2 + 1.1], [1, 2, 3]) - def test_chi2_cdf(): combo_check(stats.chi2.cdf, [0])([R(4)**2 + 1.1], [1, 2, 3]) - def test_chi2_logpdf(): combo_check(stats.chi2.logpdf, [0])([R(4)**2 + 1.1], [1, 2, 3]) - - def test_beta_cdf(): combo_check(stats.beta.cdf, [0]) ([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_beta_pdf(): combo_check(stats.beta.pdf, [0,1,2])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_beta_logpdf(): combo_check(stats.beta.logpdf, [0,1,2])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) - - def test_gamma_cdf(): combo_check(stats.gamma.cdf, [0]) ([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_gamma_pdf(): combo_check(stats.gamma.pdf, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_gamma_logpdf(): combo_check(stats.gamma.logpdf, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - - def test_norm_pdf(): combo_check(stats.norm.pdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) - def test_norm_cdf(): combo_check(stats.norm.cdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) - def test_norm_logpdf(): combo_check(stats.norm.logpdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) - def test_norm_logcdf(): combo_check(stats.norm.logcdf, [0,1,2])([R(4)], [R(4)], [R(4)**2 + 1.1]) - - def test_norm_pdf_broadcast(): combo_check(stats.norm.pdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) - def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) - def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) - def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) - - def test_poisson_cdf(): combo_check(stats.poisson.cdf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) - def test_poisson_logpmf(): combo_check(stats.poisson.logpmf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) - def test_poisson_pmf(): combo_check(stats.poisson.pmf, [1])([np.round(R(4)**2)], [R(4)**2 + 1.1]) - - def test_poisson_cdf_broadcast(): combo_check(stats.poisson.cdf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) - def test_poisson_logpmf_broadcast(): combo_check(stats.poisson.logpmf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) - def test_poisson_pmf_broadcast(): combo_check(stats.poisson.pmf, [1])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) - - def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) - def test_t_cdf(): combo_check(stats.t.cdf, [0,2])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) - def test_t_logpdf(): combo_check(stats.t.logpdf, [0,1,2,3])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) - def test_t_logcdf(): combo_check(stats.t.logcdf, [0,2])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) - - def test_t_pdf_broadcast(): combo_check(stats.t.pdf, [0,1,2,3])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) - def test_t_cdf_broadcast(): combo_check(stats.t.cdf, [0,2])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) - def test_t_logpdf_broadcast(): combo_check(stats.t.logpdf, [0,1,2,3])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) - def test_t_logcdf_broadcast(): combo_check(stats.t.logcdf, [0,2])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) + def test_chi2_pdf(): combo_check(stats.chi2.pdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [1, 2, 3]) + def test_chi2_cdf(): combo_check(stats.chi2.cdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [1, 2, 3]) + def test_chi2_logpdf(): combo_check(stats.chi2.logpdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [1, 2, 3]) + + def test_beta_cdf(): combo_check(stats.beta.cdf, [0], modes=['fwd', 'rev'])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_beta_pdf(): combo_check(stats.beta.pdf, [0,1,2], modes=['fwd', 'rev'])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_beta_logpdf(): combo_check(stats.beta.logpdf, [0,1,2], modes=['fwd', 'rev'])([U(0., 1., 4)], [R(4)**2 + 1.1], [R(4)**2 + 1.1]) + + def test_gamma_cdf(): combo_check(stats.gamma.cdf, [0], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_gamma_pdf(): combo_check(stats.gamma.pdf, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_gamma_logpdf(): combo_check(stats.gamma.logpdf, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + + def test_norm_pdf(): combo_check(stats.norm.pdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) + def test_norm_cdf(): combo_check(stats.norm.cdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) + def test_norm_logpdf(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) + def test_norm_logcdf(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4)], [R(4)], [R(4)**2 + 1.1]) + + def test_norm_pdf_broadcast(): combo_check(stats.norm.pdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) + def test_norm_cdf_broadcast(): combo_check(stats.norm.cdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) + def test_norm_logpdf_broadcast(): combo_check(stats.norm.logpdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) + def test_norm_logcdf_broadcast(): combo_check(stats.norm.logcdf, [0,1,2], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)], [R(4,1)**2 + 1.1]) + + def test_poisson_cdf(): combo_check(stats.poisson.cdf, [1], modes=['fwd', 'rev'])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_logpmf(): combo_check(stats.poisson.logpmf, [1], modes=['fwd', 'rev'])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + def test_poisson_pmf(): combo_check(stats.poisson.pmf, [1], modes=['fwd', 'rev'])([np.round(R(4)**2)], [R(4)**2 + 1.1]) + + def test_poisson_cdf_broadcast(): combo_check(stats.poisson.cdf, [1], modes=['fwd', 'rev'])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_logpmf_broadcast(): combo_check(stats.poisson.logpmf, [1], modes=['fwd', 'rev'])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + def test_poisson_pmf_broadcast(): combo_check(stats.poisson.pmf, [1], modes=['fwd', 'rev'])([np.round(R(4, 3)**2)], [R(4, 1)**2 + 1.1]) + + def test_t_pdf(): combo_check(stats.t.pdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) + def test_t_cdf(): combo_check(stats.t.cdf, [0,2], modes=['fwd', 'rev'])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) + def test_t_logpdf(): combo_check(stats.t.logpdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) + def test_t_logcdf(): combo_check(stats.t.logcdf, [0,2], modes=['fwd', 'rev'])( [R(4)], [R(4)**2 + 2.1], [R(4)], [R(4)**2 + 2.1]) + + def test_t_pdf_broadcast(): combo_check(stats.t.pdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) + def test_t_cdf_broadcast(): combo_check(stats.t.cdf, [0,2], modes=['fwd', 'rev'])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) + def test_t_logpdf_broadcast(): combo_check(stats.t.logpdf, [0,1,2,3], modes=['fwd', 'rev'])([R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) + def test_t_logcdf_broadcast(): combo_check(stats.t.logcdf, [0,2], modes=['fwd', 'rev'])( [R(4,3)], [R(1,3)**2 + 2.1], [R(4,3)], [R(4,1)**2 + 2.1]) def make_psd(mat): return np.dot(mat.T, mat) + np.eye(mat.shape[0]) def test_mvn_pdf(): combo_check(symmetrize_matrix_arg(mvn.pdf, 2), [0, 1, 2], [R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False]) @@ -138,12 +138,12 @@ def test_convolve_generalization(): assert npo.allclose(ag_convolve(A_35, A_342, axes=([1],[2]), dot_axes=([0], [0]), mode=mode)[2], sum([sp_convolve(A_35[i, :], A_342[i, 2, :], mode) - for i in range(3)])) + for i in range(3)])) assert npo.allclose(ag_convolve(A_2543, A_24232, axes=([1, 2],[2, 4]), dot_axes=([0, 3], [0, 3]), mode=mode)[2], sum([sum([sp_convolve(A_2543[i, :, :, j], - A_24232[i, 2, :, j, :], mode) - for i in range(2)]) for j in range(3)])) + A_24232[i, 2, :, j, :], mode) + for i in range(2)]) for j in range(3)])) def test_convolve(): combo_check(autograd.scipy.signal.convolve, [0,1])( @@ -165,35 +165,37 @@ def test_convolve_ignore_dot(): axes=[([1],[1])], dot_axes=[([0],[2]), ([0],[0])], mode=['full', 'valid']) ### Special ### - def test_beta(): combo_check(special.beta, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - def test_betainc(): combo_check(special.betainc, [2]) ([R(4)**2 + 1.1], [R(4)**2 + 1.1], [U(0., 1., 4)]) - def test_betaln(): combo_check(special.betaln, [0,1])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) - - def test_gammainc(): combo_check(special.gammainc, [1])([1], R(4)**2 + 1.3) - def test_gammaincc(): combo_check(special.gammaincc, [1])([1], R(4)**2 + 1.3) - def test_polygamma(): combo_check(special.polygamma, [1])([0], R(4)**2 + 1.3) - def test_jn(): combo_check(special.jn, [1])([2], R(4)**2 + 1.3) - def test_yn(): combo_check(special.yn, [1])([2], R(4)**2 + 1.3) - - def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False) - def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) - def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) - def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) - def test_gammasgn(): unary_ufunc_check(special.gammasgn,lims=[0.3, 2.0], test_complex=False) - def test_rgamma() : unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False) - def test_multigammaln(): combo_check(special.multigammaln, [0])([U(4., 5.), U(4., 5., (2,3))], - [1, 2, 3]) - - def test_j0(): unary_ufunc_check(special.j0, lims=[0.2, 20.0], test_complex=False) - def test_j1(): unary_ufunc_check(special.j1, lims=[0.2, 20.0], test_complex=False) - def test_y0(): unary_ufunc_check(special.y0, lims=[0.2, 20.0], test_complex=False) - def test_y1(): unary_ufunc_check(special.y1, lims=[0.2, 20.0], test_complex=False) - - def test_erf(): unary_ufunc_check(special.erf, lims=[-3., 3.], test_complex=True) - def test_erfc(): unary_ufunc_check(special.erfc, lims=[-3., 3.], test_complex=True) - - def test_erfinv(): unary_ufunc_check(special.erfinv, lims=[-0.95, 0.95], test_complex=False) - def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_complex=False) - - def test_logit(): unary_ufunc_check(special.logit, lims=[ 0.10, 0.90], test_complex=False) - def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False) + def test_beta(): combo_check(special.beta, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + def test_betainc(): combo_check(special.betainc, [2] , modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1], [U(0., 1., 4)]) + def test_betaln(): combo_check(special.betaln, [0,1], modes=['fwd', 'rev'])([R(4)**2 + 1.1], [R(4)**2 + 1.1]) + + def test_gammainc(): combo_check(special.gammainc, [1], modes=['fwd', 'rev'])([1], R(4)**2 + 1.3) + def test_gammaincc(): combo_check(special.gammaincc, [1], modes=['fwd', 'rev'])([1], R(4)**2 + 1.3) + def test_polygamma(): combo_check(special.polygamma, [1], modes=['fwd', 'rev'])([0], R(4)**2 + 1.3) + def test_jn(): combo_check(special.jn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) + def test_yn(): combo_check(special.yn, [1], modes=['fwd', 'rev'])([2], R(4)**2 + 1.3) + + def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) + def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) + def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) + def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) + def test_gammasgn(): unary_ufunc_check(special.gammasgn,lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) + def test_rgamma() : unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False, modes=['fwd', 'rev']) + def test_multigammaln(): combo_check(special.multigammaln, [0], modes=['fwd', 'rev'])([U(4., 5.), U(4., 5., (2,3))], + [1, 2, 3]) + + def test_j0(): unary_ufunc_check(special.j0, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) + def test_j1(): unary_ufunc_check(special.j1, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) + def test_y0(): unary_ufunc_check(special.y0, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) + def test_y1(): unary_ufunc_check(special.y1, lims=[0.2, 20.0], test_complex=False, modes=['fwd', 'rev']) + + def test_erf(): unary_ufunc_check(special.erf, lims=[-3., 3.], test_complex=True, modes=['fwd', 'rev']) + def test_erfc(): unary_ufunc_check(special.erfc, lims=[-3., 3.], test_complex=True, modes=['fwd', 'rev']) + + def test_erfinv(): unary_ufunc_check(special.erfinv, lims=[-0.95, 0.95], test_complex=False, modes=['fwd', 'rev']) + def test_erfcinv(): unary_ufunc_check(special.erfcinv, lims=[0.05, 1.95], test_complex=False, modes=['fwd', 'rev']) + + def test_logit(): unary_ufunc_check(special.logit, lims=[0.05, 0.95], test_complex=False, modes=['fwd', 'rev']) + def test_expit(): unary_ufunc_check(special.expit, lims=[-4.05, 4.95], test_complex=False, modes=['fwd', 'rev']) + + def test_rel_entr(): binary_ufunc_check(special.rel_entr, lims_A=[0.05, 1], lims_B=[0.05, 1], test_complex=False, modes=['fwd', 'rev'])