profile
viewpoint

Ask questionsrandom uniform with dtype bfloat16 crashes on TPU backend

repro:

from jax import numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (3,), dtype=jnp.bfloat16)

observation:  - crashes on TPU backends (both internal and cloud TPU as far as I can tell)  - CPU and GPU backends don't seem to crash

expect: even if bfloat16 isn't supported by random.* on TPU it would be better to error-out rather than crashing

google/jax

Answer questions majnemer

I think this should work now.

useful!

Related questions

Add has_aux to jacrev, jacfwd and hessian hot 1
Installation problem hot 1
Installing from source using Conda and CUDA could be improved - jax hot 1
jax `odeint` fails against scipy `odeint` hot 1
cuda failed to allocate errors hot 1
cuda failed to allocate errors hot 1
Custom VJPs for external functions hot 1
cuda failed to allocate errors hot 1
Unimplemented NumPy core functions hot 1
Reshape layer for stax - jax hot 1
Installing from source using Conda and CUDA could be improved hot 1
Clear GPU memory hot 1
jax/stax BatchNorm: running average on the training set and l2 regularisation hot 1
source:https://uonfu.com/
Github User Rank List