Ask questionsjax `odeint` fails against scipy `odeint`

It looks like `jax.experimental.ode.odeint` fails to match the result from `scipy.integrate.odeint`.

``````import jax
import jax.numpy as np
from jax.experimental.ode import odeint
import scipy.integrate as osp_integrate

def benchmark_odeint():
def simple(y, t):
return y

tspace = np.array((0., 15.))

y0 = np.ones(1)

with jax.disable_jit():
analytic_result = 1 + np.exp(tspace[-1])
scipy_result = osp_integrate.odeint(simple, y0, tspace)[-1][0]
jax_result = odeint(simple, np.array(y0), np.array(tspace))[-1][0]

print("Analytic\t\t\t{:.8f}".format(analytic_result))
print("Scipy\t\t\t\t{:.8f}".format(scipy_result))
print("JAX\t\t\t\t{:.8f}".format(jax_result))
print("abs(scipy - analytic)\t\t{:.8f}".format(np.abs(analytic_result - scipy_result)))
print("abs(scipy - jax)\t\t{:.8f}".format(np.abs(scipy_result - jax_result)))
print("abs(analytic - jax)\t\t{:.8f}".format(np.abs(analytic_result - jax_result)))

benchmark_odeint()
``````

Result:

``````Analytic			3269018.50000000
Scipy				3269018.46060479
JAX				3268983.00000000
abs(scipy - analytic)		0.00000000
abs(scipy - jax)		35.50000000
abs(analytic - jax)		35.50000000
``````

Here is a link to a colab notebook to reproduce these results.

The tests in ode.py only check for the `pend` system, but the test fails for the `swoop` system for t on the same order as the above example (`decay` hangs for larger t).