google/jax 6669

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Semantic Parser with Execution

Stores paper references, outputs to bib/html, does basic sanity checking on bib entries

a safe, concurrent, practical language

Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

push eventgoogle/jax

commit sha 664a4e123d83fb1d17cd31451fbebc6f1568707a

VJP of cond, via partial eval + transpose (#2091) VJP (grad) of lax.cond, via partial eval + transpose Co-authored-by: Matthew Johnson <mattjj@google.com>

push time in 19 days

PR merged google/jax

pr closed time in 19 days

push eventgoogle/jax

commit sha 6c07d317e2e5590463291e66d0adb90ae81dc3f6

add grad of cond to changelog

push time in 19 days

push eventgoogle/jax

commit sha 2341bc1dadadda526b4b5b175513e6ea8a0d4511

fix tests affected by transpose of lax.cond Co-authored-by: Matthew Johnson <mattjj@google.com>

push time in 19 days

push eventgoogle/jax

commit sha bc48bd1ea12ad7c3d3d50e3e88b1e55504283cd6

test grad of lax.cond with a closure

push time in 19 days

push eventgoogle/jax

commit sha 070455c6d3f1228709c7b3c705576cf8d07b4f8a

fix lax.cond jit test

commit sha 792c3f3a893ce1ce1bfd5b45e8a9e37f123cd943

WIP transpose of lax.cond Co-authored-by: Matthew Johnson <mattjj@google.com>

commit sha 393993e78f957cc5131a2561dc44ed9b101281fe

WIP debugging higher-order transpose of lax.cond Co-authored-by: Matthew Johnson <mattjj@google.com>

push time in 20 days

push eventgoogle/jax

commit sha a481b77a7d23f7c48bf217ecd288612042df2ba1

more lax.cond grad tests

push time in 21 days

push eventgoogle/jax

commit sha 4eee4caa6d4ec21bbe2a6bdadc3d1cc8ee1082e8

WIP transpose of lax.cond Co-authored-by: Matthew Johnson <mattjj@google.com>

push time in 21 days

push eventgoogle/jax

commit sha b3f4d7083cded6c89040d02a1ad3b073ea5977a2

WIP partial eval of lax.cond Co-authored-by: Matthew Johnson <mattjj@google.com>

commit sha 03e1bbf91a414dfd6e31f1a8eb53a7ddfba3a68c

partial eval of lax.cond Co-authored-by: Matthew Johnson <mattjj@google.com>

commit sha 13c1227180a56a682cabb2d8a9eb9f7aa4174e2e

more lax.cond tests Co-authored-by: Matthew Johnson <mattjj@google.com>

push time in 22 days

PR opened google/jax

pr created time in 22 days

push eventgoogle/jax

commit sha 8449c4af9bd5c4d9f58a1232087eb0fd0d11f14e

implement JVP of cond Co-authored-by: Matthew Johnson <mattjj@google.com>

commit sha 363f9e07fc157c97678a26dc7b70d4df3ad07a5d

Merge pull request #2045 from google/ad-cond implement JVP of cond

push time in a month

PR opened google/jax

pr created time in a month

push eventgoogle/jax

commit sha afb8af19ff7474561c3c904e03c63dbf8f57de3f

implement JVP of while loop Co-authored-by: Matthew Johnson <mattjj@google.com>

commit sha 335ecb97b838ea0185c26ec744d63c3a096a858b

test JVP of while loop, and fix the nonzero tangent calculation in the JVP rule

commit sha 28f70cc8f8ac4fc4b98ce85f51b389cec8635704

Merge pull request #1980 from google/jvp-while implement JVP of while loop. closes #650

push time in a month

issue closedgoogle/jax

Are there plans for supporting jacfwd for while_loops? Tensorflow afaik supports gradients for backprop'd while_loops. I'm really interested in jax's fwd-mode support as well.

closed time in a month

proteneerPR merged google/jax

```
In [1]: from jax import lax, jvp
In [2]: def f(x): return lax.fori_loop(0, 3, lambda i, x: x * 2, x)
In [3]: f(2.)
Out[3]: DeviceArray(16., dtype=float32)
In [4]: jvp(f, (2.,), (1.,))
Out[4]: (DeviceArray(16., dtype=float32), DeviceArray(8., dtype=float32))
In [5]: def f(x): return lax.fori_loop(0, 3, lambda i, x: x * (i+1), x)
In [6]: f(2.)
Out[6]: DeviceArray(12., dtype=float32)
In [7]: jvp(f, (2.,), (1.,))
Out[7]: (DeviceArray(12., dtype=float32), DeviceArray(6., dtype=float32))
```

pr closed time in a month

push eventgoogle/jax

commit sha 335ecb97b838ea0185c26ec744d63c3a096a858b

test JVP of while loop, and fix the nonzero tangent calculation in the JVP rule

push time in a month

PR opened google/jax

```
In [1]: from jax import lax, jvp
In [2]: def f(x): return lax.fori_loop(0, 3, lambda i, x: x * 2, x)
In [3]: f(2.)
Out[3]: DeviceArray(16., dtype=float32)
In [4]: jvp(f, (2.,), (1.,))
Out[4]: (DeviceArray(16., dtype=float32), DeviceArray(8., dtype=float32))
In [5]: def f(x): return lax.fori_loop(0, 3, lambda i, x: x * (i+1), x)
In [6]: f(2.)
Out[6]: DeviceArray(12., dtype=float32)
In [7]: jvp(f, (2.,), (1.,))
Out[7]: (DeviceArray(12., dtype=float32), DeviceArray(6., dtype=float32))
```

pr created time in a month

push eventgoogle/jax

commit sha 1ca9e9b251ebdea78df168d7a86f4719f49c553b

Concatenate error messages under numpy.{zeros,ones,full}. Closes #1822

push time in a month

issue closedgoogle/jax

Better error messaging for lax.full

Calling jnp.zeros or jnp.ones with a float instead of an integer leads to a misleading error message:

```
> jnp.zeros(1.0)
...
TypeError: `full` requires shapes to be concrete. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
```

This is because lax.full ignores the (correct) error message created by lax._canonicalize_shape and replaces it with the one shown to the user.

The code could be improved by showing the original error message instead, or appending the original error message in case it is not easily detectable if the error is caused by jit or an actual type error.

closed time in a month

thomaskeckpush eventgoogle/jax

commit sha 16484352685af5934facd32d07cd9664192e57ac

enable kernel regression example test

commit sha 6b39104a6d37c7a9add44ad15e5967a025bc42ad

Merge branch 'master' into kernel-example-test

commit sha 82ce209ae694bdbffa60155edbe34f4535e26d14

Merge branch 'master' into kernel-example-test

commit sha 1b2350bf36bc2eabfacea8f409eb04e74514f69f

Merge pull request #1730 from google/kernel-example-test enable kernel regression example test

push time in 3 months

push eventgoogle/jax

commit sha c601950b1bf5762dce0cd6be18be6a7a5e7e3473

Build Mac wheels for Python 3.8 with scipy 1.3.2. (#1739) scipy 1.3.1 never had a Python 3.8 wheel.

commit sha c77cdb51441cfeb198f63a7619409e2bd2d3647c

Bump README to jaxlib 0.1.35

commit sha eff7b45dbacad0358246130c70a1b44901945bda

Add jit decorators to most functions in jax.scipy.linalg. (#1741)

commit sha 27aa76e6a6895fa62e2fe23555178c59df411c70

Add precision to jax.numpy functions that use lax.dot_general (#1728) * Add precision to jax.numpy functions that use lax.dot_general * Test precision argument * check default precision * test with jaxprs * Document precision

commit sha 39e17398ee1c0b249856209056a37ba99cce426e

jaxlib build improvements (#1742)

commit sha 9dfb3cac28d269a3104942134bfb12abbd60b5dc

Relax test tolerances to fix flakiness. (#1743)

commit sha 1314fb7cb10dc5d484aef3f44feadc0814712ad5

Bump jaxlib version to 0.1.36 and update WORKSPACE.

commit sha dc5a599a9c9cc7eca8a86a234debcc55f84c18a0

Fix bug in jax repeat which caused a value error for repeat arguments containing 0. (#1740)

commit sha 15d276ce0ca7e8c99598fe07fdcf33514230e627

Bump README to jaxlib 0.1.36

commit sha 9c966f9fb54f01753ca434b429524796cc1beb5f

Fix `as_abstract_value`. The `JaxprTracerTuple` appears to no longer exist and the function could return `None` if `pv` was another type other than `AbstractValue`, instead of raising an error.

commit sha 20a7e7b3f49287ac59bfdef6ee80195209a143b6

Remove `as_abstract_val`.

commit sha b358c27c92f614b6ab24e7c99ea5902a4da92e39

replace x.shape with onp.shape(x) in random.py fixes #1748 (thanks @vitchyr)

commit sha 6f3cb1c3eef772268ea061158196d9d33412bb37

Add jax.devices(), etc. to the docs.

commit sha 3d1d140acd059138fb5f7c1d9be61a7edf290c05

Disable failing test (#1744)

commit sha 1dcddde4a0eb71230a4190e63775133d483b956b

Add jax.numpy.dtype as an alias of numpy.dtype. (#1750)

commit sha a9d1b770e428f197c6b296df884e868d66c434ab

Relax test tolerances to reduce flakiness. (#1751) * Relax test tolerances to reduce flakiness. * Relax test tolerance for np.cov test.

commit sha 9f86b53af386a586936d8743b843da20f955024f

Revert LaxBackedNumpyTests.testCov to use all_dtypes

commit sha 8f2a050e1ed0817d368d3c38c92fa4e7b4f8301c

fix cov test tol Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

commit sha 45a1ba0bbcfc64e9c65f9d2589bb474a59312c99

Make more tests pass on TPU. (#1752)

commit sha 67038321f8b33bdab99c1fc3d003719553769c61

Revert support for building a non-GPU build with --config=cuda enabled. (#1757) It turns out there are implicit CUDA dependencies inside the TF libraries used by JAX, so the attempt to disable GPU dependencies conditionally didn't work.

push time in 3 months

push eventgoogle/jax

commit sha 1817cab012a239181dd65c78facaae4d597f5bce

Use double-where trick to avoid NaNs in grad(sinc)

commit sha b7d11ab90db071d54211f09117d1b57765c933db

Bump jaxlib version to 0.1.35

commit sha ee36818a589d9c5d82b7b2f636c2f6c88cf9c796

Add bfloat16 support to JAX. (#1720) bfloat16 support is still immature, but this PR adds some initial support. Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems. The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must: implement our own type promotion operators that understand bfloat16 types wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.

commit sha c60f3fd65d7762d7037d84a425bfb2cc25d15c3d

Minor documentation fixes. (#1734)

commit sha 2b0cde3648e3f405b87558e3a3eff352a6c377a8

Fix test failure for jax.numpy.signbit(bfloat16) on TPU. (#1735)

commit sha a8a19e196ce07dc0c20864672804de36967612e8

Implement batching rule for lax._select_and_gather_add (#1736)

commit sha a8c5b49fda6ab7ad6f65fd0ac9975740f26a84b9

Merge pull request #1722 from google/jb/sinc-double-where Use double-where trick to avoid NaNs in grad(sinc)

commit sha 6b39104a6d37c7a9add44ad15e5967a025bc42ad

Merge branch 'master' into kernel-example-test

push time in 3 months

pull request commentgoogle/jax

Remove `dot` primitive in favor of reusing `dot_general`

Looks good with the masking rule!

comment created time in 4 months

push eventgoogle/jax

commit sha cd29780fe68f1b096f21f4f3b8d4fb3b2785c0e4

minor README formatting and grammar adjustment

push time in 4 months

push eventgoogle/jax

commit sha db3b0dd03683dd744f33ea7b50e53a4cf5854c23

fix missing delimiter in notebook source

push time in 4 months

issue openedgoogle/jax

index_update shape error not caught before reaching XLA

```
import jax.numpy as np
import jax.ops as jo
f = lambda x, X: jo.index_update(X, jo.index[0], x)
x = np.zeros(2)
X = np.zeros(2)
# RuntimeError: Invalid argument: Updates tensor must be of rank 0; got 1.:
# This is a bug in JAX's shape-checking rules; please report it!
print(f(x, X))
```

The jaxpr for `f`

in context of `x, X`

is:

```
{ lambda c ; ; a b.
let d = broadcast[ sizes=() ] a
e = scatter[ updates_shape=(2,)
update_jaxpr={ lambda ; ; a b.
let
in [b] }
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
update_consts=() ] b c d
in [e] }
```

created time in 5 months

push eventgoogle/jax

commit sha d1c66614e81e88d548b6651dc3895029d805d073

add a "last" symbol for vmap axis specs, use it in `api.jacfwd`. tests and fixes #1372 Co-authored-by: Matthew Johnson <mattjj@google.com>

commit sha 180f280ee89d4326d8d90665b506bc9fffe6e8af

Merge branch 'master' into issue1372

commit sha ba5bcf2af522e95b019988f7b1e5302e5c0116e2

Merge pull request #1390 from google/issue1372 add a "last" symbol for vmap axis specs, use it in `api.jacfwd`

push time in 5 months

issue closedgoogle/jax

Incorrect reshaping after forward-of-reverse off-diagonal second-order autodiff

The following code sample computes blocks of a Hessian by composition of autodiff with itself. Whenever forward-mode autodiff is composed atop (reverse- or forward-mode) autodiff to compute off-diagonal blocks, a shape-related error occurs, due to what appears to be incorrect output-reshaping logic.

```
from jax.api import *
import jax.numpy as np
def quad(x):
return np.dot(np.dot(np.ones(x.shape * 2), x), x)
def f(x, u):
return quad(x) + quad(u)
x, u = np.ones(5), np.ones(2)
rev = jacrev # `rev = grad` yields the same outcomes below
fwd = jacfwd
# Diagonal entries - all OK
rev(rev(f, 0), 0)(x, u)
rev(fwd(f, 0), 0)(x, u)
fwd(rev(f, 0), 0)(x, u)
fwd(fwd(f, 0), 0)(x, u)
rev(rev(f, 1), 1)(x, u)
rev(fwd(f, 1), 1)(x, u)
fwd(rev(f, 1), 1)(x, u)
fwd(fwd(f, 1), 1)(x, u)
# Off-diagonal entries by reverse-mode on the outside - all OK
rev(rev(f, 1), 0)(x, u)
rev(fwd(f, 1), 0)(x, u)
rev(rev(f, 0), 1)(x, u)
rev(fwd(f, 0), 1)(x, u)
# Off-diagonal entries by forward-mode on the outside - all fail with:
#
# RuntimeError: Invalid argument: Input dimension should be either 1 or equal to
# the output dimension it is broadcasting into; the 0th operand dimension is X,
# the 0th output dimension is Y: This is a bug in JAX's shape-checking rules;
# please report it!
fwd(rev(f, 1), 0)(x, u) # X = 2, Y = 5
fwd(fwd(f, 1), 0)(x, u) # X = 2, Y = 5
fwd(rev(f, 0), 1)(x, u) # X = 5, Y = 2
fwd(fwd(f, 0), 1)(x, u) # X = 5, Y = 2
```

Similar errors can be triggered in the analogous cases under `make_jaxpr`

, rather than by standard evaluation:

```
make_jaxpr(rev(rev(f)))(x, u) # OK
make_jaxpr(rev(fwd(f)))(x, u) # OK
make_jaxpr(fwd(rev(f)))(x, u) # OK
make_jaxpr(fwd(fwd(f)))(x, u) # OK
# ...
make_jaxpr(rev(rev(f, 1), 0))(x, u) # OK
make_jaxpr(rev(fwd(f, 1), 0))(x, u) # OK
# ValueError: cannot reshape array of size 10 into shape (5,5)
make_jaxpr(fwd(rev(f, 1), 0))(x, u)
make_jaxpr(fwd(fwd(f, 1), 0))(x, u)
```

closed time in 5 months

froystigPR merged google/jax

Extending the vmap API in this way resolves ambiguity in how to specify the final dimension as the batched axis.

pr closed time in 5 months

PR opened google/jax

Extending the vmap API in this way resolves ambiguity in how to specify the final dimension as the batched axis.

pr created time in 5 months

push eventgoogle/jax

commit sha 16484ccbed49d1afeb08b3f6311476c19b49b604

add a "Citing JAX" section to the README. fixes #1359

push time in 5 months

issue closedgoogle/jax

Add citation guide to repo and README

I would like to cite Jax in a publication but am unsure how to do this.

To make this easier the repo should contain a `CITATION.bib`

file with a preferred bibtex citation. (Either pointing to this repo, or if there is a corresponding Jax publication).

Also, that information could be included at the bottom of the README.

closed time in 5 months

jessebettissue openedgoogle/jax

Incorrect reshaping after forward-of-reverse off-diagonal second-order autodiff

The following code sample computes blocks of a Hessian by composition of autodiff with itself. Whenever forward-mode autodiff is composed atop (reverse- or forward-mode) autodiff to compute off-diagonal blocks, a shape-related error occurs, due to what appears to be incorrect output-reshaping logic.

```
from jax.api import *
import jax.numpy as np
def quad(x):
return np.dot(np.dot(np.ones(x.shape * 2), x), x)
def f(x, u):
return quad(x) + quad(u)
x, u = np.ones(5), np.ones(2)
rev = jacrev # `rev = grad` yields the same outcomes below
fwd = jacfwd
# Diagonal entries - all OK
rev(rev(f, 0), 0)(x, u)
rev(fwd(f, 0), 0)(x, u)
fwd(rev(f, 0), 0)(x, u)
fwd(fwd(f, 0), 0)(x, u)
rev(rev(f, 1), 1)(x, u)
rev(fwd(f, 1), 1)(x, u)
fwd(rev(f, 1), 1)(x, u)
fwd(fwd(f, 1), 1)(x, u)
# Off-diagonal entries by reverse-mode on the outside - all OK
rev(rev(f, 1), 0)(x, u)
rev(fwd(f, 1), 0)(x, u)
rev(rev(f, 0), 1)(x, u)
rev(fwd(f, 0), 1)(x, u)
# Off-diagonal entries by forward-mode on the outside - all fail with:
#
# RuntimeError: Invalid argument: Input dimension should be either 1 or equal to
# the output dimension it is broadcasting into; the 0th operand dimension is X,
# the 0th output dimension is Y: This is a bug in JAX's shape-checking rules;
# please report it!
fwd(rev(f, 1), 0)(x, u) # X = 2, Y = 5
fwd(fwd(f, 1), 0)(x, u) # X = 2, Y = 5
fwd(rev(f, 0), 1)(x, u) # X = 5, Y = 2
fwd(fwd(f, 0), 1)(x, u) # X = 5, Y = 2
```

Similar errors can be triggered in the analogous cases under `make_jaxpr`

, rather than by standard evaluation:

```
make_jaxpr(rev(rev(f)))(x, u) # OK
make_jaxpr(rev(fwd(f)))(x, u) # OK
make_jaxpr(fwd(rev(f)))(x, u) # OK
make_jaxpr(fwd(fwd(f)))(x, u) # OK
# ...
make_jaxpr(rev(rev(f, 1), 0))(x, u) # OK
make_jaxpr(rev(fwd(f, 1), 0))(x, u) # OK
# ValueError: cannot reshape array of size 10 into shape (5,5)
make_jaxpr(fwd(rev(f, 1), 0))(x, u)
make_jaxpr(fwd(fwd(f, 1), 0))(x, u)
```

created time in 5 months

issue openedgoogle/jax

`jit` obscures unbound axis errors

Evaluating a parallel primitive without parallel context—that is, with an unbound axis variable—typically raises a legible error. However, the error is obscured under `jit`

:

```
from jax import lax, jit
import jax.numpy as np
(lambda x: lax.psum(x, 'i'))(np.ones(2)) # NameError: unbound axis name: i
jit(lambda x: lax.psum(x, 'i'))(np.ones(2)) # ValueError: max() arg is an empty sequence
```

created time in 5 months