profile
viewpoint

Ask questions[jax2tf] Incorrect out of bound index handling for lax.dynamic_slice

To reproduce, add the following test to jax/experimental/jax2tf/tests/primitives_test.py:

def test_dynamic_slice_oob(self):
    v = np.array([1, 2])
    f_jax = jax.jit(lambda a: lax.dynamic_slice(a, [5], [1]))
    self.ConvertAndCompare(f_jax, v)

The issue is that dynamic_slice have clamping behaviour for the indices but when we translate it into a tf.slice we lose this behaviour so it will assert during tf2xla translation due to the out of bound index.

The two solution I see is either to clamp the index during translation so we can use tf.slice (and hope that it won't impact performance too much) or use tfxla.dynamic_slice what has the clamping semantics.

google/jax

Answer questions majnemer

Can scatter ignore out of bound indices?

useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Installing from source using Conda and CUDA could be improved hot 1
Clear GPU memory hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
source:https://uonfu.com/
Github User Rank List