Ask questions[jax2tf] Incorrect out of bound index handling for lax.dynamic_slice
To reproduce, add the following test to jax/experimental/jax2tf/tests/primitives_test.py:
def test_dynamic_slice_oob(self): v = np.array([1, 2]) f_jax = jax.jit(lambda a: lax.dynamic_slice(a, , )) self.ConvertAndCompare(f_jax, v)
The issue is that dynamic_slice have clamping behaviour for the indices but when we translate it into a tf.slice we lose this behaviour so it will assert during tf2xla translation due to the out of bound index.
The two solution I see is either to clamp the index during translation so we can use tf.slice (and hope that it won't impact performance too much) or use tfxla.dynamic_slice what has the clamping semantics.
Answer questions majnemer
Can scatter ignore out of bound indices?