profile
viewpoint

Ask questionsjax `odeint` fails against scipy `odeint`

It looks like jax.experimental.ode.odeint fails to match the result from scipy.integrate.odeint.

import jax
import jax.numpy as np
from jax.experimental.ode import odeint
import scipy.integrate as osp_integrate

def benchmark_odeint():
    def simple(y, t):
        return y
    
    tspace = np.array((0., 15.))
    
    y0 = np.ones(1)
    
    with jax.disable_jit():
        analytic_result = 1 + np.exp(tspace[-1])
        scipy_result = osp_integrate.odeint(simple, y0, tspace)[-1][0]
        jax_result = odeint(simple, np.array(y0), np.array(tspace))[-1][0]

      
        print("Analytic\t\t\t{:.8f}".format(analytic_result))
        print("Scipy\t\t\t\t{:.8f}".format(scipy_result))
        print("JAX\t\t\t\t{:.8f}".format(jax_result))
        print("abs(scipy - analytic)\t\t{:.8f}".format(np.abs(analytic_result - scipy_result)))
        print("abs(scipy - jax)\t\t{:.8f}".format(np.abs(scipy_result - jax_result)))
        print("abs(analytic - jax)\t\t{:.8f}".format(np.abs(analytic_result - jax_result)))

benchmark_odeint()

Result:

Analytic			3269018.50000000
Scipy				3269018.46060479
JAX				3268983.00000000
abs(scipy - analytic)		0.00000000
abs(scipy - jax)		35.50000000
abs(analytic - jax)		35.50000000

Here is a link to a colab notebook to reproduce these results.

The tests in ode.py only check for the pend system, but the test fails for the swoop system for t on the same order as the above example (decay hangs for larger t).

google/jax

Answer questions shoyer

Note that JAX defaults to float32 precision instead of float64: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision

Does that explain the difference?

useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Installing from source using Conda and CUDA could be improved hot 1
Clear GPU memory hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
Github User Rank List