profile
viewpoint

Ask questionsBinary search function in lax_control_flow_test.py doesn't work where anticipated

I'm trying to find the roots of a function differentiably using lax.custom_root, and for an initial test, I've been using the binary_search routine defined here in lax_control_flow_test.py.

The function is very similar in shape to this one: f = lambda x: -(1+jnp.tanh(x-4))/2 + 0.1, which has a root at ~2.9. No matter which starting bracket I use, I'm unable to find the root of f with this routine. Just as a cross-check, I can find the root with scipy.optimize.fsolve with most reasonable starting values.

I hope this is somehow useful, and not just stemming from my apparent lack of ability to understand the midpoint method (:

google/jax

Answer questions mattjj

@shoyer mind taking a look at this one?

useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Installing from source using Conda and CUDA could be improved hot 1
Clear GPU memory hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
Github User Rank List