google/jax 8959

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Topographic Map Viewer

hawkinsp/hawkinsp.github.com 1

Web pages

Flax is a neural network library for JAX that is designed for flexibility.

The Legion Parallel Programming System

The fundamental package for scientific computing with Python.

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.

issue commentgoogle/jax

No details in tensorboard trace profile

Is this on CPU or GPU? The tracing only really shows anything interesting at the moment on GPU.

comment created time in 3 days

Pull request review commentgoogle/jax

Add test of jax.numpy vs numpy call signatures

def f(x): check_grads(f, (1.,), order=1) + def testWrappedSignaturesMatch(self):+ """Test that jax.numpy function signatures match numpy."""+ # Note: testing signatures for an exact match is problematic, because numpy's signatures+ # evolve from release to release (mainly adding new parameters when necessary), and jax.numpy+ # adds additional jax-specific parameters to some functions. Because of this, the test checks+ # that jax.numpy function parameters are a superset of the wrapped numpy function parameters.+ # That way, when tested against older numpy versions (where functions often have fewer parameters),+ # the test should pass, and when run against new numpy releases, the tests will fail indicating+ # where jax functionality must be upgraded to match.++ # TODO(jakevdp): fix the following signatures.+ skip = {+ 'allclose', 'amax', 'amin', 'angle', 'argmax', 'argmin', 'around', 'broadcast_to',

Perhaps it would make sense to have a whitelist of `(function, argument_name)`

pairs?

e.g., for some of these we may choose not to implement one or more parameters (e.g., I'm not sure, say, `order`

makes sense on `jnp.full`

) but in that case we should not suppress the check entirely.

This list would also be more readable and informative if it included the argument names.

comment created time in 4 days

push eventgoogle/jax

commit sha 47ac612214c923ba77ea7c781e3aef5026f68f8c

Update XLA. (#3623)

push time in 4 days

PR merged google/jax

Update jaxlib release notes.

This XLA version includes a workaround for issue #2874 .

pr closed time in 4 days

push eventhawkinsp/jax

commit sha 98a32f2c5999c69d9367bf2bd9dcaacaaab114c2

Remove special case for scan masking The scan masking rule is now simplified to a standard masking rule, and the special case handling in masking.py has been removed.

commit sha 0a51c5b669842c85abb5c65e7852cf6f134cb640

Minor

commit sha c39b6a40652218fbd67923ea57499a6cfc2d75fa

Remove unused scan shape rule

commit sha dfd2d8c564b304343090b85643c1f2abcf2515fa

Fix dtype, minor

commit sha dd040de18d43b52d5e3dc50bb1cb7610d638bdc0

Bump XLA version. (#3424) Update jaxlib release notes.

commit sha ddea95e6e0e8f87f739cd3a4ee7aeac0b48d4e99

reuse commonly-typed input/output values when joining partially evaluated conditional branches

commit sha 4a836ff8ae8b38ab7b949da386e208584f83bc85

factor autodiff and vmap tests out from lax_test

commit sha ae9df752de75344d231c7ec973fdd562a91d8913

add docstring to ravel_pytree

commit sha 159a61b2f7f52423338c232c2656fe016d253e75

deflake

commit sha 269da0ae584cfe840f34e9f871f13c28e2772de5

Merge pull request #3425 from google/document-ravel-pytree add docstring to ravel_pytree

commit sha 1b88fba57c735ae909debece1d640ba3c3f67459

fix and better explain complex JVPs / VJPs fixes #3433

commit sha 0c29cc15b9580e194407b981ce963d37410be49d

fix typos

commit sha b2105ab370a4567aaf4eed910395f20a2bda67d0

Merge pull request #3434 from google/autodiff-cookbook-complex-bug-fixes fix and better explain complex JVPs / VJPs

commit sha 482067640578a40f088e5556a702090d12c26d5a

add jax.numpy.concatenate(..., axis=None) support fixes #3419

commit sha 29fa935ca56dd6aa8fc688a30f860c300fe93bd6

fix vmap-of-pmap mapped_invars logic bug fixes #3399 This crept in via #1959, but more importantly it shows we don't have good test coverage here!

commit sha 021c02ce341276eadf46c912135f8e54b1fc1cbf

Merge pull request #3436 from google/issue3419 add jax.numpy.concatenate(..., axis=None) support

commit sha c9d1b99e51e02af658de72f842236eba0fb1fca2

Merge pull request #3439 from google/issue3399 fix vmap-of-pmap mapped_invars logic bug

commit sha 3deada9ede0008bcfeb05118e9a9a0634e0f360c

Document valid enum values for precision. (#3441) This is a little tricky to figure out otherwise.

commit sha 0e804296766384763bfbb8cd6e2758b623d919ef

Added jax2tf test about primitive coverage (#3420)

commit sha 26b6ebaf0d418a25b3902a289c48ef6e7f389c5e

[jax2tf] Fixed the handling of `core.unit` in control-flow primitives. (#3432) * Fixed the handling of `core.unit` in control-flow primitives. * Remove add_any from the list of unimplemented tf ops

push time in 4 days

PR opened google/jax

Update jaxlib release notes.

This XLA version includes a workaround for issue #2874 .

pr created time in 4 days

push eventhawkinsp/jax

commit sha 8d6fa4695ed0d9fde42b4207d372d770aaf3936f

Bump XLA version. Update jaxlib release notes.

push time in 4 days

push eventgoogle/jax

commit sha 141fabbbf581c952135330aeeeb449833f3bbdf7

Reimplement argmin/argmax using a single pass variadic reduction. (#3611)

push time in 4 days

PR merged google/jax

Fixes #3602

pr closed time in 4 days

issue closedgoogle/jax

Inconsistencies and divergence depending on use of JIT

It seems that on some machines computational results differ significantly if `jit`

is applied.

I have come across this odd behavior in an implementation of a batched Monte Carlo integration. On some machines, when part of the code is `jit`

transformed, the results are significantly off and some `inf`

values occur. This result seems to depend on ostensibly irrelevant code (adding zero times a no-nan expression), and the specific sampling method. Due to this nature I could not pin down the issue to a single expression; neither am I entirely sure I haven't missed something. Below is a description of the algorithm, as minimal as I could make it with the error occurring, followed by a summary of the different behaviors.

### The code

The code consists of the following steps:

- Sample complex points as the solution to a polynomial equation (
`p+tq`

for fixed`p`

,`q`

such that`sum((p+tq)^5)=0`

). - For each sample point (each point consists of 4 complex numbers) compute a weight using
`jax.grad`

. - Take the mean over
`batch_size`

of these weights as one batch-step. - Iterate over a given number of
`batches`

and add up all means obtained from the batch-steps.

The following is a script taking two true/false arguments: whether to apply `jit`

and whether to use the `fori_loop`

(the combination `true`

`flase`

takes very long to compile).
It contains parts to save the samples, so the weights that should have been obtained can be checked afterwards, and it can be excluded that the error occurs because something changes about the sampling (up to numerical error the samples are the same independent of `jit`

use -- as they should be using the same keys).

```
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as onp
import functools
import sys
def pop(arr, index):
# drop and return value at index in arr
rolled = jnp.roll(arr, -index, axis=0)
return rolled[0], jnp.roll(rolled[1:], index, axis=0)
def insert_col(mat, col, index):
# insert column col at index into matrix mat
mat = jnp.roll(mat, -index, axis=1)
mat = jnp.concatenate([col.reshape(-1, 1), mat], axis=1)
return jnp.roll(mat, index, axis=1)
@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
return jnp.sum(z ** 5) + p * jnp.prod(z)
@functools.partial(jax.vmap, in_axes=(0, None))
def weight(z, p):
grads = grad_eqn(z, p)
dep = jnp.argmax(jnp.abs(grads))
grad_max, grad_rest = pop(grads, dep)
col = (-grad_rest / grad_max)[:, None]
mat = jnp.concatenate((jnp.eye(3, dtype=jnp.complex_), col), axis=1)
mat = mat @ mat.T.conj()
det = jnp.linalg.det(mat).real
return 1 / det
def sample_sphere(key, count, dim):
points = jax.random.normal(key, (count, dim))
return points / jnp.linalg.norm(points, axis=1, keepdims=True)
@jax.vmap
def solve_poly(p, q):
# polynomial in t given by (q + t * p)**5
coeffs = jnp.array([q**5, 5 * p**1 * q**4, 10 * p**2 * q**3,
10 * p**3 * q**2, 5 * p**4 * q, p**5])
coeffs = jnp.sum(coeffs, axis=1)
roots = jnp.roots(coeffs, strip_zeros=False)
return p.reshape(1, -1) + roots.reshape(-1, 1) * q.reshape(1, -1)
def sample_poly(key, count):
# solution has multiplicity 5, need count / 5 q's and p's
base_count = jnp.ceil(count / 5).astype(int)
# sample base_count comples p's and q's
pqs = sample_sphere(key, 2, 2 * base_count * 5)
ps, qs = (pqs[0] + 1j * pqs[1]).reshape(2, base_count, 5)
sol = solve_poly(ps, qs)
return sol.reshape(-1, 5)[:count, :]
@jax.vmap
def divide_largest(z):
# divide by and drop largest absolute entry
z0, z = pop(z, jnp.argmax(jnp.abs(z)))
return z / z0
def monte_carlo(key, batches, batch_size, fori=True):
keys = jax.random.split(key, batches)
def batch_step(i, data):
mean, samples = data
key_sample = keys[i]
zs = sample_poly(key_sample, batch_size)
zs = divide_largest(zs)
# save samples
samples = jax.ops.index_update(samples, jax.ops.index[i, :, :], zs)
weights = weight(zs, jnp.array(0.))
return mean + jnp.mean(weights), samples
mean = jnp.array(0.)
samples = jnp.zeros((batches, batch_size, 4), dtype=jnp.complex_)
if fori:
mean, samples = jax.lax.fori_loop(0, batches, batch_step, (mean, samples))
else:
for i in range(batches):
mean, samples = batch_step(i, (mean, samples))
return mean, samples
if __name__ == '__main__':
key = jax.random.PRNGKey(0)
apply_jit = len(sys.argv) > 1 and sys.argv[1] == 'true'
fori_loop = len(sys.argv) > 2 and sys.argv[2] == 'true'
niter = 10
batches = 51
batch_size = 1000
mc = functools.partial(monte_carlo, fori=fori_loop,
batches=batches, batch_size=batch_size)
if apply_jit:
mc = jax.jit(mc)
save_name = 'samples-jit-%i.npy'
else:
save_name = 'samples-%i.npy'
# skip some keys
for i in range(25):
_, key = jax.random.split(key)
for i in range(niter):
k0, key = jax.random.split(key)
mean, samples = mc(k0)
print(mean)
# save samples to manually check computations
# onp.save(save_name % i, samples)
```

### Behavior

As noted, the sample values do not differ depending on `jit`

and `fori_loop`

combination. Computing the weights and means of weights manually from saved sample values always gives finite numerical values which are consistent with the ones obtained by no `jit`

and no `fori_loop`

use (`$ python script.py false false`

).
Depending on the computer, both cases in which `fori_loop`

is used may give wrong values containing `inf`

's.
This occurred on both local machines I have tested with. Running the same code on colab, however, gives the right (and same) results in all combinations (which is why I suspect there is an underlying issue, not one in the code).

The following are the first 10 results obtained with the above script in two different environments and various combinations of `jit`

and `fori_loop`

:

XPS; false, false | XPS; true, true | XPS; false, true | Colab; true, true |
---|---|---|---|

38.87047976604907 | 35.23827002862667 | 35.167724443321404 | 38.904195431290844 |

38.85501838205715 | inf | 35.21379197263621 | 38.875554832009456 |

38.87232142336747 | 35.07552733029048 | 35.16629159384102 | 38.9613642029768 |

38.82467883296542 | 35.268796318296 | 35.18550169177784 | 38.86870896981942 |

38.875347911324106 | 35.065090432638506 | 35.12925896136021 | 38.91082515791209 |

38.81607498879701 | 35.045350301233476 | 35.087313851691306 | 38.84038161357294 |

38.884758144142545 | 35.204243102525254 | 35.19112069680813 | 38.97964735892668 |

38.884639882640634 | inf | 35.23049623201075 | 38.907215776623836 |

38.96790493327401 | inf | 35.311082582397795 | 38.90340030598595 |

38.91302814793844 | 35.26023361519001 | 35.243122471869846 | 38.87890524435126 |

None of the complexities in the code seem to be removable without making the behavior disappear.

- If the sampling is replaced by just uniform or normal random numbers no more
`inf`

's appear and all combinations give the same results. - Removing the wrapper function (which uses the
`fori_loop`

) around a batched step, and instead just returning the mean from a single batch removes the issue and all results are the same. - The gradient is taken of a function
`sum(z ** 5) + p * prod(z)`

where for`p`

always`p=0`

is passed. Given this fact, the gradient can be manually replaced with`5 * z**4`

(in the real application the gradient would potentially be less simple), which again removes the erroneous behavior. - The most peculiar dependency on the specific implementation is the following: since
`p=0`

, the term`+ p * prod(z)`

should not change the results. Removing it, however, also removes the issue (no`nan`

s and values ~38 not ~35). Even if present in the modified form`+ 0 * jnp.nan_to_num(p * jnp.prod(z))`

it reintroduces the error.

### Summary

The erroneous values seem to be connected with the use of `fori_loop`

and the gradient of `prod`

multiplied such that it should vanish. The behavior seems contingent on the random sampling used, making it difficult narrow down to a specific expression that is responsible. Specifically, computing the weights after the samples are computed gives the right results and doesn't reproduce the erroneous behavior. Any thoughts about how to narrow down here would be appreciated.

### Testing environment

All tests were run with the CPU version of jax and jaxlib installed via `pip`

. The current jax version on colab is `0.1.69`

.

The numerical results above were obtained on a Dell XPS 13 with i5-7200U CPU, jax version `0.1.70`

, and python `3.8.3`

.
I also saw the same behavior on a desktop machine with jax `0.1.72`

, Xeon Silver 4114 CPU, and python `3.6.9`

.
I'm not sure what other environment variables may be relevant.

closed time in 4 days

mathisgerdesissue commentgoogle/jax

How to select the jax release version

For general questions about JAX, I suggest using the new "Discussions" tab on Github rather than using issues.

For your specific question, you can either form Python complex numbers using `complex`

as you state, although that probably won't work if the arguments are numpy arrays, only Python literals. I usually just write `x + y * 1j`

or similar, and that works fine with JAX arrays `x`

and `y`

.

comment created time in 4 days

issue commentgoogle/jax

How to select the jax release version

I'm sorry we don't provide docker images at the moment. Our supported installation instructions use `pip`

wheels: https://github.com/google/jax#installation

comment created time in 4 days

issue closedgoogle/jax

How to select the jax release version

Taking a look again, it looks to me like this issue is resolved at head.

My version of the benchmark looks like this:

```
from jax import device_put
import numpy as onp
import jax.numpy as np
from jax import vmap
from functools import partial
import time
from jax import jit
from jax import grad
def invm_plus(Pb,Pc):
Pbc = Pb + Pc
_Pbc = Pbc * np.array([-1,-1,-1,1])
return np.sum(Pbc * _Pbc,axis=1)
def invm(Pbc):
_Pbc = Pbc * np.array([-1,-1,-1,1])
return np.sum(Pbc * _Pbc,axis=1)
def BW(m_,w_,Sbc):
i = complex(0,1)
gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
return k/(m_*m_ - Sbc - i*m_*w_)
def phase(theta, rho):
i = complex(0,1)
return rho * np.exp(theta*i)
def _abs(bw_):
conjbw = np.conj(bw_)
return np.real(bw_*conjbw)
Kp = onp.random.sample(80000*4).reshape(80000,4)
Km = onp.random.sample(80000*4).reshape(80000,4)
Pip = onp.random.sample(80000*4).reshape(80000,4)
Pim = onp.random.sample(80000*4).reshape(80000,4)
phif001 = onp.random.sample(80000*2).reshape(80000,2)
phif021 = onp.random.sample(80000*2).reshape(80000,2)
phif0 = np.asarray([phif001,phif021])
phi = invm_plus(Kp,Km)
f0 = invm_plus(Pip,Pim)
phim = np.array([2.,1.,1.,1.,2.,1.,1.,1.])
phiw = np.array([1.,2.,1.,1.,2.,1.,1.,1.])
f0m = np.array([1.,1.,1.,3.,1.,1.,1.,1.])
f0w = np.array([1.,1.,1.,1.,1.,1.,1.,1.])
const = np.array([[2.,1.,1.,1.,1.,1.,1.,1.],[1.,1.,1.,1.,1.,1.,1.,1.]])
rho = np.array([1.,1.,2.,1.,1.,1.,1.,1.])
theta = np.array([1.,1.,1.,1.,3.,1.,3.,1.])
def BW_f0(phim,phiw,f0m,f0w,phi,f0):
return vmap(partial(BW,Sbc=phi))(phim,phiw) * vmap(partial(BW,Sbc=f0))(f0m,f0w)
def phase_f0(theta_,rho_):
result = vmap(phase)(theta_,rho_)
return result
def test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0):
ph = phase_f0(theta,rho)
bw = BW_f0(phim,phiw,f0m,f0w,phi,f0)
const_phase = np.einsum('ij,j->ij',const,ph)
_phif0 = np.einsum('ijk,il->ljk',phif0,const_phase)
_phif0 = np.einsum('ijk,ij->jk',_phif0,bw)
_phif0 = np.real(np.sum(_abs(_phif0),axis=1))
return -np.sum(np.log(_phif0))
test_pw_jit = jit(test_pw)
print(test_pw_jit(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
m = (0,1,2,3,4,5,6)
grad_test_pw = jit(grad(test_pw_jit,argnums=m))
_ = grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0)[0].block_until_ready()
s = time.time()
print(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
e = time.time()
print("time : ", e - s)
def test_pw1(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0):
temp = np.zeros(80000)
for i in range(8):
bw = BW(phim[i],phiw[i],phi) * BW(f0m[i],f0w[i],f0)
_phif001 = phif001.T * bw *const[0,i]
_phif021 = phif021.T * bw *const[1,i]
_phif0 = (_phif001 + _phif021) * phase(theta[i],rho[i])
temp = temp + _phif0
data = np.real(np.sum(_abs(temp),axis=0))
return -np.sum(np.log(data))
grad_test_pw1 = jit(grad(test_pw1,argnums=m))
#print(test_pw1_jit(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0))
_ = grad_test_pw1(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0)[0].block_until_ready()
s = time.time()
print(grad_test_pw1(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0))
e = time.time()
print("time : ", e - s)
```

Both versions seem roughly equal in performance, around 4-5ms on a P100.

*Originally posted by @hawkinsp in https://github.com/google/jax/issues/1763#issuecomment-614885455*

Please tell me the JAX version you are using. Is the version of cuda and cudnn matter?

closed time in 4 days

XDongiangissue commentgoogle/jax

How to select the jax release version

I'd suggest using the current `jaxlib`

(0.1.50) and `jax`

(0.1.72). I don't think the CUDA/CuDNN versions matter that much. Hope that helps!

comment created time in 4 days

pull request commentgoogle/jax

Reimplement argmin/argmax using a single pass variadic reduction.

Done. I added a GPU-specific fallback to the original 2-pass algorithm.

comment created time in 4 days

push eventhawkinsp/jax

commit sha a5d77d2707b26334c1ed59f67532dfff69dddbe3

Add missing dtype arguments.

push time in 4 days

push eventhawkinsp/jax

commit sha db8f66d508d536e1312291ba966905ddf910582d

Rework type support for lax cumulative reductions (#3609)

commit sha e808681f6c096e95be8532e4e901f5d410e0fb58

add jit-consistency checks to a few lax_numpy tests

commit sha eb2a22758898b0470b24ee79d271723873d58956

fix reduction repeated axis error (#3618) * fix reduction repeated axis error * deflake

commit sha 107689e91f0940327a807e2a9890da3071e6feac

improve vmap axis spec structure mismatch errors (#3619) * improve vmap axis spec structure mismatch errors fixes #3613 * deflake

commit sha f852bcdb774da7c290ab7ad726be934d43e2138a

Reimplement argmin/argmax using a single pass variadic reduction.

commit sha 8276d465e43050a5d750b62e0977dd499390e3c8

Add argmin/argmax to jax2tf.

commit sha 89a860b4cc9f112d893beec7ee5030dc0864ae71

Fix comparator name.

commit sha c0ff2fcff1e23649d56f282cd2996a86bd1fedde

Set a zero jvp on argmin/argmax.

commit sha 65359502a506fc1bd843d73bed7fced12348c2af

Fix jax2tf rules.

commit sha 5f74ac5cdf21aba914a3ef91ca6e1d132fe07596

Fix typo.

commit sha b1a53c4cb9750a4ed4aa94d64be19ae02799299b

Add fallback to 2-pass algorithm for GPU. Make index dtype configurable.

push time in 4 days

push eventhawkinsp/jax

commit sha 50a97bb9d56226817c707bc7ad377a6bc4f8bfec

Fix typo.

push time in 5 days

push eventhawkinsp/jax

commit sha 47620da01dac682ebeb0b07e477422a961ca7b7c

Fix jax2tf rules.

push time in 5 days

push eventhawkinsp/jax

commit sha 1702faaf8d854fe123982d127dda3cea3b9e37f1

Set a zero jvp on argmin/argmax.

push time in 5 days

issue commentgoogle/jax

Inconsistencies and divergence depending on use of JIT

PR #3611 seems to fix the problem for me!

comment created time in 5 days

push eventhawkinsp/jax

commit sha 5034b328d5423c3d06cc6ca8b8d7a151fca6cb5b

Fix comparator name.

push time in 5 days

push eventhawkinsp/jax

commit sha 9c7263b53c99377a7e19769cb2f7e99aeacc0cb7

Add argmin/argmax to jax2tf.

push time in 5 days

PR opened google/jax

Fixes #3602

pr created time in 5 days

issue commentgoogle/jax

Inconsistencies and divergence depending on use of JIT

I think what's happening here is that our 2-pass implementation of `argmax`

is breaking for this benchmark because XLA chooses to recompute the the input for each pass, and it ends up at least a little bit different, breaking the exact equality the algorithm expects.

Here's a smaller reproduction of what I believe to be going wrong:

```
import functools
import jax
import jax.numpy as jnp
import numpy as onp
from jax.config import config
config.update("jax_enable_x64", True)
@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
return jnp.sum(z**5) + p * jnp.prod(z)
@jax.vmap
def foo(pp):
w = pp
grads = grad_eqn(w, jnp.array(0.))
dep = jnp.argmax(jnp.abs(grads))
return dep
pp = onp.array([
[
0.013914733367254616 + 9.3094023595025388e-01j, 0.7261907418672447 -
3.5631603853758342e-01j, 0.4192612993988415 + 1.5784670831992959e-01j,
-0.3979203076971916 + 9.0370153633300654e-01j
],
[
0.7412515559421278 + 5.8209475619094264e-01j, -0.7882039695924775 +
8.2271693304329360e-02j, 0.2473007941817975 + 2.9011353191281253e-01j,
0.3953279913354844 - 8.7640244493755098e-02j
],
[
0.42107995148850474 - 5.2786845011285044e-01j,
-0.9752016614993378 + 4.6493984002048627e-02j, 0.14899616328175808 -
2.5660252776982036e-02j, 0.7597534752633921 - 2.9817585059867030e-01j
],
[
0.07919062080298368 + 2.0443288812872526e-01j, -0.9144440581831745 +
5.0375688239721973e-02j, 0.7189319378618075 - 4.7781041939994917e-01j,
0.28616184673974954 + 5.8406319450724586e-01j
],
[
-0.9124599088445945 + 2.0523072700406764e-02j,
-0.49751861383647644 + 6.6934941229283518e-01j, -0.09127256955394417 -
2.5171143504173576e-01j, 0.6288139554943047 - 6.0843626330134115e-01j
],
[
0.10383981031500984 - 5.8961771394978757e-01j,
0.16722015050067904 - 5.4033783049084971e-01j,
-0.3570686716415365 + 7.9098843585535017e-01j,
-0.36717711934106606 - 8.7325230201503201e-01j
],
[
-0.4246625385849623 - 1.0866918975223600e-01j, -0.08009235275156984 -
5.0556054151587737e-01j, 0.4353126751296819 - 1.7484451023920469e-01j,
-0.29468982612935085 + 9.4805376113112527e-01j
],
[
-0.33924468705124716 + 9.3201718347059126e-01j, -0.31334017694099114 -
1.0582535646176845e-01j, 0.6724881365502651 + 2.2066280675715644e-01j,
-0.5552091363575214 - 1.4490197928805444e-02j
],
[
0.6526283624759489 + 4.9140814443027613e-02j, 0.723922878773778 +
3.9038679286342901e-02j, 0.762488834126774 - 6.2337884354704665e-01j,
-0.39085070216092266 + 7.9604864742611015e-01j
],
[
0.48937363384974575 - 5.1103393766115682e-02j, 0.7539474954530498 -
3.3604089724531272e-01j, 0.7479090648182917 - 6.4041309134555557e-01j,
-0.34829014191239205 + 3.2357526022662669e-01j
],
[
-0.09058883804930377 + 3.5194499540548291e-01j,
-0.013080375527656812 - 5.3701680483099290e-01j, -0.7690608912513537 -
7.1176000810729834e-02j, 0.7408780750401596 - 5.8939023080126030e-01j
],
[
-0.1448148911478361 + 1.1680988499664302e-01j,
0.6357579332535874 + 4.8007163633817268e-01j, -0.20356990725787427 -
9.2667400504246511e-01j, -0.486415685914065 - 6.7109379995585339e-01j
],
[
0.6208929968646449 - 2.6040937438741862e-01j,
0.18721272656033006 + 4.2043707645400447e-01j, -0.29536647248428616 +
7.3555331996715753e-02j, -0.9936616656553029 + 2.7635687670210948e-02j
],
[
0.12004262077972877 + 1.2780756181079903e-01j, -0.3058488919814152 +
9.5148878693497396e-01j, 0.3781176412486092 - 1.4516939708606397e-01j,
0.20414911926428903 + 2.8066844435687388e-01j
],
[
0.44393385508190725 + 4.1301429499018250e-01j,
0.775660315846601 - 6.1108689231453883e-01j, -0.22979601723668555 +
5.6425408946759868e-02j, 0.6348633420533681 - 2.1792338815000339e-01j
],
[
0.5903309747549792 + 5.7825608084895608e-04j, -0.9688715927965774 -
8.0116867317087063e-03j, 0.6118501784052073 + 4.1552864146077695e-01j,
0.011088080369843475 - 1.2357486664766727e-01j
]
],
dtype=onp.complex128)
nojit = foo(pp)
withjit = jax.jit(foo)(pp)
print(" no jit: ", nojit)
print("with jit: ", withjit)
onp.testing.assert_allclose(nojit, withjit)
```

Output:

```
no jit: [3 0 1 1 0 3 3 0 2 2 3 2 3 1 1 1]
with jit: [ 3 0 1
1 0 3
3 0 2
2 3 2
3 1 9223372036854775807
9223372036854775807]
Traceback (most recent call last):
File "u.py", line 109, in <module>
onp.testing.assert_allclose(nojit, withjit)
File "/Users/phawkins/.pyenv/versions/py3.7.4/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1528, in assert_allclose
verbose=verbose, header=header, equal_nan=equal_nan)
File "/Users/phawkins/.pyenv/versions/py3.7.4/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 840, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 2 / 16 (12.5%)
Max absolute difference: 9223372036854775806
Max relative difference: 1.
x: array([3, 0, 1, 1, 0, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1], dtype=int64)
y: array([ 3, 0, 1,
1, 0, 3,
3, 0, 2,...
```

The fix is probably to switch our `argmax`

implementation to use a 1-pass algorithm via a variadic reduction.

comment created time in 5 days

push eventgoogle/jax

commit sha 420ef4e0a8bfb33386dcf0ed977ceb9aa40c58b8

Fix shape rule for lax.pad for input dimensions of size 0. (#3608)

push time in 5 days

PR merged google/jax

Fixes #3599

pr closed time in 5 days

issue closedgoogle/jax

lax.pad breaks for zero-sized inputs

```
from jax import lax, numpy as jnp
out = lax.pad(jnp.ones((0,)), 0., ((1, 1, 1),))
print(out.shape) # (1,)
print(out) # [0. 0.]
print(out[0]) # RuntimeError: Invalid argument: Argument does not match host shape or layout of computation parameter 0: want f32[1]{0}, got f32[2]{0}
```

Pad works as expected for zero-sized inputs for non-interior padding (i. e. padding config `((1, 1, 0),)`

), so I guess this should also work (or at least give an error).

closed time in 5 days

JuliusKunzeissue commentgoogle/jax

Inconsistencies and divergence depending on use of JIT

Can you share the output of `lscpu`

on all the hosts you've tried this code on? I'm speculating that this might be because of a compiler bug, and one reason this might differ across machines is that the compiler may be targeting different CPU architectures.

comment created time in 5 days

PR opened google/jax

Fixes #3599

pr created time in 5 days

issue commentgoogle/jax

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU

Wouldn't it suffice to look for the first 0 on the diagonal of the input matrix?

(I'm also just trying to understand what API contract you expect, because the usual contract says "this is an illegal input". It's possible we can make the XLA algorithm mimic the behavior of the usual TRSM algorithm in this case, but it's not clear to me the behavior in the singular case is actually well defined without also fixing the choice of algorithm.)

comment created time in 6 days

issue commentgoogle/jax

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU

I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, `scipy.linalg.solve_triangular`

would raise a `singular matrix`

exception in the corresponding situation.

comment created time in 6 days

issue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

I should add that it would be possible to workaround this bug by avoiding U16/S16 types on the JAX side. Since we have larger integer types available (U32, S32) this is possible if slightly annoying to do.

comment created time in 6 days

issue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

Google internal bug b/156977343.

comment created time in 6 days

issue commentgoogle/jax

In the absence of the ability to define a custom transpose rule which allows JAX to derive the VJP from the JVP, I think the best bet for supporting both forward and reverse-mode autodiff is to upgrade `expm_frechet`

into a first-class JAX primitive. This isn't much harder than the current `custom_jvp`

approach, but would use internal APIs. See: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

comment created time in 6 days

issue commentgoogle/jax

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU

For the batched case the implementation switches to a completely different algorithm (actually, the same implementation used on TPU): https://github.com/google/jax/blob/7b57dc8c8043163a5e649ba66143ccef880d7d58/jax/lax_linalg.py#L436 For the batch 1 case, we call LAPACK TRSM which acts as you say.

The batched case calls into XLA, which uses an algorithm inspired by MAGMA that inverts diagonal blocks: https://github.com/tensorflow/tensorflow/blob/bd006c354f11f9045d344f3e48b47be9f8368dac/tensorflow/compiler/xla/service/triangular_solve_expander.cc#L439

comment created time in 7 days

Pull request review commentgoogle/jax

def _cumprod_jvp_rule(primals, tangents, *, axis: int): return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents) +def _cummax_jvp_rule(primals, tangents, *, axis: int, unit: Number):

Can we share the implementation here and make it parametric?

comment created time in 7 days

Pull request review commentgoogle/jax

def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int): batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p) +cummax_p = standard_primitive(

I suspect we can share all of these with `cumprod`

too by adding a helper function, perhaps something like this:

```
cummax_p = _cumulative_reductive(lax.max, ...)
```

comment created time in 7 days

issue closedgoogle/jax

This is `np.linalg.expm`

. It probably makes sense to port the TF implementation in terms of Pade approximants (https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/linalg/linalg_impl.py#L228-L339), which itself is a port from Eigen.

closed time in 7 days

jekbradburyissue commentgoogle/jax

`expm`

is implemented, together with its JVP. I'm going to close this issue, leaving https://github.com/google/jax/issues/3447 open for the reverse-mode derivative.

comment created time in 7 days

issue closedgoogle/jax

Gradient of functions involving `expm`

Hi, the following attempt to take the gradient of a function involving a matrix exponential

```
import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm
def foo(mat):
return jnp.sum(expm(mat))
grad = jax.grad(foo)(jnp.eye(2))
```

is failing with the following assertion error

```
~/anaconda3/envs/envname/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
96 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
97 aval_primals, const_primals = unzip2(pval_primals)
---> 98 assert all(aval_primal is None for aval_primal in aval_primals)
99 if not has_aux:
100 return const_primals, pval_tangents, jaxpr, consts
AssertionError:
```

I'm using `jax.__version__ == 0.1.62`

and `jaxlib.__version__ == 0.1.43`

, thanks!

closed time in 7 days

danieljtaitissue commentgoogle/jax

Gradient of functions involving `expm`

The forward-mode derivative was merged as part of #2062 and https://github.com/google/jax/issues/3447 is open for the reverse-mode derivative. I'm going to close this issue since I think the remaining work is a duplicate of #3447

comment created time in 7 days

pull request commentgoogle/jax

add lax.associative_scan to docs

The docs appear to have a several rendering problems: https://jax--3583.org.readthedocs.build/en/3583/_autosummary/jax.lax.associative_scan.html#jax.lax.associative_scan

Please fix?

comment created time in 7 days

push eventgoogle/jax

commit sha 7b57dc8c8043163a5e649ba66143ccef880d7d58

Issue1635 expm frechet (#2062) * Implement Frechet derivatives for expm. * Update expm to use the current custom gradients API. Make some stylistic fixes. Co-authored-by: Peter Hawkins <phawkins@google.com>

push time in 7 days

PR merged google/jax

This is an implementation of `scipy.linalg.expm_frechet`

to compute the derivatives of matrix exponentiation. It was mentioned in #1635 as a way to provide derivatives for `jax.scipy.linalg.expm`

which has itself been implemented in #1940 .This pull request does not set the `defvjp`

rules to use `scipy.linalg.expm_frechet`

because testing such an enhancement seems nontrivial. **EDIT: The defjvp rule has been added**

There are a few things to note:

`scipy.linalg.expm_frechet_algo_64`

creates an identity matrix in the following way:`ident = np.identity(n)`

. This is problematic because it does not ensure that`ident`

and`A`

have the same`dtype`

. Note further that`scipy.linalg.expm`

does set the`dtype`

of the identity matrix to be the same as`A`

. The`dtype`

of the output of`scipy.linalg.expm_frechet`

can therefore be different from its arguments. For this reason,`_CheckAgainstNumpy`

is called with`check_dtypes=False`

.

Major thanks to @jhueckelheim for identifying this problem.

- There was a need to set the
`dtype`

of the constant 2.0 within`_diff_pade13`

to`A.dtype`

to ensure numerical correctness. Originally the code was:`A = A * 2.0**-s`

We found that the`dtype`

of the expression`A * 2.0**-s`

is not necessarily the same as`A`

. The code was changed to:`two = np.array([2.0],A.dtype)`

`A = A * two[0]**-s`

pr closed time in 7 days

push eventsriharikrishna/jax

commit sha 6daf44219ac13853baf4e811d9a10e33c9aa9001

Fix flake8 error.

push time in 7 days

push eventgoogle/jax

commit sha 1f2025e12f469d04deff45fb7ba8f12530254fe2

Incorporate a few recent NumPy API extensions. (#3586)

push time in 7 days

PR merged google/jax

pr closed time in 7 days

push eventgoogle/jax

commit sha 17fc8b75c26d27c55be022618b6b567e59611ace

Implement DeviceArray.tobytes(). (#3585) Move tolist() implementation onto DeviceArray; there's no reason to populate it dynamically.

push time in 7 days

PR merged google/jax

Move tolist() implementation onto DeviceArray; there's no reason to populate it dynamically.

Fixes https://github.com/google/jax/issues/1791

pr closed time in 7 days

issue closedgoogle/jax

converting JAX arrays to bytes

The method I can use,

```
import jax.numpy as np
import struct
import numpy as onp
from random import random
def a():
a = np.zeros((100,200,)) + random()
h, w = a.shape
shape = struct.pack('>II',h,w)
encoded = shape + onp.asarray(a).tobytes()
%timeit a()
The slowest run took 90.44 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 608 µs per loop
```

Now if I use only change `np`

to `onp`

```
import jax.numpy as np
import struct
import numpy as onp
from random import random
def a():
a = onp.zeros((100,200,)) + random()
h, w = a.shape
shape = struct.pack('>II',h,w)
encoded = shape + onp.asarray(a).tobytes()
%timeit a()
The slowest run took 61.96 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 47.6 µs per loop
```

Is there any particular reason why the `tobytes`

method is not there natively for `jax.numpy`

?

closed time in 7 days

dchatterjee172push eventsriharikrishna/jax

commit sha cc53aa956b0571efb3b0237dd87d92d509f8b1fd

skip new optix test on tpu (cf. #2350)

commit sha 9fd69a04ea6072479513ce45391d4fbe37998715

Replace uses of ExecutePerReplica with ExecuteOnLocalDevices. (#2394) ExecutePerReplica is deprecated, and ExecuteOnLocalDevices is now available via the minimum jaxlib version.

commit sha ebbcbad547e56357e600cd8c19232d4b91cf4f00

allow vmap in_axes to be a list, fixes #2367 (#2395)

commit sha cfbdb65ad8637e883679a0d0516acdc28dd9e8ea

add register_pytree_node_class, fixes #2396 (#2400) Co-authored-by: Stephan Hoyer <shoyer@google.com> Co-authored-by: Stephan Hoyer <shoyer@google.com>

commit sha ffa03403c9c6d9f071125cdc793c9b88e421cf13

Add jnp vs np out of bounds indexing to Sharp Bits nb (#2378)

commit sha 419961f9dd1365f741c68bd88daee3c6f89b43d7

Check for invalid shapes in broadcast_in_dim and fail gracefully.

commit sha cdf188af2fd4f256c2c5c390ec0d09ed321212d0

add raises-exception notebook cell metadata (#2402)

commit sha 2dfeaeb63fa9e884ef5b76bc43cf99b2c5a5c04f

Allow zero tolerance for jax.test_util.tolerance (#2393) Currently, if a user passes any falsy value to jax.test_util.tolerance, it is changed to the default value. This makes sense when the value passed is None, but not when the value passed is 0 (which indicates a desired tolerance of exactly 0). Disables failing tests for now.

commit sha 620bf4300b74c298a5e0133e08f60f76700cf37f

[remat] Change remat lowering to XLA::Conditional (#2391) * [remat] Change remat lowering to XLA::Conditional `jax.remat` creates rematerializing passes that don't have data dependencies on the actual loss-computing forward pass. This means that the XLA scheduler was free to schedule the remat forward pass before the loss-computing pass, defeating the goal of saving accelerator memory with `jax.remat`. In practice, it sometimes did for my workloads. This change expresses the lowering of remat_call(f) as: Conditional(true, inputs, f, inputs, dummy_f). In the common case of `jax.grad(jax.remat(f))`, the content of the lowered remat_call are both the forwards & backwards; that is, the incoming cotangents are part of the args. Additionally, Conditional (AFAIK) is un-inlineable in the sense that it doesn't execute until all its inputs (e.g. cotangents!) are available. Downsides: - AFAICT, we can no longer interleave computation in/outside the rematerialized block. - Potentially, lower performance. I do not observe this in my tests. * provide no replication info for subcomputation params

commit sha 61b430eeb40aeef3254f50dbcb79271e7ab3db96

Added more documentation for how to fix notebook build failures (#2404)

commit sha cf41f7682fef099fe1810bd49e64c9439e2d4f3d

Add np.linalg and np.fft functions to documentation. (#2407)

commit sha 58feed2bcb6802d2b560712648d11a441c82909e

jax.lax.nextafter test fix. (#2408) Fixes #2403.

commit sha 7f0463e2c9d7dfbe9c451baabb7c603eac6a4b3d

remove input shapes from params of some primitives (#2410) Long, long ago, when JAX was first born, we realized that we couldn't transpose this jaxpr: { lambda ; a. let b = reduce_sum[ axes=(0,) ] a in b } The problem was that the transpose of a reduce-sum is a broadcast, but because jaxprs didn't have shape information available, we didn't know what input shape to broadcast to! Our hack was to have the primitives that required shape information for transposition to acquire it into their parameters, so that we'd produce jaxprs like this one: { lambda ; a. let b = reduce_sum[ axes=(0,) input_shape=(3,) ] a in b } That's not only aesthetically unpleasant, but also it meant we were limiting an (unused) capability of the system: ideally we should be able to trace a reduce-sum jaxpr without specializing on shape information (e.g. at the Unshaped level) and only require shape specialization for transposition. (Good thing no one actually traces at Unshaped...) But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that shape information (or whatever information with which the jaxpr was specialized out of Python) is in the jaxpr itself. So we could finally remove these shapes-in-params warts! That's exactly what this commit does! Co-authored-by: Roy Frostig <frostig@google.com> Co-authored-by: Roy Frostig <frostig@google.com>

commit sha 271041b499029d639bade9e1c77d3dc64a89f9cd

Update a regression test to test size-zero device to device transfers. (#2411)

commit sha c41c4de5a54e16998fdb6b9e238f301d6299d078

lower fori_loop to scan when possible (#2414) When a fori_loop specialized on trip count is to be evaluated, it's preferable to generate a scan rather than a while_loop because the former is reverse-mode differentiable while the latter is not. Otherwise they're essentially the same; in particular, no extensive inputs/outputs arise unless reverse-mode autodiff is applied. Also fixes #2412.

commit sha 47df7b95c44d27a2bb78636c7642a60cdb622402

change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption

commit sha e84a621184967618997e2b0018fa88979735a7cd

new jet implementation, with conv-based rules Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>

commit sha a21fdf8669437a7a052983e18112b3379b955290

more jet rules and tests Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>

commit sha 7adf9fe84f5b522ee9120a7ff7763ea0708fc394

add more jet rules! Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>

commit sha ddd52c47301d48cb65b6c7098a164b99362efa3a

adding div and linear prims Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>

push time in 7 days

pull request commentgoogle/jax

I rebased this PR on top of head, and updated it to use the new `custom_jvp`

API.

comment created time in 7 days

PR opened google/jax

pr created time in 7 days

push eventhawkinsp/jax

commit sha 1909fe7dc51d3652163d08d962797ced0aab110b

Incorporate a few recent NumPy API extensions.

push time in 7 days

push eventhawkinsp/jax

commit sha 40f5cce2bf9cc62100e0a4f296cd60feaa51e696

Incorporate a few recent NumPy API extensions.

push time in 7 days

PR opened google/jax

Move tolist() implementation onto DeviceArray; there's no reason to populate it dynamically.

Fixes https://github.com/google/jax/issues/1791

pr created time in 7 days

issue closedgoogle/jax

I got out-of-memory error during the backward, after run four times forward and backward plus one forward. I have no idea why it keep increasing memory usage. Do you have any method to trace memory usage? I knew Jax initially allocate 90% memory, but I want to trace actual usage.

File ".local/lib/python3.6/site-packages/jax/api.py", line 398, in value_and_grad_f g = vjp_py(onp.ones((), dtype=dtype)) File ".local/lib/python3.6/site-packages/jax/api_util.py", line 66, in apply_flat_fun_nokwargs ans = fun(*args) File ".local/lib/python3.6/site-packages/jax/interpreters/ad.py", line 114, in vjp_ _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts) File ".local/lib/python3.6/site-packages/jax/interpreters/ad.py", line 178, in backward_pass eqn.params, subjaxpr, sub_consts, sub_freevar_vals, invals, cts_in) File ".local/lib/python3.6/site-packages/jax/interpreters/ad.py", line 458, in call_transpose out_flat = primitive.bind(fun, *all_args, **params) File ".local/lib/python3.6/site-packages/jax/core.py", line 569, in call_bind outs = primitive.impl(f, *args, **params) File ".local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 369, in _xla_call_impl return compiled_fun(*args) File ".local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 403, in _execute_compiled out_bufs = compiled.Execute(input_bufs).destructure() RuntimeError: Resource exhausted: Out of memory while trying to allocate 2692743312 bytes.

closed time in 7 days

neon5dissue commentgoogle/jax

We just added a device memory profiler to JAX that should be exactly what you need to track down this sort of memory leak.

https://jax.readthedocs.io/en/latest/device_memory_profiling.html

I tried to reproduce the OOM you describe with JAX from head and it doesn't seem to reproduce. The memory usage appears to remain constant at around 4GB. I've attached a device memory profile visualization as produced by the memory profiling tool (on CPU). profile002.pdf

Please feel free to reopen if you have any more problems!

comment created time in 7 days

issue closedgoogle/jax

I'm using the install of CUDA 10.1 with Arch Linux and got the following error. The problem seems to be that although I specified a path of `/opt/cuda`

the script seems to be looking for the includes in `/opt/cuda/lib64`

```
$ python build.py --enable_march_native --enable_mkl_dnn --enable_cuda --cuda_path /opt/cuda
...
Bazel binary path: ./bazel-0.24.1-linux-x86_64
Python binary path: ~/miniconda3/envs/pytorch_dev/bin/python
MKL-DNN enabled: yes
-march=native: yes
CUDA enabled: yes
CUDA toolkit path: /opt/cuda
CUDNN library path: /opt/cuda
Building XLA and installing it in the jaxlib source tree...
ERROR: Skipping ':install_xla_in_source_tree': error loading package 'build': in ~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/tensorflow/core/platform/default/cuda_build_defs.bzl: Encountered error while reading extension file 'cuda/build_defs.bzl': no such package '@local_config_cuda//cuda': Traceback (most recent call last):
File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 1266
_create_local_cuda_repository(repository_ctx)
File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 988, in _create_local_cuda_repository
_get_cuda_config(repository_ctx)
File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 714, in _get_cuda_config
find_cuda_config(repository_ctx, ["cuda", "cudnn"])
File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 694, in find_cuda_config
auto_configure_fail(("Failed to run find_cuda_config...))
File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 325, in auto_configure_fail
fail(("\n%sCuda Configuration Error:%...)))
Cuda Configuration Error: Failed to run find_cuda_config.py: Could not find any cublas_api.h matching version '' in any subdirectory:
''
'include'
'include/cuda'
'include/*-linux-gnu'
'extras/CUPTI/include'
'include/cuda/CUPTI'
of:
'/opt/cuda/extras/CUPTI/lib64'
'/opt/cuda/lib64'
'/opt/cuda/nvvm/lib64'
'/usr'
'/usr/lib'
'/usr/lib/intel'
'/usr/lib/libfakeroot'
'/usr/lib/openmpi'
```

As a work around I created a symlink from `/usr/include/cuda`

to `/opt/cuda/include`

but that's not an ideal solution.

closed time in 7 days

dhpollackissue commentgoogle/jax

Seems like this issue may be moot now, especially given the last comment. Closing; feel free to reopen if you are still having problems!

comment created time in 7 days

issue closedgoogle/jax

Buffer donation to a jit function

Below is a CNN iteratedly applied to a 2Gb input. It produces a 4x2Gb = 8 Gb peak memory consumption.

```
import jax.numpy as np
import jax.random as random
from jax import lax
from jax import jit
@jit
def f(x):
for _ in range(10):
x = lax.conv_general_dilated(x, np.ones((3, 3, 1, 1)), (1, 1), 'SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
return x
x = random.normal(random.PRNGKey(1), (2**19, 2**5, 2**5, 1))
# (2**20, 2**5, 2**5, 1)) OOMs!
x = f(x)
```

Without JIT, the peak memory consumption is 2x2Gb = 4 Gb, as is expected.

Would be great to achieve a comparable memory usage with JIT by input buffer donation to the jit function (not sure on the exact terminology).

Thanks a lot!

closed time in 7 days

romannggissue commentgoogle/jax

Buffer donation to a jit function

Buffer donation has been checked in!

comment created time in 7 days

issue closedgoogle/jax

Compilation error (nv-nsight-cu-cli not found)

I am trying to compile JAX from source with GPU support. My GCC version is 8.3. When I run

$ python build/build.py --enable_cuda

It looks for the file /usr/local/cuda-10.1/NsightCompute-*/nv-nsight-cu-cli but it doesn't exist.

I fixed it temporarily by creating the following soft link.

$ ln -s /usr/local/cuda-10.1/NsightCompute-*/nv-nsight-cu-cli.orig /usr/local/cuda-10.1/NsightCompute-*/nv-nsight-cu-cli

closed time in 7 days

adityaiitbissue commentgoogle/jax

Compilation error (nv-nsight-cu-cli not found)

I suspect this issue is obsolete now; closing. Please feel free to reopen if you can still reproduce it!

comment created time in 7 days

issue closedgoogle/jax

matmul slow for complex dtypes

Hey, thanks for the great work here!

I noticed that matmuls for complex dtypes are ~ 20x to 25x slower on my macbook than they are for real dtypes. Here is a simple code that does the timing. Wondering if this is expected, or what I can do to speed that up. Thanks!

```
import numpy as np
import jax
from jax import config
config.update('jax_enable_x64',True)
import time
@jax.jit
def do_matvec_simple(matrix, vector):
res = 0
for _ in range(100):
res += matrix @ vector
return res
@jax.jit
def do_matmul_simple(matrix1, matrix2):
res = 0
for _ in range(100):
res += matrix1 @ matrix2
return res
def run_timings_dot(dtype, D):
matrix = jax.numpy.array(np.random.rand(D,D).astype(dtype))
vector = jax.numpy.array(np.random.rand(D).astype(dtype))
t1=time.time()
for _ in range(100):
res = matrix @ vector
res.block_until_ready()
print(f'loop over 100 matrix-vector muls in dtype {np.dtype(dtype).name}', time.time() -t1)
res = do_matvec_simple(matrix, vector)
res.block_until_ready()
t1 = time.time()
res = do_matvec_simple(matrix, vector)
res.block_until_ready()
print(f'jit 100 do_matvec_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
def run_timings_matmul_simple(dtype, D):
A = jax.numpy.array(np.random.rand(D,D).astype(dtype))
B = jax.numpy.array(np.random.rand(D,D).astype(dtype))
t1=time.time()
for _ in range(100):
res = A@B
res.block_until_ready()
print(f'loop over 100 matrix-matrix muls in dtype {np.dtype(dtype).name}', time.time() -t1)
res = do_matmul_simple(A,B)
res.block_until_ready()
t1 = time.time()
res = do_matmul_simple(A,B)
res.block_until_ready()
print(f'jit 100 do_matmul_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
print('######## timings for matrix-vector ###########')
print(' ---------- float64 --------------')
run_timings_dot(np.float64, 1000)
print(' ---------- complex128 --------------')
run_timings_dot(np.complex128, 1000)
print()
print()
print('######## timings for matrix-matrix ###########')
print(' ---------- float64 --------------')
run_timings_matmul_simple(np.float64, 400)
print(' ---------- complex128 --------------')
run_timings_matmul_simple(np.complex128, 400)
```

update: disabling double precision seems to increase the slowdown to ~ 100x

closed time in 8 days

mganahlissue commentgoogle/jax

matmul slow for complex dtypes

This bug should be fixed in jaxlib 0.1.48 or newer (jaxlib 0.1.50 was released yesterday).

comment created time in 8 days

PR merged google/jax

pr closed time in 9 days

push eventgoogle/jax

commit sha 5116fd47aa9cf44888b88c685c380422bf0fdc50

Add a heap profiler API and document it. (#3576)

push time in 9 days

pull request commentgoogle/jax

Add a heap profiler API and document it.

I'm going to merge this given Skye has reviewed; if Jake has comments I'm happy to make them in a follow-up PR.

comment created time in 9 days

push eventhawkinsp/jax

commit sha 0332dbdf6d6884af42d690b2bdb393d0e03c5257

Add more crossreferences.

push time in 9 days

push eventhawkinsp/jax

commit sha 0bb119ca4a9eb229d5179a4f94c73dd2b2cdddd5

Incorporate review comments.

push time in 9 days

push eventhawkinsp/jax

commit sha 71c1a120f1f64f3ed59f007362d4f2404eedd5f2

Fix flake8 warnings.

push time in 9 days

push eventhawkinsp/jax

commit sha d22604482dccc7975929a2e8ce1b72e004e03e2d

Add note about gperftools version of pprof.

push time in 9 days

PR opened google/jax

pr created time in 9 days

created taggoogle/jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

created time in 9 days

push eventgoogle/jax

commit sha b576417e5bb760fbb13097b7e3277f2715c06965

Update README for jaxlib 0.1.50 release. (#3574)

push time in 9 days

PR merged google/jax

pr closed time in 9 days

PR opened google/jax

pr created time in 9 days

push eventhawkinsp/jax

commit sha 98d46d39aa9138fbd8143411b99dc42ef0a36ad3

Implement indexing helpers

commit sha 72593f0489315cefe685f36f6c89ae8b81dfc860

Modify syntax to `x.at[idx].set(y)` and similar.

commit sha ab0a903172fdff74287b047ae7e0e26b6fc1c0fc

Merge branch 'master' into index_sugar

commit sha 9e429907c1236e600473d2ff31008cb81e9d1395

Add support for `mul`

commit sha f4f67b065f5b4663474df8e91785524101387a36

Remove unused textwrap

commit sha d356c41f77401fb796a763f19beda65f86300f50

Release jaxlib 0.1.44. (#2740)

commit sha 708107ebe30826d428b56cebe553b0e9016a0c80

Add numpy.rint to lax numpy (#2724) * Add numpy.rint to lax numpy * Use round_to_nearest_even for numpy.rint * Add rint to jax.numpy docs * Fix np.rint float promotion

commit sha 903b50ebf102b3152d60a7cc0e465bdeeee2c8f5

Some cleanup and reformatting in `xla.py`. - Make creation of a few dictionaries more readable. - Use f-strings where possible. - Remove unused imports and function parameters. - Don't format string before passing to `log` function.

commit sha f7070e56d19b0cbdac7d6f2474bae6dde673d612

Merge branch 'master' into changelist/306845248

commit sha 42887172b04c8933d628d5101221d38b4815114f

Add a regression test that runs the same computation on all devices that are present. (#2741)

commit sha d906c89e0bf39f6b41d360db27b1c70b9b29fdad

fix scipy_signal_test convolve failures

commit sha 6d889efd30d8c1b13eceb2493e0cb3b726fbec0a

Merge pull request #2743 from chr1sj0nes/changelist/306845248 Some cleanup and reformatting in `xla.py`.

commit sha a5efe842562af81cae4df7aca95e52981dea2802

Update names and documentation.

commit sha 835949a36a22df40e3206a29c84a520e19793c40

Fix typo

commit sha 3265b8c3c7839a2e805ab32d3a314e512127e8e3

Thread precision through np.convolve & np.correlate

commit sha aeb0d036cb71af55212af2b112d698a8b9bd739d

set precision=HIGHEST only for TPU test

commit sha 8c00b35f21e51163dcbd8b427d4c861b6d6cc13a

Add FIXMEs for AD type errors

commit sha 7d716b8306d6df0c012e9bfffeb1df439e3e560d

Add a simple form of partial evaluation for while_loop. (#2497) The issue that I wanted to fix was that when running grad(while_loop), the error was a cryptic assertion failure (that all primals are known after linearization, in ad.py:linearize). I could not figure out how to detect before that assertion that we are doing a reverse AD for while_loop. So, I implemented a simple form of partial evaluation, to allow the primals after linearization to be known, so that the code proceeds and can then fail gracefully when trying to transpose the while. This is not a proper implementation of partial evaluation. The known outputs are computed early, properly. But the unknown outputs are computed by a *whole* computation of, including the known parts. Fixes issue: #2129

commit sha 9a5b8d626a484d7dfded25bee7b408ce7d37816b

Assert that reduction computations don't have constants. (#2754) This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.

commit sha e6f0b8d87dbc48e3b36ec980485a19a8543508d1

Raise an error if stop_gradient is called on non-arrays (#2750) * Raise an error if stop_gradient is called on non-arrays * Fix incorrect usage of stop_gradient in solve() * fix *other* misuse of stop_gradient

push time in 9 days

PR merged google/jax

pr closed time in 9 days

push eventgoogle/jax

commit sha 66cea0277cf3f02ed3975465761c061a89e084a0

Fix test failures on GPU. (#3572)

push time in 9 days

PR opened google/jax

pr created time in 9 days

pull request commentgoogle/jax

Refactor tests to define common functionality in test_util

Yes, unfortunately parameterizing the device under test is tricky because it needs to happen after flag parsing but before test case instantiation. This is why we have an unusual idiom of parsing the flags early in tests, right after module imports. If you can think of a better scheme for running the same tests on multiple devices that works in all the configurations we need to run tests I'm certainly open to it!

comment created time in 9 days

push eventgoogle/jax

commit sha 8f86b139fe78b9e65bcdec7f4244cfc105b8b629

Update docker script for CUDA 11. (#3571)

push time in 9 days

PR merged google/jax

pr closed time in 9 days

PR opened google/jax

pr created time in 9 days

push eventhawkinsp/jax

commit sha a141cc6e8d36ff10e28180683588bedf5432df1a

Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555) * Use dynamic loading to locate CUDA libraries in jaxlib. This should allow jaxlib CUDA wheels to be manylinux2010 compliant. * Tag CUDA jaxlib wheels as manylinux2010. Drop support for CUDA 9.2, add support for CUDA 11.0. * Reorder CUDA imports.

commit sha 2a6fc316c3e2de99abdf4a97656abbed44a1c626

Update XLA in preparation for a new jaxlib release (0.1.50). (#3560)

commit sha db80ca5dd87d27e1800834c2828298aa3f6d8992

allow closures for odeint dynamics functions (#3562) * allow closures for odeint dynamics functions fixes #2718, #3557 * add tests for odeint dynamics closing over tracers

commit sha c2501d1bef7e97e597c2108c86645f265a249326

update version and changelog for pypi (#3564)

commit sha 062ce297ddf9056ca7743a2d262b0db070eb4553

removed stale faq entries (#3565)

commit sha 4b1bb1890937e9fb81de210305c77c13fb4952ec

Fix logsumexp test (#3563) Previously, this test could generate an axis out of bounds error with large num_generated_cases. (discovered in the process of testing #3561)

commit sha 26c6c3a457414577f7232699095877d6c86d032d

fix error when doing forward-mode of odeint (#3566) fixes #3558

commit sha 99a7e2c9ef8e9a9d4218b64ccd4f2990e08a9c12

Update docker script for CUDA 11.

push time in 9 days

issue commentgoogle/jax

Feature request: add jax[cudaversion] to pypi

#3555 solves one of the problems (manylinux2010 compilance), but not all of them. The others are:

- we would need a PyPi size limit exception
- we would probably consider carefully how to structure our packages, if there are multiple variants that may potentially be installed at the same time.

comment created time in 9 days

issue commentgoogle/jax

There's no technical blocker to using JAX on AMD GPUs. We on the JAX team simply don't have access to any AMD GPUs at the moment to develop or test the necessary changes (which are probably not that large, given most of the necessary work has been done in the context of TensorFlow.)

Contributions are welcome!

comment created time in 9 days

issue commentgoogle/jax

illegal memory access while two GPUs is used

One thing I'd like to double check: you aren't *also* using TensorFlow in the same process? There's some sort of multigpu interactions between the two that I haven't yet tracked down. https://github.com/google/jax/issues/2742

comment created time in 9 days

issue commentgoogle/jax

I don't think this is an XLA or JAX issue in particular, it sounds to me like the docker file needs fixing. Do any CUDA programs work in this configuration (e.g., say TF or PyTorch)? I'm not sure there's much we can do here; is there a way you can install the matching CUDA toolkit version in the docker configuration?

comment created time in 9 days

issue commentgoogle/jax

I've asked the XLA folks to work around the NVIDIA IRFFT bug by cloning the input for all IRFFTs. Hopefully that can happen relatively quickly.

(If any NVIDIA folks are reading this, this is partners bug # 2959622 )

Interestingly PyTorch also seems to have experienced the same bug: https://github.com/pytorch/pytorch/issues/34551

comment created time in 9 days