Ask questionsrandom uniform with dtype bfloat16 crashes on TPU backend
from jax import numpy as jnp from jax import random x = random.uniform(random.PRNGKey(0), (3,), dtype=jnp.bfloat16)
observation: - crashes on TPU backends (both internal and cloud TPU as far as I can tell) - CPU and GPU backends don't seem to crash
expect: even if bfloat16 isn't supported by random.* on TPU it would be better to error-out rather than crashing
Answer questions majnemer
I think this should work now.