profile
viewpoint

google/jax 8959

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

hawkinsp/ZTopo 5

Topographic Map Viewer

hawkinsp/flax 0

Flax is a neural network library for JAX that is designed for flexibility.

hawkinsp/legion 0

The Legion Parallel Programming System

hawkinsp/numpy 0

The fundamental package for scientific computing with Python.

hawkinsp/numpyro 0

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

hawkinsp/opt_einsum 0

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.

issue commentgoogle/jax

No details in tensorboard trace profile

Is this on CPU or GPU? The tracing only really shows anything interesting at the moment on GPU.

AdrienCorenflos

comment created time in 3 days

Pull request review commentgoogle/jax

Add test of jax.numpy vs numpy call signatures

 def f(x):     check_grads(f, (1.,), order=1)  +  def testWrappedSignaturesMatch(self):+    """Test that jax.numpy function signatures match numpy."""+    # Note: testing signatures for an exact match is problematic, because numpy's signatures+    # evolve from release to release (mainly adding new parameters when necessary), and jax.numpy+    # adds additional jax-specific parameters to some functions. Because of this, the test checks+    # that jax.numpy function parameters are a superset of the wrapped numpy function parameters.+    # That way, when tested against older numpy versions (where functions often have fewer parameters),+    # the test should pass, and when run against new numpy releases, the tests will fail indicating+    # where jax functionality must be upgraded to match.++    # TODO(jakevdp): fix the following signatures.+    skip = {+      'allclose', 'amax', 'amin', 'angle', 'argmax', 'argmin', 'around', 'broadcast_to',

Perhaps it would make sense to have a whitelist of (function, argument_name) pairs?

e.g., for some of these we may choose not to implement one or more parameters (e.g., I'm not sure, say, order makes sense on jnp.full ) but in that case we should not suppress the check entirely.

This list would also be more readable and informative if it included the argument names.

jakevdp

comment created time in 4 days

push eventgoogle/jax

Peter Hawkins

commit sha 47ac612214c923ba77ea7c781e3aef5026f68f8c

Update XLA. (#3623)

view details

push time in 4 days

PR merged google/jax

Bump XLA version. cla: yes

Update jaxlib release notes.

This XLA version includes a workaround for issue #2874 .

+3 -3

0 comment

1 changed file

hawkinsp

pr closed time in 4 days

push eventhawkinsp/jax

Julius Kunze

commit sha 98a32f2c5999c69d9367bf2bd9dcaacaaab114c2

Remove special case for scan masking The scan masking rule is now simplified to a standard masking rule, and the special case handling in masking.py has been removed.

view details

Julius Kunze

commit sha 0a51c5b669842c85abb5c65e7852cf6f134cb640

Minor

view details

Julius Kunze

commit sha c39b6a40652218fbd67923ea57499a6cfc2d75fa

Remove unused scan shape rule

view details

Julius Kunze

commit sha dfd2d8c564b304343090b85643c1f2abcf2515fa

Fix dtype, minor

view details

Peter Hawkins

commit sha dd040de18d43b52d5e3dc50bb1cb7610d638bdc0

Bump XLA version. (#3424) Update jaxlib release notes.

view details

Roy Frostig

commit sha ddea95e6e0e8f87f739cd3a4ee7aeac0b48d4e99

reuse commonly-typed input/output values when joining partially evaluated conditional branches

view details

Roy Frostig

commit sha 4a836ff8ae8b38ab7b949da386e208584f83bc85

factor autodiff and vmap tests out from lax_test

view details

Matthew Johnson

commit sha ae9df752de75344d231c7ec973fdd562a91d8913

add docstring to ravel_pytree

view details

Matthew Johnson

commit sha 159a61b2f7f52423338c232c2656fe016d253e75

deflake

view details

Matthew Johnson

commit sha 269da0ae584cfe840f34e9f871f13c28e2772de5

Merge pull request #3425 from google/document-ravel-pytree add docstring to ravel_pytree

view details

Matthew Johnson

commit sha 1b88fba57c735ae909debece1d640ba3c3f67459

fix and better explain complex JVPs / VJPs fixes #3433

view details

Matthew Johnson

commit sha 0c29cc15b9580e194407b981ce963d37410be49d

fix typos

view details

Matthew Johnson

commit sha b2105ab370a4567aaf4eed910395f20a2bda67d0

Merge pull request #3434 from google/autodiff-cookbook-complex-bug-fixes fix and better explain complex JVPs / VJPs

view details

Matthew Johnson

commit sha 482067640578a40f088e5556a702090d12c26d5a

add jax.numpy.concatenate(..., axis=None) support fixes #3419

view details

Matthew Johnson

commit sha 29fa935ca56dd6aa8fc688a30f860c300fe93bd6

fix vmap-of-pmap mapped_invars logic bug fixes #3399 This crept in via #1959, but more importantly it shows we don't have good test coverage here!

view details

Matthew Johnson

commit sha 021c02ce341276eadf46c912135f8e54b1fc1cbf

Merge pull request #3436 from google/issue3419 add jax.numpy.concatenate(..., axis=None) support

view details

Matthew Johnson

commit sha c9d1b99e51e02af658de72f842236eba0fb1fca2

Merge pull request #3439 from google/issue3399 fix vmap-of-pmap mapped_invars logic bug

view details

Stephan Hoyer

commit sha 3deada9ede0008bcfeb05118e9a9a0634e0f360c

Document valid enum values for precision. (#3441) This is a little tricky to figure out otherwise.

view details

George Necula

commit sha 0e804296766384763bfbb8cd6e2758b623d919ef

Added jax2tf test about primitive coverage (#3420)

view details

George Necula

commit sha 26b6ebaf0d418a25b3902a289c48ef6e7f389c5e

[jax2tf] Fixed the handling of `core.unit` in control-flow primitives. (#3432) * Fixed the handling of `core.unit` in control-flow primitives. * Remove add_any from the list of unimplemented tf ops

view details

push time in 4 days

PR opened google/jax

Bump XLA version.

Update jaxlib release notes.

This XLA version includes a workaround for issue #2874 .

+22 -3

0 comment

2 changed files

pr created time in 4 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 8d6fa4695ed0d9fde42b4207d372d770aaf3936f

Bump XLA version. Update jaxlib release notes.

view details

push time in 4 days

push eventgoogle/jax

Peter Hawkins

commit sha 141fabbbf581c952135330aeeeb449833f3bbdf7

Reimplement argmin/argmax using a single pass variadic reduction. (#3611)

view details

push time in 4 days

PR merged google/jax

Reimplement argmin/argmax using a single pass variadic reduction. cla: yes

Fixes #3602

+125 -27

2 comments

7 changed files

hawkinsp

pr closed time in 4 days

issue closedgoogle/jax

Inconsistencies and divergence depending on use of JIT

It seems that on some machines computational results differ significantly if jit is applied.

I have come across this odd behavior in an implementation of a batched Monte Carlo integration. On some machines, when part of the code is jit transformed, the results are significantly off and some inf values occur. This result seems to depend on ostensibly irrelevant code (adding zero times a no-nan expression), and the specific sampling method. Due to this nature I could not pin down the issue to a single expression; neither am I entirely sure I haven't missed something. Below is a description of the algorithm, as minimal as I could make it with the error occurring, followed by a summary of the different behaviors.

The code

The code consists of the following steps:

  1. Sample complex points as the solution to a polynomial equation (p+tq for fixed p, q such that sum((p+tq)^5)=0).
  2. For each sample point (each point consists of 4 complex numbers) compute a weight using jax.grad.
  3. Take the mean over batch_size of these weights as one batch-step.
  4. Iterate over a given number of batches and add up all means obtained from the batch-steps.

The following is a script taking two true/false arguments: whether to apply jit and whether to use the fori_loop (the combination true flase takes very long to compile). It contains parts to save the samples, so the weights that should have been obtained can be checked afterwards, and it can be excluded that the error occurs because something changes about the sampling (up to numerical error the samples are the same independent of jit use -- as they should be using the same keys).

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as onp
import functools
import sys


def pop(arr, index):
    # drop and return value at index in arr
    rolled = jnp.roll(arr, -index, axis=0)
    return rolled[0], jnp.roll(rolled[1:], index, axis=0)


def insert_col(mat, col, index):
    # insert column col at index into matrix mat
    mat = jnp.roll(mat, -index, axis=1)
    mat = jnp.concatenate([col.reshape(-1, 1), mat], axis=1)
    return jnp.roll(mat, index, axis=1)


@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
    return jnp.sum(z ** 5) + p * jnp.prod(z)


@functools.partial(jax.vmap, in_axes=(0, None))
def weight(z, p):
    grads = grad_eqn(z, p)
    dep = jnp.argmax(jnp.abs(grads))
    
    grad_max, grad_rest = pop(grads, dep)
    col = (-grad_rest / grad_max)[:, None]
    
    mat = jnp.concatenate((jnp.eye(3, dtype=jnp.complex_), col), axis=1)
    mat = mat @ mat.T.conj()
    det = jnp.linalg.det(mat).real

    return 1 / det 


def sample_sphere(key, count, dim):
    points = jax.random.normal(key, (count, dim))
    return points / jnp.linalg.norm(points, axis=1, keepdims=True)


@jax.vmap
def solve_poly(p, q):
    # polynomial in t given by (q + t * p)**5
    coeffs = jnp.array([q**5, 5 * p**1 * q**4, 10 * p**2 * q**3,
                        10 * p**3 * q**2, 5 * p**4 * q, p**5])
    coeffs = jnp.sum(coeffs, axis=1)

    roots = jnp.roots(coeffs, strip_zeros=False)
    return p.reshape(1, -1) + roots.reshape(-1, 1) * q.reshape(1, -1)


def sample_poly(key, count):
    # solution has multiplicity 5, need count / 5 q's and p's
    base_count = jnp.ceil(count / 5).astype(int)

    # sample base_count comples p's and q's
    pqs = sample_sphere(key, 2, 2 * base_count * 5)
    ps, qs = (pqs[0] + 1j * pqs[1]).reshape(2, base_count, 5)

    sol = solve_poly(ps, qs)
    return sol.reshape(-1, 5)[:count, :]


@jax.vmap
def divide_largest(z):
    # divide by and drop largest absolute entry
    z0, z = pop(z, jnp.argmax(jnp.abs(z)))
    return z / z0


def monte_carlo(key, batches, batch_size, fori=True):
    keys = jax.random.split(key, batches)

    def batch_step(i, data):
        mean, samples = data
        key_sample = keys[i]

        zs = sample_poly(key_sample, batch_size)
        zs = divide_largest(zs)
        # save samples
        samples = jax.ops.index_update(samples, jax.ops.index[i, :, :], zs)

        weights = weight(zs, jnp.array(0.))
        return mean + jnp.mean(weights), samples

    mean = jnp.array(0.)
    samples = jnp.zeros((batches, batch_size, 4), dtype=jnp.complex_)

    if fori:
        mean, samples = jax.lax.fori_loop(0, batches, batch_step, (mean, samples))
    else:
        for i in range(batches):
            mean, samples = batch_step(i, (mean, samples))

    return mean, samples


if __name__ == '__main__':
    key = jax.random.PRNGKey(0)

    apply_jit = len(sys.argv) > 1 and sys.argv[1] == 'true'
    fori_loop = len(sys.argv) > 2 and sys.argv[2] == 'true'

    niter = 10
    batches = 51
    batch_size = 1000

    mc = functools.partial(monte_carlo, fori=fori_loop,
            batches=batches, batch_size=batch_size)

    if apply_jit:
        mc = jax.jit(mc)
        save_name = 'samples-jit-%i.npy' 
    else:
        save_name = 'samples-%i.npy'

    # skip some keys
    for i in range(25):
        _, key = jax.random.split(key)

    for i in range(niter):
        k0, key = jax.random.split(key)
        mean, samples = mc(k0)
        print(mean)
        # save samples to manually check computations
        # onp.save(save_name % i, samples)

Behavior

As noted, the sample values do not differ depending on jit and fori_loop combination. Computing the weights and means of weights manually from saved sample values always gives finite numerical values which are consistent with the ones obtained by no jit and no fori_loop use ($ python script.py false false). Depending on the computer, both cases in which fori_loop is used may give wrong values containing inf's. This occurred on both local machines I have tested with. Running the same code on colab, however, gives the right (and same) results in all combinations (which is why I suspect there is an underlying issue, not one in the code).

The following are the first 10 results obtained with the above script in two different environments and various combinations of jit and fori_loop:

XPS; false, false XPS; true, true XPS; false, true Colab; true, true
38.87047976604907 35.23827002862667 35.167724443321404 38.904195431290844
38.85501838205715 inf 35.21379197263621 38.875554832009456
38.87232142336747 35.07552733029048 35.16629159384102 38.9613642029768
38.82467883296542 35.268796318296 35.18550169177784 38.86870896981942
38.875347911324106 35.065090432638506 35.12925896136021 38.91082515791209
38.81607498879701 35.045350301233476 35.087313851691306 38.84038161357294
38.884758144142545 35.204243102525254 35.19112069680813 38.97964735892668
38.884639882640634 inf 35.23049623201075 38.907215776623836
38.96790493327401 inf 35.311082582397795 38.90340030598595
38.91302814793844 35.26023361519001 35.243122471869846 38.87890524435126

None of the complexities in the code seem to be removable without making the behavior disappear.

  • If the sampling is replaced by just uniform or normal random numbers no more inf's appear and all combinations give the same results.
  • Removing the wrapper function (which uses the fori_loop) around a batched step, and instead just returning the mean from a single batch removes the issue and all results are the same.
  • The gradient is taken of a function sum(z ** 5) + p * prod(z) where for p always p=0 is passed. Given this fact, the gradient can be manually replaced with 5 * z**4 (in the real application the gradient would potentially be less simple), which again removes the erroneous behavior.
  • The most peculiar dependency on the specific implementation is the following: since p=0, the term + p * prod(z) should not change the results. Removing it, however, also removes the issue (no nans and values ~38 not ~35). Even if present in the modified form + 0 * jnp.nan_to_num(p * jnp.prod(z)) it reintroduces the error.

Summary

The erroneous values seem to be connected with the use of fori_loop and the gradient of prod multiplied such that it should vanish. The behavior seems contingent on the random sampling used, making it difficult narrow down to a specific expression that is responsible. Specifically, computing the weights after the samples are computed gives the right results and doesn't reproduce the erroneous behavior. Any thoughts about how to narrow down here would be appreciated.

Testing environment

All tests were run with the CPU version of jax and jaxlib installed via pip. The current jax version on colab is 0.1.69.

The numerical results above were obtained on a Dell XPS 13 with i5-7200U CPU, jax version 0.1.70, and python 3.8.3. I also saw the same behavior on a desktop machine with jax 0.1.72, Xeon Silver 4114 CPU, and python 3.6.9. I'm not sure what other environment variables may be relevant.

closed time in 4 days

mathisgerdes

issue commentgoogle/jax

How to select the jax release version

For general questions about JAX, I suggest using the new "Discussions" tab on Github rather than using issues.

For your specific question, you can either form Python complex numbers using complex as you state, although that probably won't work if the arguments are numpy arrays, only Python literals. I usually just write x + y * 1j or similar, and that works fine with JAX arrays x and y.

XDongiang

comment created time in 4 days

issue commentgoogle/jax

How to select the jax release version

I'm sorry we don't provide docker images at the moment. Our supported installation instructions use pip wheels: https://github.com/google/jax#installation

XDongiang

comment created time in 4 days

issue closedgoogle/jax

How to select the jax release version

Taking a look again, it looks to me like this issue is resolved at head.

My version of the benchmark looks like this:

from jax import device_put
import numpy as onp
import jax.numpy as np
from jax import vmap
from functools import partial
import time
from jax import jit
from jax import grad

def invm_plus(Pb,Pc):
    Pbc = Pb + Pc
    _Pbc = Pbc * np.array([-1,-1,-1,1])
    return np.sum(Pbc * _Pbc,axis=1)

def invm(Pbc):
    _Pbc = Pbc * np.array([-1,-1,-1,1])
    return np.sum(Pbc * _Pbc,axis=1)

def BW(m_,w_,Sbc):
    i = complex(0,1)
    gamma=np.sqrt(m_*m_*(m_*m_+w_*w_))
    k = np.sqrt(2*np.sqrt(2)*m_*np.abs(w_)*gamma/np.pi/np.sqrt(m_*m_+gamma))
    return k/(m_*m_ - Sbc - i*m_*w_)

def phase(theta, rho):
    i = complex(0,1)
    return rho * np.exp(theta*i)

def _abs(bw_):
    conjbw = np.conj(bw_)
    return np.real(bw_*conjbw)

Kp = onp.random.sample(80000*4).reshape(80000,4)
Km = onp.random.sample(80000*4).reshape(80000,4)
Pip = onp.random.sample(80000*4).reshape(80000,4)
Pim = onp.random.sample(80000*4).reshape(80000,4)
phif001 = onp.random.sample(80000*2).reshape(80000,2)
phif021 = onp.random.sample(80000*2).reshape(80000,2)

phif0 = np.asarray([phif001,phif021])
phi = invm_plus(Kp,Km)
f0 = invm_plus(Pip,Pim)

phim = np.array([2.,1.,1.,1.,2.,1.,1.,1.])
phiw = np.array([1.,2.,1.,1.,2.,1.,1.,1.])
f0m = np.array([1.,1.,1.,3.,1.,1.,1.,1.])
f0w = np.array([1.,1.,1.,1.,1.,1.,1.,1.])
const = np.array([[2.,1.,1.,1.,1.,1.,1.,1.],[1.,1.,1.,1.,1.,1.,1.,1.]])
rho = np.array([1.,1.,2.,1.,1.,1.,1.,1.])
theta = np.array([1.,1.,1.,1.,3.,1.,3.,1.])


def BW_f0(phim,phiw,f0m,f0w,phi,f0):
    return vmap(partial(BW,Sbc=phi))(phim,phiw) * vmap(partial(BW,Sbc=f0))(f0m,f0w)

def phase_f0(theta_,rho_):
    result = vmap(phase)(theta_,rho_)
    return result

def test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0):
    ph = phase_f0(theta,rho)
    bw = BW_f0(phim,phiw,f0m,f0w,phi,f0)
    const_phase = np.einsum('ij,j->ij',const,ph)
    _phif0 = np.einsum('ijk,il->ljk',phif0,const_phase)
    _phif0 = np.einsum('ijk,ij->jk',_phif0,bw)
    _phif0 = np.real(np.sum(_abs(_phif0),axis=1))
    return -np.sum(np.log(_phif0))

test_pw_jit = jit(test_pw)

print(test_pw_jit(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))

m = (0,1,2,3,4,5,6)
grad_test_pw = jit(grad(test_pw_jit,argnums=m))
_ = grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0)[0].block_until_ready()
s = time.time()
print(grad_test_pw(phim,phiw,f0m,f0w,const,rho,theta,phif0,phi,f0))
e = time.time()
print("time : ", e - s)

def test_pw1(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0):
    temp = np.zeros(80000)
    for i in range(8):
        bw = BW(phim[i],phiw[i],phi) * BW(f0m[i],f0w[i],f0)
        _phif001 = phif001.T * bw *const[0,i]
        _phif021 = phif021.T * bw *const[1,i]
        _phif0 = (_phif001 + _phif021) * phase(theta[i],rho[i])
        temp = temp + _phif0
    data = np.real(np.sum(_abs(temp),axis=0))
    return -np.sum(np.log(data))

grad_test_pw1 = jit(grad(test_pw1,argnums=m))

#print(test_pw1_jit(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0))
_ = grad_test_pw1(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0)[0].block_until_ready()
s = time.time()
print(grad_test_pw1(phim,phiw,f0m,f0w,theta,rho,const,phif001,phif021,phi,f0))
e = time.time()
print("time : ", e - s)

Both versions seem roughly equal in performance, around 4-5ms on a P100.

Originally posted by @hawkinsp in https://github.com/google/jax/issues/1763#issuecomment-614885455

Please tell me the JAX version you are using. Is the version of cuda and cudnn matter?

closed time in 4 days

XDongiang

issue commentgoogle/jax

How to select the jax release version

I'd suggest using the current jaxlib (0.1.50) and jax (0.1.72). I don't think the CUDA/CuDNN versions matter that much. Hope that helps!

XDongiang

comment created time in 4 days

pull request commentgoogle/jax

Reimplement argmin/argmax using a single pass variadic reduction.

Done. I added a GPU-specific fallback to the original 2-pass algorithm.

hawkinsp

comment created time in 4 days

push eventhawkinsp/jax

Peter Hawkins

commit sha a5d77d2707b26334c1ed59f67532dfff69dddbe3

Add missing dtype arguments.

view details

push time in 4 days

push eventhawkinsp/jax

Jake Vanderplas

commit sha db8f66d508d536e1312291ba966905ddf910582d

Rework type support for lax cumulative reductions (#3609)

view details

Roy Frostig

commit sha e808681f6c096e95be8532e4e901f5d410e0fb58

add jit-consistency checks to a few lax_numpy tests

view details

Matthew Johnson

commit sha eb2a22758898b0470b24ee79d271723873d58956

fix reduction repeated axis error (#3618) * fix reduction repeated axis error * deflake

view details

Matthew Johnson

commit sha 107689e91f0940327a807e2a9890da3071e6feac

improve vmap axis spec structure mismatch errors (#3619) * improve vmap axis spec structure mismatch errors fixes #3613 * deflake

view details

Peter Hawkins

commit sha f852bcdb774da7c290ab7ad726be934d43e2138a

Reimplement argmin/argmax using a single pass variadic reduction.

view details

Peter Hawkins

commit sha 8276d465e43050a5d750b62e0977dd499390e3c8

Add argmin/argmax to jax2tf.

view details

Peter Hawkins

commit sha 89a860b4cc9f112d893beec7ee5030dc0864ae71

Fix comparator name.

view details

Peter Hawkins

commit sha c0ff2fcff1e23649d56f282cd2996a86bd1fedde

Set a zero jvp on argmin/argmax.

view details

Peter Hawkins

commit sha 65359502a506fc1bd843d73bed7fced12348c2af

Fix jax2tf rules.

view details

Peter Hawkins

commit sha 5f74ac5cdf21aba914a3ef91ca6e1d132fe07596

Fix typo.

view details

Peter Hawkins

commit sha b1a53c4cb9750a4ed4aa94d64be19ae02799299b

Add fallback to 2-pass algorithm for GPU. Make index dtype configurable.

view details

push time in 4 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 50a97bb9d56226817c707bc7ad377a6bc4f8bfec

Fix typo.

view details

push time in 5 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 47620da01dac682ebeb0b07e477422a961ca7b7c

Fix jax2tf rules.

view details

push time in 5 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 1702faaf8d854fe123982d127dda3cea3b9e37f1

Set a zero jvp on argmin/argmax.

view details

push time in 5 days

issue commentgoogle/jax

Inconsistencies and divergence depending on use of JIT

PR #3611 seems to fix the problem for me!

mathisgerdes

comment created time in 5 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 5034b328d5423c3d06cc6ca8b8d7a151fca6cb5b

Fix comparator name.

view details

push time in 5 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 9c7263b53c99377a7e19769cb2f7e99aeacc0cb7

Add argmin/argmax to jax2tf.

view details

push time in 5 days

PR opened google/jax

Reimplement argmin/argmax using a single pass variadic reduction.

Fixes #3602

+100 -27

0 comment

6 changed files

pr created time in 5 days

create barnchhawkinsp/jax

branch : argmin

created branch time in 5 days

issue commentgoogle/jax

Inconsistencies and divergence depending on use of JIT

I think what's happening here is that our 2-pass implementation of argmax is breaking for this benchmark because XLA chooses to recompute the the input for each pass, and it ends up at least a little bit different, breaking the exact equality the algorithm expects.

Here's a smaller reproduction of what I believe to be going wrong:

import functools

import jax
import jax.numpy as jnp
import numpy as onp

from jax.config import config
config.update("jax_enable_x64", True)

@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
  return jnp.sum(z**5) + p * jnp.prod(z)

@jax.vmap
def foo(pp):
  w = pp
  grads = grad_eqn(w, jnp.array(0.))
  dep = jnp.argmax(jnp.abs(grads))
  return dep

pp = onp.array([
    [
        0.013914733367254616 + 9.3094023595025388e-01j, 0.7261907418672447 -
        3.5631603853758342e-01j, 0.4192612993988415 + 1.5784670831992959e-01j,
        -0.3979203076971916 + 9.0370153633300654e-01j
    ],
    [
        0.7412515559421278 + 5.8209475619094264e-01j, -0.7882039695924775 +
        8.2271693304329360e-02j, 0.2473007941817975 + 2.9011353191281253e-01j,
        0.3953279913354844 - 8.7640244493755098e-02j
    ],
    [
        0.42107995148850474 - 5.2786845011285044e-01j,
        -0.9752016614993378 + 4.6493984002048627e-02j, 0.14899616328175808 -
        2.5660252776982036e-02j, 0.7597534752633921 - 2.9817585059867030e-01j
    ],
    [
        0.07919062080298368 + 2.0443288812872526e-01j, -0.9144440581831745 +
        5.0375688239721973e-02j, 0.7189319378618075 - 4.7781041939994917e-01j,
        0.28616184673974954 + 5.8406319450724586e-01j
    ],
    [
        -0.9124599088445945 + 2.0523072700406764e-02j,
        -0.49751861383647644 + 6.6934941229283518e-01j, -0.09127256955394417 -
        2.5171143504173576e-01j, 0.6288139554943047 - 6.0843626330134115e-01j
    ],
    [
        0.10383981031500984 - 5.8961771394978757e-01j,
        0.16722015050067904 - 5.4033783049084971e-01j,
        -0.3570686716415365 + 7.9098843585535017e-01j,
        -0.36717711934106606 - 8.7325230201503201e-01j
    ],
    [
        -0.4246625385849623 - 1.0866918975223600e-01j, -0.08009235275156984 -
        5.0556054151587737e-01j, 0.4353126751296819 - 1.7484451023920469e-01j,
        -0.29468982612935085 + 9.4805376113112527e-01j
    ],
    [
        -0.33924468705124716 + 9.3201718347059126e-01j, -0.31334017694099114 -
        1.0582535646176845e-01j, 0.6724881365502651 + 2.2066280675715644e-01j,
        -0.5552091363575214 - 1.4490197928805444e-02j
    ],
    [
        0.6526283624759489 + 4.9140814443027613e-02j, 0.723922878773778 +
        3.9038679286342901e-02j, 0.762488834126774 - 6.2337884354704665e-01j,
        -0.39085070216092266 + 7.9604864742611015e-01j
    ],
    [
        0.48937363384974575 - 5.1103393766115682e-02j, 0.7539474954530498 -
        3.3604089724531272e-01j, 0.7479090648182917 - 6.4041309134555557e-01j,
        -0.34829014191239205 + 3.2357526022662669e-01j
    ],
    [
        -0.09058883804930377 + 3.5194499540548291e-01j,
        -0.013080375527656812 - 5.3701680483099290e-01j, -0.7690608912513537 -
        7.1176000810729834e-02j, 0.7408780750401596 - 5.8939023080126030e-01j
    ],
    [
        -0.1448148911478361 + 1.1680988499664302e-01j,
        0.6357579332535874 + 4.8007163633817268e-01j, -0.20356990725787427 -
        9.2667400504246511e-01j, -0.486415685914065 - 6.7109379995585339e-01j
    ],
    [
        0.6208929968646449 - 2.6040937438741862e-01j,
        0.18721272656033006 + 4.2043707645400447e-01j, -0.29536647248428616 +
        7.3555331996715753e-02j, -0.9936616656553029 + 2.7635687670210948e-02j
    ],
    [
        0.12004262077972877 + 1.2780756181079903e-01j, -0.3058488919814152 +
        9.5148878693497396e-01j, 0.3781176412486092 - 1.4516939708606397e-01j,
        0.20414911926428903 + 2.8066844435687388e-01j
    ],
    [
        0.44393385508190725 + 4.1301429499018250e-01j,
        0.775660315846601 - 6.1108689231453883e-01j, -0.22979601723668555 +
        5.6425408946759868e-02j, 0.6348633420533681 - 2.1792338815000339e-01j
    ],
    [
        0.5903309747549792 + 5.7825608084895608e-04j, -0.9688715927965774 -
        8.0116867317087063e-03j, 0.6118501784052073 + 4.1552864146077695e-01j,
        0.011088080369843475 - 1.2357486664766727e-01j
    ]
],
               dtype=onp.complex128)
nojit = foo(pp)
withjit = jax.jit(foo)(pp)
print("  no jit: ", nojit)
print("with jit: ", withjit)
onp.testing.assert_allclose(nojit, withjit)

Output:

  no jit:  [3 0 1 1 0 3 3 0 2 2 3 2 3 1 1 1]
with jit:  [                  3                   0                   1
                   1                   0                   3
                   3                   0                   2
                   2                   3                   2
                   3                   1 9223372036854775807
 9223372036854775807]
Traceback (most recent call last):
  File "u.py", line 109, in <module>
    onp.testing.assert_allclose(nojit, withjit)
  File "/Users/phawkins/.pyenv/versions/py3.7.4/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1528, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/Users/phawkins/.pyenv/versions/py3.7.4/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 840, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 2 / 16 (12.5%)
Max absolute difference: 9223372036854775806
Max relative difference: 1.
 x: array([3, 0, 1, 1, 0, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1], dtype=int64)
 y: array([                  3,                   0,                   1,
                         1,                   0,                   3,
                         3,                   0,                   2,...

The fix is probably to switch our argmax implementation to use a 1-pass algorithm via a variadic reduction.

mathisgerdes

comment created time in 5 days

push eventgoogle/jax

Peter Hawkins

commit sha 420ef4e0a8bfb33386dcf0ed977ceb9aa40c58b8

Fix shape rule for lax.pad for input dimensions of size 0. (#3608)

view details

push time in 5 days

PR merged google/jax

Fix shape rule for lax.pad for input dimensions of size 0. cla: yes

Fixes #3599

+4 -3

0 comment

2 changed files

hawkinsp

pr closed time in 5 days

issue closedgoogle/jax

lax.pad breaks for zero-sized inputs

from jax import lax, numpy as jnp
out = lax.pad(jnp.ones((0,)), 0., ((1, 1, 1),))

print(out.shape) # (1,)
print(out)       # [0. 0.]
print(out[0])    # RuntimeError: Invalid argument: Argument does not match host shape or layout of computation parameter 0: want f32[1]{0}, got f32[2]{0}

Pad works as expected for zero-sized inputs for non-interior padding (i. e. padding config ((1, 1, 0),)), so I guess this should also work (or at least give an error).

closed time in 5 days

JuliusKunze

issue commentgoogle/jax

Inconsistencies and divergence depending on use of JIT

Can you share the output of lscpu on all the hosts you've tried this code on? I'm speculating that this might be because of a compiler bug, and one reason this might differ across machines is that the compiler may be targeting different CPU architectures.

mathisgerdes

comment created time in 5 days

PR opened google/jax

Fix shape rule for lax.pad for input dimensions of size 0.

Fixes #3599

+4 -3

0 comment

2 changed files

pr created time in 5 days

create barnchhawkinsp/jax

branch : pad

created branch time in 5 days

issue commentgoogle/jax

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU

Wouldn't it suffice to look for the first 0 on the diagonal of the input matrix?

(I'm also just trying to understand what API contract you expect, because the usual contract says "this is an illegal input". It's possible we can make the XLA algorithm mimic the behavior of the usual TRSM algorithm in this case, but it's not clear to me the behavior in the singular case is actually well defined without also fixing the choice of algorithm.)

dpfau

comment created time in 6 days

issue commentgoogle/jax

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU

I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, scipy.linalg.solve_triangular would raise a singular matrix exception in the corresponding situation.

dpfau

comment created time in 6 days

issue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

I should add that it would be possible to workaround this bug by avoiding U16/S16 types on the JAX side. Since we have larger integer types available (U32, S32) this is possible if slightly annoying to do.

levskaya

comment created time in 6 days

issue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

Google internal bug b/156977343.

levskaya

comment created time in 6 days

issue commentgoogle/jax

Add VJP expm

In the absence of the ability to define a custom transpose rule which allows JAX to derive the VJP from the JVP, I think the best bet for supporting both forward and reverse-mode autodiff is to upgrade expm_frechet into a first-class JAX primitive. This isn't much harder than the current custom_jvp approach, but would use internal APIs. See: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

LoicRaillon

comment created time in 6 days

issue commentgoogle/jax

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU

For the batched case the implementation switches to a completely different algorithm (actually, the same implementation used on TPU): https://github.com/google/jax/blob/7b57dc8c8043163a5e649ba66143ccef880d7d58/jax/lax_linalg.py#L436 For the batch 1 case, we call LAPACK TRSM which acts as you say.

The batched case calls into XLA, which uses an algorithm inspired by MAGMA that inverts diagonal blocks: https://github.com/tensorflow/tensorflow/blob/bd006c354f11f9045d344f3e48b47be9f8368dac/tensorflow/compiler/xla/service/triangular_solve_expander.cc#L439

dpfau

comment created time in 7 days

Pull request review commentgoogle/jax

Add cummax and cummin

 def _cumprod_jvp_rule(primals, tangents, *, axis: int):   return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)  +def _cummax_jvp_rule(primals, tangents, *, axis: int, unit: Number):

Can we share the implementation here and make it parametric?

ekelsen

comment created time in 7 days

Pull request review commentgoogle/jax

Add cummax and cummin

 def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int): batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)  +cummax_p = standard_primitive(

I suspect we can share all of these with cumprod too by adding a helper function, perhaps something like this:

cummax_p = _cumulative_reductive(lax.max, ...)
ekelsen

comment created time in 7 days

issue closedgoogle/jax

Implement matrix exponential

This is np.linalg.expm. It probably makes sense to port the TF implementation in terms of Pade approximants (https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/linalg/linalg_impl.py#L228-L339), which itself is a port from Eigen.

closed time in 7 days

jekbradbury

issue commentgoogle/jax

Implement matrix exponential

expm is implemented, together with its JVP. I'm going to close this issue, leaving https://github.com/google/jax/issues/3447 open for the reverse-mode derivative.

jekbradbury

comment created time in 7 days

issue closedgoogle/jax

Gradient of functions involving `expm`

Hi, the following attempt to take the gradient of a function involving a matrix exponential

import jax.numpy as jnp
import jax
from jax.scipy.linalg import expm

def foo(mat):
    return jnp.sum(expm(mat))

grad = jax.grad(foo)(jnp.eye(2))

is failing with the following assertion error

~/anaconda3/envs/envname/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     96   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     97   aval_primals, const_primals = unzip2(pval_primals)
---> 98   assert all(aval_primal is None for aval_primal in aval_primals)
     99   if not has_aux:
    100     return const_primals, pval_tangents, jaxpr, consts

AssertionError: 

I'm using jax.__version__ == 0.1.62 and jaxlib.__version__ == 0.1.43, thanks!

closed time in 7 days

danieljtait

issue commentgoogle/jax

Gradient of functions involving `expm`

The forward-mode derivative was merged as part of #2062 and https://github.com/google/jax/issues/3447 is open for the reverse-mode derivative. I'm going to close this issue since I think the remaining work is a duplicate of #3447

danieljtait

comment created time in 7 days

pull request commentgoogle/jax

add lax.associative_scan to docs

The docs appear to have a several rendering problems: https://jax--3583.org.readthedocs.build/en/3583/_autosummary/jax.lax.associative_scan.html#jax.lax.associative_scan

Please fix?

mattjj

comment created time in 7 days

push eventgoogle/jax

Sri Hari Krishna Narayanan

commit sha 7b57dc8c8043163a5e649ba66143ccef880d7d58

Issue1635 expm frechet (#2062) * Implement Frechet derivatives for expm. * Update expm to use the current custom gradients API. Make some stylistic fixes. Co-authored-by: Peter Hawkins <phawkins@google.com>

view details

push time in 7 days

PR merged google/jax

Reviewers
Issue1635 expm frechet cla: yes

This is an implementation of scipy.linalg.expm_frechet to compute the derivatives of matrix exponentiation. It was mentioned in #1635 as a way to provide derivatives for jax.scipy.linalg.expm which has itself been implemented in #1940 .This pull request does not set the defvjp rules to use scipy.linalg.expm_frechet because testing such an enhancement seems nontrivial. EDIT: The defjvp rule has been added

There are a few things to note:

  1. scipy.linalg.expm_frechet_algo_64 creates an identity matrix in the following way:ident = np.identity(n). This is problematic because it does not ensure that ident and A have the same dtype. Note further that scipy.linalg.expm does set the dtype of the identity matrix to be the same as A. The dtype of the output of scipy.linalg.expm_frechet can therefore be different from its arguments. For this reason, _CheckAgainstNumpy is called with check_dtypes=False.

Major thanks to @jhueckelheim for identifying this problem.

  1. There was a need to set the dtypeof the constant 2.0 within _diff_pade13 to A.dtype to ensure numerical correctness. Originally the code was:A = A * 2.0**-s We found that the dtype of the expression A * 2.0**-s is not necessarily the same as A. The code was changed to: two = np.array([2.0],A.dtype) A = A * two[0]**-s
+352 -94

16 comments

3 changed files

sriharikrishna

pr closed time in 7 days

push eventsriharikrishna/jax

Peter Hawkins

commit sha 6daf44219ac13853baf4e811d9a10e33c9aa9001

Fix flake8 error.

view details

push time in 7 days

push eventgoogle/jax

Peter Hawkins

commit sha 1f2025e12f469d04deff45fb7ba8f12530254fe2

Incorporate a few recent NumPy API extensions. (#3586)

view details

push time in 7 days

PR merged google/jax

Small updates to match NumPy. cla: yes
+18 -9

0 comment

1 changed file

hawkinsp

pr closed time in 7 days

push eventgoogle/jax

Peter Hawkins

commit sha 17fc8b75c26d27c55be022618b6b567e59611ace

Implement DeviceArray.tobytes(). (#3585) Move tolist() implementation onto DeviceArray; there's no reason to populate it dynamically.

view details

push time in 7 days

PR merged google/jax

Implement DeviceArray.tobytes(). cla: yes

Move tolist() implementation onto DeviceArray; there's no reason to populate it dynamically.

Fixes https://github.com/google/jax/issues/1791

+11 -1

0 comment

3 changed files

hawkinsp

pr closed time in 7 days

issue closedgoogle/jax

converting JAX arrays to bytes

The method I can use,

import jax.numpy as np
import struct
import numpy as onp
from random import random

def a():
    a = np.zeros((100,200,)) + random()
    h, w = a.shape
    shape = struct.pack('>II',h,w)
    encoded = shape + onp.asarray(a).tobytes()
%timeit a()


The slowest run took 90.44 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 608 µs per loop

Now if I use only change np to onp

import jax.numpy as np
import struct
import numpy as onp
from random import random

def a():
    a = onp.zeros((100,200,)) + random()
    h, w = a.shape
    shape = struct.pack('>II',h,w)
    encoded = shape + onp.asarray(a).tobytes()
%timeit a()


The slowest run took 61.96 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 47.6 µs per loop

Is there any particular reason why the tobytes method is not there natively for jax.numpy?

closed time in 7 days

dchatterjee172

push eventsriharikrishna/jax

Matthew Johnson

commit sha cc53aa956b0571efb3b0237dd87d92d509f8b1fd

skip new optix test on tpu (cf. #2350)

view details

Peter Hawkins

commit sha 9fd69a04ea6072479513ce45391d4fbe37998715

Replace uses of ExecutePerReplica with ExecuteOnLocalDevices. (#2394) ExecutePerReplica is deprecated, and ExecuteOnLocalDevices is now available via the minimum jaxlib version.

view details

Matthew Johnson

commit sha ebbcbad547e56357e600cd8c19232d4b91cf4f00

allow vmap in_axes to be a list, fixes #2367 (#2395)

view details

Matthew Johnson

commit sha cfbdb65ad8637e883679a0d0516acdc28dd9e8ea

add register_pytree_node_class, fixes #2396 (#2400) Co-authored-by: Stephan Hoyer <shoyer@google.com> Co-authored-by: Stephan Hoyer <shoyer@google.com>

view details

Jordan Hoffmann

commit sha ffa03403c9c6d9f071125cdc793c9b88e421cf13

Add jnp vs np out of bounds indexing to Sharp Bits nb (#2378)

view details

Peter Hawkins

commit sha 419961f9dd1365f741c68bd88daee3c6f89b43d7

Check for invalid shapes in broadcast_in_dim and fail gracefully.

view details

Matthew Johnson

commit sha cdf188af2fd4f256c2c5c390ec0d09ed321212d0

add raises-exception notebook cell metadata (#2402)

view details

Daniel Johnson

commit sha 2dfeaeb63fa9e884ef5b76bc43cf99b2c5a5c04f

Allow zero tolerance for jax.test_util.tolerance (#2393) Currently, if a user passes any falsy value to jax.test_util.tolerance, it is changed to the default value. This makes sense when the value passed is None, but not when the value passed is 0 (which indicates a desired tolerance of exactly 0). Disables failing tests for now.

view details

Trevor Cai

commit sha 620bf4300b74c298a5e0133e08f60f76700cf37f

[remat] Change remat lowering to XLA::Conditional (#2391) * [remat] Change remat lowering to XLA::Conditional `jax.remat` creates rematerializing passes that don't have data dependencies on the actual loss-computing forward pass. This means that the XLA scheduler was free to schedule the remat forward pass before the loss-computing pass, defeating the goal of saving accelerator memory with `jax.remat`. In practice, it sometimes did for my workloads. This change expresses the lowering of remat_call(f) as: Conditional(true, inputs, f, inputs, dummy_f). In the common case of `jax.grad(jax.remat(f))`, the content of the lowered remat_call are both the forwards & backwards; that is, the incoming cotangents are part of the args. Additionally, Conditional (AFAIK) is un-inlineable in the sense that it doesn't execute until all its inputs (e.g. cotangents!) are available. Downsides: - AFAICT, we can no longer interleave computation in/outside the rematerialized block. - Potentially, lower performance. I do not observe this in my tests. * provide no replication info for subcomputation params

view details

George Necula

commit sha 61b430eeb40aeef3254f50dbcb79271e7ab3db96

Added more documentation for how to fix notebook build failures (#2404)

view details

Peter Hawkins

commit sha cf41f7682fef099fe1810bd49e64c9439e2d4f3d

Add np.linalg and np.fft functions to documentation. (#2407)

view details

Skye Wanderman-Milne

commit sha 58feed2bcb6802d2b560712648d11a441c82909e

jax.lax.nextafter test fix. (#2408) Fixes #2403.

view details

Matthew Johnson

commit sha 7f0463e2c9d7dfbe9c451baabb7c603eac6a4b3d

remove input shapes from params of some primitives (#2410) Long, long ago, when JAX was first born, we realized that we couldn't transpose this jaxpr: { lambda ; a. let b = reduce_sum[ axes=(0,) ] a in b } The problem was that the transpose of a reduce-sum is a broadcast, but because jaxprs didn't have shape information available, we didn't know what input shape to broadcast to! Our hack was to have the primitives that required shape information for transposition to acquire it into their parameters, so that we'd produce jaxprs like this one: { lambda ; a. let b = reduce_sum[ axes=(0,) input_shape=(3,) ] a in b } That's not only aesthetically unpleasant, but also it meant we were limiting an (unused) capability of the system: ideally we should be able to trace a reduce-sum jaxpr without specializing on shape information (e.g. at the Unshaped level) and only require shape specialization for transposition. (Good thing no one actually traces at Unshaped...) But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that shape information (or whatever information with which the jaxpr was specialized out of Python) is in the jaxpr itself. So we could finally remove these shapes-in-params warts! That's exactly what this commit does! Co-authored-by: Roy Frostig <frostig@google.com> Co-authored-by: Roy Frostig <frostig@google.com>

view details

Peter Hawkins

commit sha 271041b499029d639bade9e1c77d3dc64a89f9cd

Update a regression test to test size-zero device to device transfers. (#2411)

view details

Matthew Johnson

commit sha c41c4de5a54e16998fdb6b9e238f301d6299d078

lower fori_loop to scan when possible (#2414) When a fori_loop specialized on trip count is to be evaluated, it's preferable to generate a scan rather than a while_loop because the former is reverse-mode differentiable while the latter is not. Otherwise they're essentially the same; in particular, no extensive inputs/outputs arise unless reverse-mode autodiff is applied. Also fixes #2412.

view details

Matthew Johnson

commit sha 47df7b95c44d27a2bb78636c7642a60cdb622402

change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption

view details

Matthew Johnson

commit sha e84a621184967618997e2b0018fa88979735a7cd

new jet implementation, with conv-based rules Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>

view details

Matthew Johnson

commit sha a21fdf8669437a7a052983e18112b3379b955290

more jet rules and tests Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>

view details

Matthew Johnson

commit sha 7adf9fe84f5b522ee9120a7ff7763ea0708fc394

add more jet rules! Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>

view details

Jacob Kelly

commit sha ddd52c47301d48cb65b6c7098a164b99362efa3a

adding div and linear prims Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>

view details

push time in 7 days

pull request commentgoogle/jax

Issue1635 expm frechet

I rebased this PR on top of head, and updated it to use the new custom_jvp API.

sriharikrishna

comment created time in 7 days

PR opened google/jax

Small updates to match NumPy.
+18 -9

0 comment

1 changed file

pr created time in 7 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 1909fe7dc51d3652163d08d962797ced0aab110b

Incorporate a few recent NumPy API extensions.

view details

push time in 7 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 40f5cce2bf9cc62100e0a4f296cd60feaa51e696

Incorporate a few recent NumPy API extensions.

view details

push time in 7 days

create barnchhawkinsp/jax

branch : numpyfixes

created branch time in 7 days

PR opened google/jax

Implement DeviceArray.tobytes().

Move tolist() implementation onto DeviceArray; there's no reason to populate it dynamically.

Fixes https://github.com/google/jax/issues/1791

+11 -1

0 comment

3 changed files

pr created time in 7 days

create barnchhawkinsp/jax

branch : tobytes

created branch time in 7 days

issue closedgoogle/jax

How to trace memory usage?

I got out-of-memory error during the backward, after run four times forward and backward plus one forward. I have no idea why it keep increasing memory usage. Do you have any method to trace memory usage? I knew Jax initially allocate 90% memory, but I want to trace actual usage.

File ".local/lib/python3.6/site-packages/jax/api.py", line 398, in value_and_grad_f g = vjp_py(onp.ones((), dtype=dtype)) File ".local/lib/python3.6/site-packages/jax/api_util.py", line 66, in apply_flat_fun_nokwargs ans = fun(*args) File ".local/lib/python3.6/site-packages/jax/interpreters/ad.py", line 114, in vjp_ _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts) File ".local/lib/python3.6/site-packages/jax/interpreters/ad.py", line 178, in backward_pass eqn.params, subjaxpr, sub_consts, sub_freevar_vals, invals, cts_in) File ".local/lib/python3.6/site-packages/jax/interpreters/ad.py", line 458, in call_transpose out_flat = primitive.bind(fun, *all_args, **params) File ".local/lib/python3.6/site-packages/jax/core.py", line 569, in call_bind outs = primitive.impl(f, *args, **params) File ".local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 369, in _xla_call_impl return compiled_fun(*args) File ".local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 403, in _execute_compiled out_bufs = compiled.Execute(input_bufs).destructure() RuntimeError: Resource exhausted: Out of memory while trying to allocate 2692743312 bytes.

closed time in 7 days

neon5d

issue commentgoogle/jax

How to trace memory usage?

We just added a device memory profiler to JAX that should be exactly what you need to track down this sort of memory leak.

https://jax.readthedocs.io/en/latest/device_memory_profiling.html

I tried to reproduce the OOM you describe with JAX from head and it doesn't seem to reproduce. The memory usage appears to remain constant at around 4GB. I've attached a device memory profile visualization as produced by the memory profiling tool (on CPU). profile002.pdf

Please feel free to reopen if you have any more problems!

neon5d

comment created time in 7 days

issue closedgoogle/jax

Could not find cublas.h

I'm using the install of CUDA 10.1 with Arch Linux and got the following error. The problem seems to be that although I specified a path of /opt/cuda the script seems to be looking for the includes in /opt/cuda/lib64

$ python build.py --enable_march_native --enable_mkl_dnn --enable_cuda --cuda_path /opt/cuda
...
Bazel binary path: ./bazel-0.24.1-linux-x86_64
Python binary path: ~/miniconda3/envs/pytorch_dev/bin/python
MKL-DNN enabled: yes
-march=native: yes
CUDA enabled: yes
CUDA toolkit path: /opt/cuda
CUDNN library path: /opt/cuda

Building XLA and installing it in the jaxlib source tree...
ERROR: Skipping ':install_xla_in_source_tree': error loading package 'build': in ~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/tensorflow/core/platform/default/cuda_build_defs.bzl: Encountered error while reading extension file 'cuda/build_defs.bzl': no such package '@local_config_cuda//cuda': Traceback (most recent call last):
	File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 1266
		_create_local_cuda_repository(repository_ctx)
	File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 988, in _create_local_cuda_repository
		_get_cuda_config(repository_ctx)
	File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 714, in _get_cuda_config
		find_cuda_config(repository_ctx, ["cuda", "cudnn"])
	File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 694, in find_cuda_config
		auto_configure_fail(("Failed to run find_cuda_config...))
	File "~/.cache/bazel/_bazel_david/9d74510baba13ae20478b5800bcfdf90/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 325, in auto_configure_fail
		fail(("\n%sCuda Configuration Error:%...)))

Cuda Configuration Error: Failed to run find_cuda_config.py: Could not find any cublas_api.h matching version '' in any subdirectory:
        ''
        'include'
        'include/cuda'
        'include/*-linux-gnu'
        'extras/CUPTI/include'
        'include/cuda/CUPTI'
of:
        '/opt/cuda/extras/CUPTI/lib64'
        '/opt/cuda/lib64'
        '/opt/cuda/nvvm/lib64'
        '/usr'
        '/usr/lib'
        '/usr/lib/intel'
        '/usr/lib/libfakeroot'
        '/usr/lib/openmpi'

As a work around I created a symlink from /usr/include/cuda to /opt/cuda/include but that's not an ideal solution.

closed time in 7 days

dhpollack

issue commentgoogle/jax

Could not find cublas.h

Seems like this issue may be moot now, especially given the last comment. Closing; feel free to reopen if you are still having problems!

dhpollack

comment created time in 7 days

issue closedgoogle/jax

Buffer donation to a jit function

Below is a CNN iteratedly applied to a 2Gb input. It produces a 4x2Gb = 8 Gb peak memory consumption.

import jax.numpy as np
import jax.random as random
from jax import lax
from jax import jit

@jit
def f(x):
  for _ in range(10):
    x = lax.conv_general_dilated(x, np.ones((3, 3, 1, 1)), (1, 1), 'SAME', 
                                 dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
  return x

x = random.normal(random.PRNGKey(1), (2**19, 2**5, 2**5, 1))  
# (2**20, 2**5, 2**5, 1)) OOMs!
x = f(x)

Without JIT, the peak memory consumption is 2x2Gb = 4 Gb, as is expected.

Would be great to achieve a comparable memory usage with JIT by input buffer donation to the jit function (not sure on the exact terminology).

Thanks a lot!

closed time in 7 days

romanngg

issue commentgoogle/jax

Buffer donation to a jit function

Buffer donation has been checked in!

romanngg

comment created time in 7 days

issue closedgoogle/jax

Compilation error (nv-nsight-cu-cli not found)

I am trying to compile JAX from source with GPU support. My GCC version is 8.3. When I run

$ python build/build.py --enable_cuda

It looks for the file /usr/local/cuda-10.1/NsightCompute-*/nv-nsight-cu-cli but it doesn't exist.

I fixed it temporarily by creating the following soft link.

$ ln -s /usr/local/cuda-10.1/NsightCompute-/nv-nsight-cu-cli.orig /usr/local/cuda-10.1/NsightCompute-/nv-nsight-cu-cli

closed time in 7 days

adityaiitb

issue commentgoogle/jax

Compilation error (nv-nsight-cu-cli not found)

I suspect this issue is obsolete now; closing. Please feel free to reopen if you can still reproduce it!

adityaiitb

comment created time in 7 days

issue closedgoogle/jax

matmul slow for complex dtypes

Hey, thanks for the great work here!

I noticed that matmuls for complex dtypes are ~ 20x to 25x slower on my macbook than they are for real dtypes. Here is a simple code that does the timing. Wondering if this is expected, or what I can do to speed that up. Thanks!

import numpy as np
import jax
from jax import config
config.update('jax_enable_x64',True)
import time


@jax.jit    
def do_matvec_simple(matrix, vector):
    res = 0
    for _ in range(100):
        res += matrix @ vector
    return res

@jax.jit    
def do_matmul_simple(matrix1, matrix2):
    res = 0
    for _ in range(100):
        res += matrix1 @ matrix2
    return res
    
def run_timings_dot(dtype, D):
    matrix = jax.numpy.array(np.random.rand(D,D).astype(dtype))
    vector = jax.numpy.array(np.random.rand(D).astype(dtype))
    t1=time.time()
    for _ in range(100):
        res = matrix @ vector
    res.block_until_ready()
    print(f'loop over 100 matrix-vector muls in dtype {np.dtype(dtype).name}', time.time() -t1)

    res = do_matvec_simple(matrix, vector)
    res.block_until_ready()
    t1 = time.time()
    res = do_matvec_simple(matrix, vector)
    res.block_until_ready()
    print(f'jit 100 do_matvec_simple for dtype {np.dtype(dtype).name}', time.time() - t1)
    
def run_timings_matmul_simple(dtype, D):
    A = jax.numpy.array(np.random.rand(D,D).astype(dtype))
    B = jax.numpy.array(np.random.rand(D,D).astype(dtype))
    t1=time.time()
    for _ in range(100):
        res = A@B
    res.block_until_ready()
    print(f'loop over 100 matrix-matrix muls in dtype {np.dtype(dtype).name}', time.time() -t1)

    res = do_matmul_simple(A,B)
    res.block_until_ready()
    t1 = time.time()
    res = do_matmul_simple(A,B)
    res.block_until_ready()
    print(f'jit 100 do_matmul_simple for dtype {np.dtype(dtype).name}', time.time() - t1)        
    

print('########  timings for matrix-vector  ###########')
print('      ----------   float64   --------------')
run_timings_dot(np.float64, 1000)
print('      ----------   complex128   --------------')
run_timings_dot(np.complex128, 1000)
print()
print()
print('########  timings for matrix-matrix  ###########')
print('      ----------   float64   --------------')
run_timings_matmul_simple(np.float64, 400)
print('      ----------   complex128   --------------')
run_timings_matmul_simple(np.complex128, 400)

update: disabling double precision seems to increase the slowdown to ~ 100x

closed time in 8 days

mganahl

issue commentgoogle/jax

matmul slow for complex dtypes

This bug should be fixed in jaxlib 0.1.48 or newer (jaxlib 0.1.50 was released yesterday).

mganahl

comment created time in 8 days

PR merged google/jax

Reviewers
Add a heap profiler API and document it. cla: yes
+1239 -5

1 comment

9 changed files

hawkinsp

pr closed time in 9 days

push eventgoogle/jax

Peter Hawkins

commit sha 5116fd47aa9cf44888b88c685c380422bf0fdc50

Add a heap profiler API and document it. (#3576)

view details

push time in 9 days

pull request commentgoogle/jax

Add a heap profiler API and document it.

I'm going to merge this given Skye has reviewed; if Jake has comments I'm happy to make them in a follow-up PR.

hawkinsp

comment created time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 0332dbdf6d6884af42d690b2bdb393d0e03c5257

Add more crossreferences.

view details

push time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 0bb119ca4a9eb229d5179a4f94c73dd2b2cdddd5

Incorporate review comments.

view details

push time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha 71c1a120f1f64f3ed59f007362d4f2404eedd5f2

Fix flake8 warnings.

view details

push time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha d22604482dccc7975929a2e8ce1b72e004e03e2d

Add note about gperftools version of pprof.

view details

push time in 9 days

PR opened google/jax

Reviewers
Add a heap profiler API and document it.
+1224 -5

0 comment

9 changed files

pr created time in 9 days

create barnchhawkinsp/jax

branch : heapprofile

created branch time in 9 days

created taggoogle/jax

tagjaxlib-v0.1.50

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

created time in 9 days

push eventgoogle/jax

Peter Hawkins

commit sha b576417e5bb760fbb13097b7e3277f2715c06965

Update README for jaxlib 0.1.50 release. (#3574)

view details

push time in 9 days

PR merged google/jax

Update README for jaxlib 0.1.50 release. cla: yes
+8 -10

0 comment

1 changed file

hawkinsp

pr closed time in 9 days

PR opened google/jax

Update README for jaxlib 0.1.50 release.
+8 -10

0 comment

1 changed file

pr created time in 9 days

push eventhawkinsp/jax

Daniel Johnson

commit sha 98d46d39aa9138fbd8143411b99dc42ef0a36ad3

Implement indexing helpers

view details

Daniel Johnson

commit sha 72593f0489315cefe685f36f6c89ae8b81dfc860

Modify syntax to `x.at[idx].set(y)` and similar.

view details

Daniel Johnson

commit sha ab0a903172fdff74287b047ae7e0e26b6fc1c0fc

Merge branch 'master' into index_sugar

view details

Daniel Johnson

commit sha 9e429907c1236e600473d2ff31008cb81e9d1395

Add support for `mul`

view details

Daniel Johnson

commit sha f4f67b065f5b4663474df8e91785524101387a36

Remove unused textwrap

view details

Peter Hawkins

commit sha d356c41f77401fb796a763f19beda65f86300f50

Release jaxlib 0.1.44. (#2740)

view details

Jamie Townsend

commit sha 708107ebe30826d428b56cebe553b0e9016a0c80

Add numpy.rint to lax numpy (#2724) * Add numpy.rint to lax numpy * Use round_to_nearest_even for numpy.rint * Add rint to jax.numpy docs * Fix np.rint float promotion

view details

Chris Jones

commit sha 903b50ebf102b3152d60a7cc0e465bdeeee2c8f5

Some cleanup and reformatting in `xla.py`. - Make creation of a few dictionaries more readable. - Use f-strings where possible. - Remove unused imports and function parameters. - Don't format string before passing to `log` function.

view details

Chris Jones

commit sha f7070e56d19b0cbdac7d6f2474bae6dde673d612

Merge branch 'master' into changelist/306845248

view details

Peter Hawkins

commit sha 42887172b04c8933d628d5101221d38b4815114f

Add a regression test that runs the same computation on all devices that are present. (#2741)

view details

Roy Frostig

commit sha d906c89e0bf39f6b41d360db27b1c70b9b29fdad

fix scipy_signal_test convolve failures

view details

Matthew Johnson

commit sha 6d889efd30d8c1b13eceb2493e0cb3b726fbec0a

Merge pull request #2743 from chr1sj0nes/changelist/306845248 Some cleanup and reformatting in `xla.py`.

view details

Daniel Johnson

commit sha a5efe842562af81cae4df7aca95e52981dea2802

Update names and documentation.

view details

Daniel Johnson

commit sha 835949a36a22df40e3206a29c84a520e19793c40

Fix typo

view details

Jake VanderPlas

commit sha 3265b8c3c7839a2e805ab32d3a314e512127e8e3

Thread precision through np.convolve & np.correlate

view details

Jake VanderPlas

commit sha aeb0d036cb71af55212af2b112d698a8b9bd739d

set precision=HIGHEST only for TPU test

view details

Adam Paszke

commit sha 8c00b35f21e51163dcbd8b427d4c861b6d6cc13a

Add FIXMEs for AD type errors

view details

George Necula

commit sha 7d716b8306d6df0c012e9bfffeb1df439e3e560d

Add a simple form of partial evaluation for while_loop. (#2497) The issue that I wanted to fix was that when running grad(while_loop), the error was a cryptic assertion failure (that all primals are known after linearization, in ad.py:linearize). I could not figure out how to detect before that assertion that we are doing a reverse AD for while_loop. So, I implemented a simple form of partial evaluation, to allow the primals after linearization to be known, so that the code proceeds and can then fail gracefully when trying to transpose the while. This is not a proper implementation of partial evaluation. The known outputs are computed early, properly. But the unknown outputs are computed by a *whole* computation of, including the known parts. Fixes issue: #2129

view details

Peter Hawkins

commit sha 9a5b8d626a484d7dfded25bee7b408ce7d37816b

Assert that reduction computations don't have constants. (#2754) This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.

view details

Stephan Hoyer

commit sha e6f0b8d87dbc48e3b36ec980485a19a8543508d1

Raise an error if stop_gradient is called on non-arrays (#2750) * Raise an error if stop_gradient is called on non-arrays * Fix incorrect usage of stop_gradient in solve() * fix *other* misuse of stop_gradient

view details

push time in 9 days

PR merged google/jax

Fix test failures on GPU. cla: yes
+4 -4

0 comment

4 changed files

hawkinsp

pr closed time in 9 days

push eventgoogle/jax

Peter Hawkins

commit sha 66cea0277cf3f02ed3975465761c061a89e084a0

Fix test failures on GPU. (#3572)

view details

push time in 9 days

PR opened google/jax

Fix test failures on GPU.
+4 -4

0 comment

4 changed files

pr created time in 9 days

create barnchhawkinsp/jax

branch : gpuerrors

created branch time in 9 days

pull request commentgoogle/jax

Refactor tests to define common functionality in test_util

Yes, unfortunately parameterizing the device under test is tricky because it needs to happen after flag parsing but before test case instantiation. This is why we have an unusual idiom of parsing the flags early in tests, right after module imports. If you can think of a better scheme for running the same tests on multiple devices that works in all the configurations we need to run tests I'm certainly open to it!

jakevdp

comment created time in 9 days

push eventgoogle/jax

Peter Hawkins

commit sha 8f86b139fe78b9e65bcdec7f4244cfc105b8b629

Update docker script for CUDA 11. (#3571)

view details

push time in 9 days

PR merged google/jax

Update docker script for CUDA 11 and libcudnn8. cla: yes
+8 -4

0 comment

1 changed file

hawkinsp

pr closed time in 9 days

PR opened google/jax

Update docker script for CUDA 11 and libcudnn8.
+8 -4

0 comment

1 changed file

pr created time in 9 days

push eventhawkinsp/jax

Peter Hawkins

commit sha a141cc6e8d36ff10e28180683588bedf5432df1a

Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555) * Use dynamic loading to locate CUDA libraries in jaxlib. This should allow jaxlib CUDA wheels to be manylinux2010 compliant. * Tag CUDA jaxlib wheels as manylinux2010. Drop support for CUDA 9.2, add support for CUDA 11.0. * Reorder CUDA imports.

view details

Peter Hawkins

commit sha 2a6fc316c3e2de99abdf4a97656abbed44a1c626

Update XLA in preparation for a new jaxlib release (0.1.50). (#3560)

view details

Matthew Johnson

commit sha db80ca5dd87d27e1800834c2828298aa3f6d8992

allow closures for odeint dynamics functions (#3562) * allow closures for odeint dynamics functions fixes #2718, #3557 * add tests for odeint dynamics closing over tracers

view details

Matthew Johnson

commit sha c2501d1bef7e97e597c2108c86645f265a249326

update version and changelog for pypi (#3564)

view details

Matthew Johnson

commit sha 062ce297ddf9056ca7743a2d262b0db070eb4553

removed stale faq entries (#3565)

view details

Jake Vanderplas

commit sha 4b1bb1890937e9fb81de210305c77c13fb4952ec

Fix logsumexp test (#3563) Previously, this test could generate an axis out of bounds error with large num_generated_cases. (discovered in the process of testing #3561)

view details

Matthew Johnson

commit sha 26c6c3a457414577f7232699095877d6c86d032d

fix error when doing forward-mode of odeint (#3566) fixes #3558

view details

Peter Hawkins

commit sha 99a7e2c9ef8e9a9d4218b64ccd4f2990e08a9c12

Update docker script for CUDA 11.

view details

push time in 9 days

issue commentgoogle/jax

Feature request: add jax[cudaversion] to pypi

#3555 solves one of the problems (manylinux2010 compilance), but not all of them. The others are:

  • we would need a PyPi size limit exception
  • we would probably consider carefully how to structure our packages, if there are multiple variants that may potentially be installed at the same time.
KristianHolsheimer

comment created time in 9 days

issue commentgoogle/jax

Using non-nvidia GPU

There's no technical blocker to using JAX on AMD GPUs. We on the JAX team simply don't have access to any AMD GPUs at the moment to develop or test the necessary changes (which are probably not that large, given most of the necessary work has been done in the context of TensorFlow.)

Contributions are welcome!

ricardobarroslourenco

comment created time in 9 days

issue commentgoogle/jax

float32 vs float64 issues

Any news?

benjaminpope

comment created time in 9 days

issue commentgoogle/jax

illegal memory access while two GPUs is used

One thing I'd like to double check: you aren't also using TensorFlow in the same process? There's some sort of multigpu interactions between the two that I haven't yet tracked down. https://github.com/google/jax/issues/2742

EyalRozenberg1

comment created time in 9 days

issue commentgoogle/jax

CUDA issues

I don't think this is an XLA or JAX issue in particular, it sounds to me like the docker file needs fixing. Do any CUDA programs work in this configuration (e.g., say TF or PyTorch)? I'm not sure there's much we can do here; is there a way you can install the matching CUDA toolkit version in the docker configuration?

ericmjl

comment created time in 9 days

issue commentgoogle/jax

Wrong outputs from FFT on GPU

I've asked the XLA folks to work around the NVIDIA IRFFT bug by cloning the input for all IRFFTs. Hopefully that can happen relatively quickly.

(If any NVIDIA folks are reading this, this is partners bug # 2959622 )

Interestingly PyTorch also seems to have experienced the same bug: https://github.com/pytorch/pytorch/issues/34551

kohr-h

comment created time in 9 days

more