diff --git a/benchmarks/dsolve.py b/benchmarks/dsolve.py index 575241b..0f10706 100644 --- a/benchmarks/dsolve.py +++ b/benchmarks/dsolve.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- import sympy -from sympy import symbols, dsolve, Eq, Function, exp +from sympy import S, symbols, dsolve, Eq, Function, exp +from sympy.solvers.ode.riccati import match_riccati, solve_riccati def _make_ode_01(): # du/dt = -k0*u @@ -16,11 +17,52 @@ def _make_ode_01(): return Eq(v(t).diff(t), dvdt_), params +def _make_riccati_particular(): + # Particular solution solver for the Riccati ODE - + # f'(x) = b_0 + b_1*f(x) + b_2*f(x)**2 + # where b_0, b_1, b_2 are rational functions of x + f = Function('f') + x = symbols('x') + + eq = Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - \ + 2)*f(x)**2/(4*x + 2) + (8*x**2 - 7*x + 26)/(\ + 16*x**3 - 24*x**2 + 8) - S(3)/2) + + # Check if equation matches and get b0, b1, b2 + _, (b0, b1, b2) = match_riccati(eq, f, x) + + return (f(x), x, b0, b1, b2) + + +def _make_riccati_general(): + # General solution solver for the Riccati ODE - + # f'(x) = b_0 + b_1*f(x) + b_2*f(x)**2 + # where b_0, b_1, b_2 are rational functions of x + f = Function('f') + x = symbols('x') + + eq = f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 \ + - x + 3)*f(x)/(x*(x - 1)) + (3*x**2 - 2*x + 2)/(x*(x \ + - 1)**2) + + hint = "1st_rational_riccati" + + return eq, f(x), hint + + class TimeDsolve01: def setup(self): self.ode, self.params = _make_ode_01() + self.geneq, self.func, self.hint = _make_riccati_general() + self.parteq, self.args = _make_riccati_particular() def time_dsolve(self): t, y, y0, k = self.params dsolve(self.ode, y[1](t)) + + def time_riccati_partsol(self): + sols = solve_riccati(*self.args) + + def time_riccati_gensol(self): + dsolve(self.eq, self.func, hint=self.hint) diff --git a/benchmarks/tests/test_dsolve.py b/benchmarks/tests/test_dsolve.py index a57ecc3..545ea79 100644 --- a/benchmarks/tests/test_dsolve.py +++ b/benchmarks/tests/test_dsolve.py @@ -1,9 +1,10 @@ from __future__ import absolute_import import sympy -from sympy import dsolve, Ne, exp, refine +from sympy import dsolve, Ne, Eq, exp, refine, Symbol +from sympy.solvers.ode.riccati import solve_riccati -from benchmarks.dsolve import _make_ode_01 +from benchmarks.dsolve import _make_ode_01, _make_riccati_particular, _make_riccati_general def test_make_ode_01(): ode, params = _make_ode_01() @@ -15,3 +16,18 @@ def test_make_ode_01(): int_const = [fs for fs in refined.free_symbols if fs not in ignore][0] ref = int_const*exp(-k[1]*t) - exp(-k[0]*t)*k[0]*y0[0]/(k[0] - k[1]) assert (refined.rhs - ref).simplify() == 0 + + +def test_riccati_particular(): + fx, x, b0, b1, b2 = _make_riccati_particular() + sol = solve_riccati(fx, x, b0, b1, b2) + assert sol == [Eq(fx, (1 - 4*x)/(2*x - 2))] + + +def test_riccati_general(): + eq, fx, hint = _make_riccati_general() + x = list(fx.atoms(Symbol))[0] + gensol = dsolve(eq, hint=hint) + C1 = Symbol('C1') + assert gensol == Eq(fx, (-C1 - x**3 + x**2 - \ + 2*x + 1)/(C1*x - C1 + x**4 - x**3 + x**2 - 2*x + 1))