profile
viewpoint

Ask questionsCustom VJPs for external functions

Hi! I want to define custom gradients for a simulation for sensitivity analysis. I have been using autograd for this, but since it is not actively being developed anymore I wanted to switch to jax. In autograd I would write something like this:

from autograd import grad
from autograd.extend import primitive, defvjp
import simulation

@primitive
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

defvjp(sim, sim_vjp)

In autograd, this worked fine and I was able to chain this together with some other differentiable transformations and get gradients out of the whole thing. From what I was able to gather, the above would be written in jax as follows:

import jax
import simulation

@jax.custom_transforms
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

jax.defvjp_all(sim, sim_vjp)

However, this throws Exception: Tracer can't be used with raw numpy functions., which I assume is because the simulation code does not use jax. Are the custom gradients in jax not black-boxes as in autograd anymore, i.e. is this a fundamental limitation or have I screwed something up? Do I need to implement this using lax primitives, and if so, how?

I would be grateful for a minimal example implementing this for some arbitrary non-jax function. This code here for example works in autograd:

from autograd import grad
from autograd.extend import primitive, defvjp
from scipy.ndimage import gaussian_filter

@primitive
def filter(img):
    return gaussian_filter(img, 1)

def filter_vjp(ans, img):
    def vjp(g):
        return gaussian_filter(g, 1)
    return vjp

defvjp(filter, filter_vjp)

How would one translate this so it works in jax? Thanks so much!

google/jax

Answer questions tpr0p

I'm having the same issue. I can not get the example code from the defvjp_all documentation to work.

System information: OS: Linux - Ubuntu 18.04 Python: 3.7.2

Build information: I built from source following the instructions in the README.

git clone https://github.com/google/jax
cd jax
python build/build.py
pip install -e build  # install jaxlib (includes XLA)
pip install -e .      # install jax

Minimal code to reproduce (jax_test.py):

import jax
import numpy as np

@jax.custom_transforms
def f(x):
    return np.square(x)

def f_vjp(x):
    return f(x), lambda g: 2 * x * g

jax.defvjp_all(f, f_vjp)

def main():
    jax.grad(f, 0)(1.)

if __name__ == "__main__":
    main()

Stack trace:

Traceback (most recent call last):
  File "jax_test.py", line 23, in <module>
    main()
  File "jax_test.py", line 19, in main
    jax.grad(f, 0)(1.)
  File "/home/tcpropson/repos/jax/jax/api.py", line 341, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/tcpropson/repos/jax/jax/api.py", line 387, in value_and_grad_f
    ans, vjp_py = vjp(f_partial, *dyn_args)
  File "/home/tcpropson/repos/jax/jax/api.py", line 1002, in vjp
    out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
  File "/home/tcpropson/repos/jax/jax/interpreters/ad.py", line 105, in vjp
    out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
  File "/home/tcpropson/repos/jax/jax/interpreters/ad.py", line 94, in linearize
    jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
  File "/home/tcpropson/repos/jax/jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
    jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
  File "/home/tcpropson/repos/jax/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/tcpropson/repos/jax/jax/api.py", line 1175, in __call__
    jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals_in, instantiate=True)
  File "/home/tcpropson/repos/jax/jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
    jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
  File "/home/tcpropson/repos/jax/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax_test.py", line 10, in f
    return np.square(x)
  File "/home/tcpropson/repos/jax/jax/core.py", line 287, in __array__
    raise Exception("Tracer can't be used with raw numpy functions. "
Exception: Tracer can't be used with raw numpy functions. You might have
  import numpy as np
instead of
  import jax.numpy as np
useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Installing from source using Conda and CUDA could be improved hot 1
Clear GPU memory hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
Github User Rank List