Ask questionsBinary search function in lax_control_flow_test.py doesn't work where anticipated
I'm trying to find the roots of a function differentiably using
lax.custom_root, and for an initial test, I've been using the
binary_search routine defined here in
The function is very similar in shape to this one:
f = lambda x: -(1+jnp.tanh(x-4))/2 + 0.1, which has a root at ~2.9. No matter which starting bracket I use, I'm unable to find the root of
f with this routine. Just as a cross-check, I can find the root with
scipy.optimize.fsolve with most reasonable starting values.
I hope this is somehow useful, and not just stemming from my apparent lack of ability to understand the midpoint method (:
Answer questions mattjj
@shoyer mind taking a look at this one?