push eventgoogle/jax

commit sha d958f3007da36913d235e435c1df765d56ba1722

Change JAX type promotion to prefer inexact types. (#1815) Change the JAX type promotion table to prefer inexact types during type promotion. NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost. This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators: e.g., ``` import numpy as onp from jax import numpy as np In [1]: onp.promote_types(onp.float32, onp.int32) Out[1]: dtype('float64') In [2]: onp.promote_types(onp.float16, onp.int64) Out[2]: dtype('float64') In [3]: np.promote_types(onp.float32, onp.int32) Out[3]: dtype('float32') In [4]: np.promote_types(onp.float16, onp.int64) Out[4]: dtype('float16') ``` This change is in preparation for enabling x64 mode by default on all platforms.

commit sha 0c0137d787830d8ebd584c4610f7932f3787cab6

avoid compiling trivial programs from partial_eval also minor clean up in api_test.py

commit sha 5a5db5f2089292c57d43c79fe237484b19720f8f

add lazy sub-language, fuse into op-by-op computations Also removes the DeviceConstant system.

push time in 5 hours

push eventgoogle/jax

commit sha 0c0137d787830d8ebd584c4610f7932f3787cab6

avoid compiling trivial programs from partial_eval also minor clean up in api_test.py

push time in 5 hours

PR merged google/jax

**Follow-up work** for subsequent PRs

- [ ] update
`pmap`

logic with the same

On master, we often end up compiling and executing trivial computations as a result of partial evaluation, where "trivial computations" means computations that are just tuple rearrangements of their inputs, possibly with some constants mixed in. More concretely, it means jaxprs with no eqns.

For example, executing this:

```
@jit
def f(x):
@jit
def g(x):
return x + 1
return g(x)
f(2)
```

Compiles and executes XLA computations corresponding to these jaxprs (from inserting a `print(jaxpr)`

in `_xla_computation`

in xla.py):

```
{ lambda ; ; a.
let
in [*] }
{ lambda ; ; a.
let b = xla_call[ backend=None
device=None ] a
{ lambda ; ; a.
let b = add a 1
in [b] } [ ; ]
in [b] }
```

Notice the first jaxpr, which really does get compiled and executed!

Here's a variant:

```
@jit
def f(x):
@jit
def g(x, y):
return (y, x + 1, 3)
return g(x, 2)
f(2)
```

```
{ lambda ; ; a b.
let
in [b, *, *] }
{ lambda ; ; a.
let b c d = xla_call[ backend=None
device=None ] a *
{ lambda ; ; a b.
let c = add a 1
in [*, c, *] } [ ; ]
in [*, c, *] }
```

Basically, nested `jit`

s will always produce these trivial computations, but so will things like `make_jaxpr`

and `xla_computation`

.

We've just been putting up with these extra computations because no one complained, but when working on #1668 I noticed it has another ill effect: if `jit`

computations force their arguments, then compiling and executing trivial computations like this causes more forcing than we might want! That was the last straw, though to be clear the bigger reason for avoiding these recompiles is just they're incurring a totally unnecessary costs.

Luckily, the fix is easy: we just check when a jaxpr has no eqns (i.e. when it's just a tuple rearrangement and/or constant-output computation) and in that case don't compile anything, and instead just do the rearrangement in Python.

This PR also includes some unrelated api_test.py cleanups.

pr closed time in 5 hours

pull request commentgoogle/jax

avoid compiling trivial programs from partial_eval

@dougalm gave the verbal LGTM:

comment created time in 5 hours

push eventgoogle/jax

commit sha b0c7586095092c5626d7c527ac7beeb528450740

make xla.force non-side-effecting

push time in 6 hours

push eventgoogle/jax

commit sha 460b0c7e1f4959014871eec9056f9b438af875fd

remove cruft from rebase

push time in 8 hours

push eventgoogle/jax

commit sha 437e6db8a1c0373fefc045f9dfd3693b00024212

Disable linalg_test.py::NumpyLinalgTest.testPinv on TPU and GPU This failed in google3 presubmits.

commit sha 120270cb47e0276021afa93ceca9149aee9dca35

Refined the test disabling for only TPU

commit sha 17813eab20604f89cccdd45ba06431b276d73839

Simplify np.cross. Add a jit decorator. (#1810) * Simplify np.cross. Add a jit decorator.

commit sha 31eb4fc1f3f673d0707b3eb4d1acf26a1166d15d

Merge pull request #1812 from gnecula/bug_fix Disable linalg_test.py::NumpyLinalgTest.testPinv on TPU and GPU

commit sha eca0d98ffde30573a5b95a8e9c2d3d10387e2e43

Increase test tolerance for float16 for LaxBackedNumpyTests.testCross Due to failure in google3 presubmit

commit sha 4a42e5d83006fbda7842b0fb7e947d1eddc42d9b

Merge pull request #1813 from gnecula/bug_fix Increase test tolerance for float16 for LaxBackedNumpyTests.testCross

commit sha c1aeaf511cb38c9d7a3174446d0525877256e6c9

xla_computation option to instantiate const output

commit sha 0899673363e4189dc670792bd3f0317795196b75

switch xla_computation instantiate outputs default

commit sha d113416844604c616bf11e8650c5018a15524ea6

Update WORKSPACE. We haven't published jaxlib 0.1.27 yet so I'm leaving the version as-is.

commit sha 7f4c2fcb6bee990db88d1eed6e02e2e898b9af3e

bump version for pypi

commit sha 1a82da37a31c49a94df15944eb7029f475936914

log compiles in pmap (#1817)

commit sha 36826585ee03781386db0c64a359ca4f6859b2ba

wip lazy sublanguage

commit sha dc93867d5e1ae93d956a55d2d5c02bcf5ac1241e

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

commit sha 2a6987bb187552bd81e7f3ab0b0c2eadcec0d209

revise np.arange

commit sha 557a25d9c8cd094da6a28d6f5fd268e92453311b

performance

push time in 8 hours

push eventgoogle/jax

commit sha 7f4c2fcb6bee990db88d1eed6e02e2e898b9af3e

bump version for pypi

push time in 19 hours

PR opened google/jax

On master, we often end up compiling and executing trivial computations as a result of partial evaluation, where "trivial computations" means computations that are just tuple rearrangements of their inputs, possibly with some constants mixed in. More concretely, it means jaxprs with no eqns.

For example, executing this:

```
@jit
def f(x):
@jit
def g(x):
return x + 1
return g(x)
f(2)
```

Compiles and executes XLA computations corresponding to these jaxprs (from inserting a `print(jaxpr)`

in `_xla_computation`

in xla.py):

```
{ lambda ; ; a.
let
in [*] }
{ lambda ; ; a.
let b = xla_call[ backend=None
device=None ] a
{ lambda ; ; a.
let b = add a 1
in [b] } [ ; ]
in [b] }
```

Notice the first jaxpr, which really does get compiled and executed!

Here's a variant:

```
@jit
def f(x):
@jit
def g(x, y):
return (y, x + 1, 3)
return g(x, 2)
f(2)
```

```
{ lambda ; ; a b.
let
in [b, *, *] }
{ lambda ; ; a.
let b c d = xla_call[ backend=None
device=None ] a *
{ lambda ; ; a b.
let c = add a 1
in [*, c, *] } [ ; ]
in [*, c, *] }
```

Basically, nested `jit`

s will always produce these trivial computations, but so will things like `make_jaxpr`

and `xla_computation`

.

We've just been putting up with these extra computations because no one complained, but when working on #1668 I noticed it has another ill effect: if `jit`

computations force their arguments, then compiling and executing trivial computations like this causes more forcing than we might want! That was the last straw, though to be clear the bigger reason for avoiding these recompiles is just they're incurring a totally unnecessary costs.

Luckily, the fix is easy: we just check when a jaxpr has no eqns (i.e. when it's just a tuple rearrangement and/or constant-output computation) and in that case don't compile anything, and instead just do the rearrangement in Python.

pr created time in 19 hours

push eventgoogle/jax

commit sha e8b46869b69d658878fc57b22de85bc7e9e86c73

performance

push time in a day

push eventgoogle/jax

commit sha c1aeaf511cb38c9d7a3174446d0525877256e6c9

xla_computation option to instantiate const output

commit sha 0899673363e4189dc670792bd3f0317795196b75

switch xla_computation instantiate outputs default

push time in a day

issue commentgoogle/jax

@differentiable decorator for classes

I haven't parsed the details here yet, but I think there is a use case (based on conversations long ago with rxwei, jekbradbury, sharadmv, and dougalm) where one wants to differentiate with respect to a container (ie a product type) that contains, say, a float and an int. That's potentially different from using the pytree mechanism to shuttle off the int part, since we need to pay attention to that int part for, say, `jit`

.

We likely need to teach the AD system that the tangent space for integral values is `core.unit`

, modeling the trivial vector space. (Alternatively we could have the differentiation api handle pytrees differently. Dynamic typing gives us a lot of options!)

Sorry if the above is cryptic. I wanted to jot down some quick thoughts without yet delving into the excellent points made in this thread already.

comment created time in a day

push eventgoogle/jax

commit sha 6009efa2e17061bfa23cdf55c5b75f8c8975da00

switch xla_computation instantiate outputs default

push time in a day

push eventgoogle/jax

commit sha 4ffeaf82a6f25d93f36503e68d497fc7c4b47a3f

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

commit sha d996829270b33d0108ba86094170e911e1e62843

revise np.arange

push time in a day

push eventgoogle/jax

commit sha c1777faedd9293a705968aa67a7044a720c4873d

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

push time in a day

push eventgoogle/jax

commit sha 73c3b3ed9232f70d439e9d0247582c29c8796197

push time in a day

push eventgoogle/jax

commit sha 12a62c1f33ab3cf32c2e2157016f745f834c831c

Bump jaxlib version to 0.1.37 and update WORKSPACE.

commit sha 5b6c9325ed47b29d9182b0480206ba15b5787500

Fix WORKSPACE hash

commit sha d6b18fbb51d78cc2eb3177736e9ed52f925fbc6a

Add some missing NumPy constants: euler_gamma, NZERO and PZERO. (#1809) I avoided adding the deprecated aliases for inf and nan.

commit sha 9503baf475f2922f1df82455e3f597b28fa1d437

wip lazy sublanguage

commit sha 924f6d6c151ca159783bb029c1d062e238f70083

push time in a day

pull request commentgoogle/jax

Running the mnist_vae.py example, I noticed something unfortunate: we were compiling the update loop (`run_epoch`

) twice!

First call's signature:

```
WARNING:absl:Compiling run_epoch for args (ArgSpec(aval=ShapedArray(uint32[2]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(2,), dims=(0,)), xla_shape=u32[2]{0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(None,)), xla_shape=f32[])).
```

Second and subsequent calls' signatures:

```
WARNING:absl:Compiling run_epoch for args (ArgSpec(aval=ShapedArray(uint32[2]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(2,), dims=(0,)), xla_shape=u32[2]{0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0})).
```

The difference is in the broadcasts, like this pair of entries:

```
ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(None, None)), xla_shape=f32[]),
```

```
ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}),
```

Because of the `np.zeros`

used to initialize optimizers in optimizers.py, on the first call to `run_epoch`

we're fusing the formation of those lazy zeros into the `run_epoch`

computation, meaning we compile for a signature with scalar arguments and broadcasts on the inputs. Then on the second call to `run_epoch`

we need to recompile for dense array inputs. (If we add an `xla.force`

call to the zeros creation, like `xla.force(np.zeros_like(x0))`

to the `init`

function of `momentum`

in optimizers.py, we don't have to recompile, as on master.)

Two possible solutions:

- make
`jit`

force its arguments, as on master (while still keeping op-by-op application and`jit`

closure nonstrict) - change optimizers.py, and similar user code, to call
`xla.force`

in cases like this

Discussing with @hawkinsp, we think Option 1 sounds like a better heuristic. (As an extension, if we want it, we could add a `nonstrict_argnums`

to `jit`

to control this behavior, maybe useful for library writers.) Forcing users to think about laziness and change their code, as we'd need to change optimizers.py, is costly, and we don't yet have real use cases where that this strict-jit policy would be problematic.

comment created time in a day

push eventgoogle/jax

commit sha 07c346153cc836996e53df8b7582a7f8c39b2702

wip lazy sublanguage

push time in a day

push eventgoogle/jax

commit sha be9cc31bcbf755e740fb22802cfefad638fc1cbb

wip lazy sublanguage

push time in 2 days

push eventgoogle/jax

commit sha 5c15dda2c973598ef4130e5cd3979c74dda32cf2

Changed api.make_jaxpr to return a TypedJaxpr * A TypedJaxpr contains more useful information (consts, types) * Also forced the instantiation of constants when producing the jaxpr. Before: >>>print(api.make_jaxpr(lambda x: 1.)(0.)) lambda ; ; a. let in [*]} After this change: >>>print(api.make_jaxpr(lambda x: 1.)(0.)) lambda ; ; a. let in [1.0]}

commit sha 603258ebb825d61978fe47da34bc69af94ad3040

Fixed a couple of tests

commit sha 2bb74b627e0382f27f2f11e197d93e58f859f3fe

Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output

commit sha b0ffbaf1f6af38cdd0087dad3b2ac7147cfb6c50

Fixed also a notebook that has gone stale

commit sha 9a8523603c0d4efbdd82b2bd174060ac7825c051

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

commit sha b2b5049eb52d69c836a72a77f50822ec27f5b9b2

try remat_call partial-eval into two remat_calls The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!

commit sha ac251046fcbe0a940ab79fe49d58227e61f7c675

make remat_call partial-eval into one remat_call

commit sha 115d365a92ef038426b5d2777943948463c2725b

raise error if we do concrete aval FLOPs w/o remat

commit sha 0cb3b433b516efd94046d8346ad397a04e5521fa

Change in how we print sorted params for eqns

commit sha a47f365c924269029b2366870ebe69020ddf7785

Cleaned some test warnings. Specifically: * lax_control_flow_test.py:...: DeprecationWarning: invalid escape sequence \( * Deprecated assertRaisesRegexp, replace with assertRaisesRegex

commit sha 2b0b04fcadc9ea952c14dbde04894312e2c7b757

Merge remote-tracking branch 'upstream/master' into jaxpr_pp

commit sha 3b97c5f792c0053b283eec77c7e638a2907f8f15

Updated uses of make_jaxpr in new code

commit sha 0bc081ec9896984feffbdbdbd2e044c0e1603751

Merge pull request #1766 from gnecula/jaxpr_pp Changed api.make_jaxpr to return a TypedJaxpr

commit sha fc73e50e0417299ca84ac52197af3a2b42adce3e

Merge pull request #1785 from gnecula/bug_fix3 Cleaned some test warnings.

commit sha 0ebf8488ae53737f77edc35bae04fc3e53e90a7a

Implement np.flip with axis = None (#1783) * super minimal starter code * Update optimizers.py * implement flip with axis = None

commit sha 6d2eb6790ee4ba7b6fd9d6117668b42e54c17e55

Add betaln, a wrapper for the Beta function (scipy.special.betaln). (#1788) * Add betaln, a wrapper for the Beta function (scipy.special.betaln). * Use infix operators for addition and multiplication.

commit sha f0d93333791db927a98e8029878c4c027b97bb49

Document functions in jax.nn. (#1795)

commit sha f3c8af49e78b04542535824be9e8d3c95b0cc778

Fix bugs in handling of convolutions whose LHS has spatial size 0. (#1794) * Fix bugs in handling of convolutions whose LHS has spatial size 0. * Use onp.shape to compute shapes.

commit sha 8782860d0bdc84a52ee4d33db3d48abe3ec34892

Relax test tolerances to fix test flakiness.

commit sha 441ad4dbbdc5f38d9f61017621228d7dbf994a57

Relax test tolerances for scipy test.

push time in 2 days

pull request commentgoogle/jax

@shoyer Since it's been a while on our end, WDYT about landing and iterating? Do you want to take on the remaining todos (from your comments), or should I? (I don't mind doing it, if it means merging this excellent feature sooner!)

comment created time in 2 days

issue closedgoogle/jax

jax not working in multiprocessing

```
import numpy as np
from functools import partial
import jax.numpy as jnp
from jax import jit, random
from collections import namedtuple
from multiprocessing import Process, Queue
from pickle import dumps
class Brain(namedtuple("Brain", ("w1", "b1", "w2", "b2"))):
def __sub__(self, other):
return Brain(
w1=self.w1 - other.w1,
b1=self.b1 - other.b1,
w2=self.w2 - other.w2,
b2=self.b2 - other.b2,
)
def __mul__(self, scalar):
return Brain(
w1=self.w1 * scalar,
b1=self.b1 * scalar,
w2=self.w2 * scalar,
b2=self.b2 * scalar,
)
__rmul__ = __mul__
def get_brain(
input_size: int, hidden_size: int, output_size: int, max_memory: int, seed: int
):
key = random.PRNGKey(seed)
w1 = random.truncated_normal(
key, lower=0, upper=0.1, shape=(input_size, hidden_size)
)
w2 = random.truncated_normal(
key, lower=0, upper=0.1, shape=(hidden_size, output_size)
)
b1 = jnp.zeros(shape=(hidden_size,))
b2 = jnp.zeros(shape=(output_size,))
return Brain(w1=w1, b1=b1, w2=w2, b2=b2)
@jit
def forward(brain: Brain, data: np.ndarray):
o1 = jnp.matmul(data, brain.w1) + brain.b1
a1 = jnp.tanh(o1)
o2 = jnp.matmul(a1, brain.w2) + brain.b2
a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
return a2
def worker(queue):
import jax.numpy as jnp
from jax import grad, jit
@jit
def forward(brain: Brain, data: np.ndarray):
o1 = jnp.matmul(data, brain.w1) + brain.b1
a1 = jnp.tanh(o1)
o2 = jnp.matmul(a1, brain.w2) + brain.b2
a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
return a2
@jit
def loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
pred = forward(brain, data)
loss = jnp.mean(-(labels * pred).sum(1))
return loss
@jit
def grad_loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
return partial(grad(loss), data=data, labels=labels)(brain)
@jit
def sgd(brain: Brain, data: np.ndarray, labels: np.ndarray, learning_rate: float):
g = grad_loss(brain, data, labels)
brain = brain - g * learning_rate
return brain
while True:
brain, data, label, epoch, learning_rate = queue.get()
for i in range(epoch):
brain = sgd(brain, data, labels, learning_rate)
break # if multiprocess the control flow dows not even come here, nothing get's returned from forward
if __name__ == "__main__":
brain = get_brain(100, 200, 9, 1000, 1)
data = np.random.normal(size=(1000, 100))
labels = np.random.uniform(0, 1, size=(1000, 9))
queue = Queue(10)
workers = []
for i in range(2):
p = Process(target=worker, args=(queue,))
p.start() # does not work
workers.append(p)
for i in range(10):
queue.put((brain, data, labels, 1000, 1))
worker(queue) # works
```

closed time in 2 days

dchatterjee172issue commentgoogle/jax

jax not working in multiprocessing

Thanks for the question!

comment created time in 2 days

Pull request review commentgoogle/jax

Support atrous conv in same padded convolution and add warning if use…

def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, if type(dimension_numbers) is not ConvDimensionNumbers: dimension_numbers = conv_dimension_numbers( lhs.shape, rhs.shape, dimension_numbers)- if isinstance(padding, str):- lhs_perm, rhs_perm, _ = dimension_numbers- padding = padtype_to_pads(- onp.take(lhs.shape, lhs_perm)[2:], onp.take(rhs.shape, rhs_perm)[2:],- window_strides, padding) if lhs_dilation is None: lhs_dilation = (1,) * (lhs.ndim - 2)+ elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):+ warnings.warn(

Instead of a warning, how about an error? (Also, maybe the message could say more about how to specify the required padding explicitly, i.e. what the alternative to a string is.)

comment created time in 2 days

Pull request review commentgoogle/jax

Support atrous conv in same padded convolution and add warning if use…

def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, if type(dimension_numbers) is not ConvDimensionNumbers: dimension_numbers = conv_dimension_numbers( lhs.shape, rhs.shape, dimension_numbers)- if isinstance(padding, str):- lhs_perm, rhs_perm, _ = dimension_numbers- padding = padtype_to_pads(- onp.take(lhs.shape, lhs_perm)[2:], onp.take(rhs.shape, rhs_perm)[2:],- window_strides, padding) if lhs_dilation is None: lhs_dilation = (1,) * (lhs.ndim - 2)+ elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):+ warnings.warn(+ "String padding is not set-up correctly for transposed convolution "+ "using this op. Please either exaclty specify the required padding or "

Typo: "exaclty" -> "Exactly"

comment created time in 2 days

Pull request review commentgoogle/jax

add optional `length` argument to scan

def scan(f, init, xs): xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.+ length: optional integer specifying the number of loop iterations, which+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can+ be used to perform scans where no input ``xs`` are needed).

Good question! I believe that if `xs`

is an empty container, then the second argument to `f`

will be the same empty container. That is, if `xs=None`

then the second argument to `f`

is `None`

, but if `xs=()`

then the second argument to `f`

is `()`

, and if `xs=[()]`

then the second argument to `f`

is `[()]`

.

comment created time in 2 days

push eventgoogle/jax

commit sha 09f94a1e3d6a4f43d82f43e85ed1dc74f5290b93

add optional `length` argument to scan

commit sha ac2af106ed35556841a6fc3bf643b0e200ca0fa4

adjust scan docstring (thanks @shoyer)

push time in 3 days

push eventgoogle/jax

commit sha 991f626eac49694afa262aa5ebaf336c0213f002

adjust scan docstring (thanks @shoyer)

push time in 3 days

pull request commentgoogle/jax

add optional `length` argument to scan

Great idea to update the dosctring; done! I'm not sold on making `xs=None`

a default, though, because I don't like default argument values in general.

comment created time in 3 days

push eventgoogle/jax

commit sha c183dd1295321225bb05d6f2604555384a6f9f48

adjust scan docstring (thanks @shoyer)

push time in 3 days

push eventgoogle/jax

commit sha 51686f43d390e209923b476d04d352bb2a340f01

Make get_compile_options API accept 2D device assignment.

push time in 3 days

PR merged google/jax

pr closed time in 3 days

pull request commentgoogle/jax

Adding `broadcast_argnums` to `pmap` for allowing similar behaviour t…

Thanks for opening this! I think it's likely we'll merge it but we need to first sort out an API question about how to encode "static" and "broadcast" separately. One option might be to have a `static_broadcasted_argnums`

like in this PR, then separately have an `in_axes`

like in `vmap`

to specify non-static broadcasted argnums.

cc @skye

comment created time in 3 days

push eventgoogle-research/autoconj

commit sha 710d95ce4862f764cbdc96b399244394a63bf885

remove some fmap cruft

commit sha 8a167210a202a1bfdeab87d34e6e66391cad08a3

copying over pgm support (WIP)

push time in 5 days

issue commentgoogle/jax

FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated

I think it’s okay as-is: we support the deprecated numpy behavior, but when our tests compare against numpy it generates the warning. (Actually, I guess we could have the test case slurp up the warning... I only just realized that might be one of the solutions you have in mind.)

That said I’m interested to hear @shoyer ’s thoughts!

comment created time in 7 days

push eventgoogle/jax

commit sha 9232549620941dae0b4062155711d6b43c74ace8

add lax.delta/lax.broadcasted_eye to lazy language

push time in 7 days

pull request commentgoogle/jax

I often ignore warnings by running `pytest -n auto tests -W ignore`

. We could clean up the tests not to generate any warnings, but I'm not sure how much work that will take. We might also want to wait until the dust settles on @hawkinsp 's dtype-handling changes.

comment created time in 7 days

pull request commentgoogle/jax

I often ignore warnings by running `pytest -n auto tests -W ignore`

. We could clean up the tests not to generate any warnings, but I'm not sure how much work that will take. We might also want to wait until the dust settles on @hawkinsp 's dtype-handling changes.

comment created time in 7 days

push eventgoogle/jax

commit sha 42834312d8ee43274e9d47f39a35cd10b0f4108e

add lax.delta/lax.broadcasted_eye to lazy language

push time in 8 days

push eventgoogle/jax

commit sha 362845f43cf842fbf81a9ccdfed4e4a67a6178ff

add lax.eye, lax.tri to lazy language

push time in 8 days

push eventgoogle/jax

commit sha 162856b9003c1840cffe66171218fdaaa7eadef0

add lax.eye, lax.tri to lazy language

push time in 8 days

push eventgoogle/jax

commit sha f032af2356280e66d097fddeb3f66b8a659b5782

add lax.eye to lazy language

commit sha 9bf7deca105d8a564180f1d3cba5da5020dd2fd9

add lax.tri to lazy language

push time in 8 days

push eventgoogle/jax

commit sha 50f8e74e4872e2deae7088f00ca27ccfcde561a7

add lazy identity

push time in 8 days

pull request commentgoogle/jax

Add experimental rematerialization decorator

cc @joschu

comment created time in 8 days

push eventgoogle/jax

commit sha 9a8523603c0d4efbdd82b2bd174060ac7825c051

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

commit sha b2b5049eb52d69c836a72a77f50822ec27f5b9b2

try remat_call partial-eval into two remat_calls The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!

commit sha ac251046fcbe0a940ab79fe49d58227e61f7c675

make remat_call partial-eval into one remat_call

commit sha 115d365a92ef038426b5d2777943948463c2725b

raise error if we do concrete aval FLOPs w/o remat

push time in 8 days

PR merged google/jax

## The problem

We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. This allows users greater control over the time/memory tradeoffs of their computations. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.)

In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves.

This PR adds a `checkpoint`

decorator (alias: `remat`

) for this purpose.

Fixes #1732.

## API examples

### Example 1: grad-of-jit

Consider this grad-of-jit situation:

```
import jax.numpy as np
from jax import jit, grad, remat
@remat
def g(x):
return np.sin(np.sin(x))
@jit
def f(x):
return g(x)
grad(f)(2.0)
```

If we run this *without* the remat decorator, here are the jaxprs we end up lowering to XLA, one for the forward pass and one for the backward pass (by adding a `print(jaxpr)`

to `xla._xla_callable`

):

```
{ lambda ; ; a b.
let c = sin a
d = sin c
e = cos a
f = cos c
in [d, *, e, f] }
{ lambda ; ; a b c.
let d = mul c b
e = mul d a
in [e] }
```

Here's what we see *with* the `remat`

decorator:

```
{ lambda ; ; a b.
let c = sin a
d = sin c
in [d, *, a] }
{ lambda ; ; a b.
let c = remat_call a b
{ lambda ; ; a b.
let c = sin a
d = cos c
e = mul b d
f = cos a
g = mul e f
in [g] } [ ; ]
in [c] }
```

Notice how there are no residuals passed into the backward pass as constants, and we're computing `sin(2.0)`

both in the forward pass and in the backward pass.

### Example 2: jit-of-grad

Here's a jit-of-grad situation:

```
import jax.numpy as np
from jax import jit, value_and_grad, remat
@remat
def g(x):
return np.sin(np.sin(x))
jit(grad(g))(2.0)
```

```
{ lambda ; ; a.
let b = sin a
c = sin b
d = remat_call a 1.0
{ lambda ; ; a b.
let c = sin a
d = cos c
e = mul b d
f = cos a
g = mul e f
in [g] } [ ; ]
in [c, d] }
```

Again we're computing `sin(2.0)`

twice. Not shown is the fact that the translation rule for lowering `remat_call`

to XLA HLO includes some HLO widgets that will foil XLA's CSE optimization across `remat_call`

boundaries, even though it's all being lowered to one computation here.

### Example 3: differentiating Python control flow under `remat`

```
import jax.numpy as np
from jax import jit, grad, linearize, remat, make_jaxpr
@partial(remat, concrete=True)
def g(x):
if x > 0.:
return np.sin(np.sin(x))
else:
return np.cos(np.cos(x))
def f(x):
x = 3. * x
x = g(x)
x = 4. * x
return x
print(grad(f)(2.))
jaxpr = make_jaxpr(linearize(f, 2.)[1])(1.)
print(jaxpr)
```

```
11.075182
{ lambda ; ; a.
let b = mul a 3.0
c d = remat_call 6.0 b
{ lambda ; ; a b.
let c = sin a
d = sin c
e = cos a
f = mul b e
g = cos c
h = mul f g
in [d, h] } [ ; ]
e = mul d 4.0
in [e] }
```

Functions with `remat`

can still support differentiating through Python control flow! At least, when you pass `concrete=True`

. There's a tradeoff here: to handle Python control flow, in some cases involving `jit`

we might end up doing redundant FLOPs, as in the Three Sines Problem described below, so we wanted to make this feature opt-in.

Notice that in this last jaxpr, the `remat_call`

is applied to an argument that is a literal `6.0`

. We've prevented that from being partially evaluated through the called jaxpr on the JAX side, and moreover the lowering of `remat_call`

will prevent XLA from doing that constant folding as well.

### Example 4: scanning a `remat`

function

```
import jax.numpy as np
from jax import lax
from jax import vjp, remat, make_jaxpr
def f(c, x):
return np.sin(c), None
def foo(x):
y, _ = lax.scan(remat(f), x, np.arange(3.))
return y
_, foo_vjp = vjp(foo, 4.)
foo_vjp(1.) # added a `print(jaxpr)` in ad.backward_pass
```

```
{ lambda c d ; ; a b.
let e f = scan[ forward=True
length=3
jaxpr={ lambda ; ; a b c d e.
let f g = remat_call d e b
{ lambda ; ; a b c.
let d = sin a
e = cos a
f = mul c e
in [d, f] } [ ; ]
in [*, g] }
num_consts=0
num_carry=2
linear=[True, True, True, False, False] ] * b * c d
in [*, f] }
```

Notice the `sin`

and `cos`

computation in the scanned function! Compare to not using the `remat`

decorator:

```
{ lambda c ; ; a b.
let d e = scan[ forward=True
length=3
jaxpr={ lambda ; ; a b c d.
let e = mul b d
in [*, e] }
num_consts=0
num_carry=2
linear=[True, True, True, False] ] * b * c
in [*, e] }
```

### Example 5: binomial checkpointing

```
def binom_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = binom_checkpoint(funs[:len(funs)//2])
f2 = binom_checkpoint(funs[len(funs)//2:])
return lambda x: f1(remat(f2)(x))
# forward pass
f = binom_checkpoint([np.cos, np.sin, np.cos, np.sin])
print(make_jaxpr(f)(4.))
# forward and backward pass
f = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
print(make_jaxpr(value_and_grad(f))(4.))
# longer forward pass
f = binom_checkpoint([np.cos, np.sin, np.cos, np.sin,
np.cos, np.sin, np.cos, np.sin])
print(make_jaxpr(f)(4.))
```

```
# forward pass
{ lambda ; ; a.
let b = remat_call[ concrete=False ] a
{ lambda ; ; a.
let b = sin a
c = cos b
in [c] } [ ; ]
c = sin b
d = cos c
in [d] }
# forward and backward pass
{ lambda ; ; a.
let b = sin a
c = sin b
d = sin c
e = sin d
f = cos d
g = mul 1.0 f
h = cos c
i = mul g h
j = remat_call[ concrete=False ] a i
{ lambda ; ; a b.
let c = sin a
d = cos c
e = mul b d
f = cos a
g = mul e f
in [g] } [ ; ]
in [e, j] }
# longer forward pass
{ lambda ; ; a.
let b = remat_call[ concrete=False ] a
{ lambda ; ; a.
let b = remat_call[ concrete=False ] a
{ lambda ; ; a.
let b = sin a
c = cos b
in [c] } [ ; ]
c = sin b
d = cos c
in [d] } [ ; ]
c = remat_call[ concrete=False ] b
{ lambda ; ; a.
let b = sin a
c = cos b
in [c] } [ ; ]
d = sin c
e = cos d
in [e] }
```

## Implementation

The basic design here is to create a new call primitive, like `core.call_p`

(which doesn't *necessarily* raise the abstraction level of its inputs), that has a special partial evaluation rule: stage out the full computation, not just the parts that couldn't be partially evaluated. We call the new primitive `remat_call_p`

.

The implementation of the partial evaluation rule has three main steps:

- trace the full jaxpr of the called function, treating all inputs as "unknown" (but potentially using concrete avals for the arguments which are actually "known", if
`concrete=True`

), - process the full jaxpr to determine which outputs are meant to be "known"/"unknown", and to prune any extra computation,
- compute any "known" values we need (as an optimization, extract values from any concrete avals instead of re-evaluating them).

Otherwise, it looks a lot like the standard call partial evaluation rule, which is `JaxprTrace.process_call`

.

To make `remat_call_p`

transposable, we need to upgrade our `ad.backward_pass`

code to do some nonlinear evaluation (to rematerialize the residuals needed).

To translate `remat_call_p`

to XLA in a way that prevents XLA's CSE optimization from foiling our rematerialization plans, we use a special widget that (according to friendly XLA team members) will in effect force a false data dependence to avoid CSE.

This change also adjusts our abstract evaluation rules for concrete values not to promote to the shaped level. (It was a long time ago, but I think we had that policy only because we didn't have a use case for performing FLOPs while creating jaxprs.) To ensure that this change doesn't accidentally lead to more FLOPs in user code, I tried running the test suite with an `assert False`

in the new concrete-evaluation path and checked that it was never hit except in the new `remat`

tests.

The upgrade to ad.backward_pass was co-authored with @dougalm, and is essentially part of #1719.

pr closed time in 8 days

issue closedgoogle/jax

add explicit checkpointing (rematerialization) control

Analogous to `checkpoint`

in Autograd (but handling closed-over tracers).

closed time in 8 days

mattjjpush eventgoogle/jax

commit sha 16484352685af5934facd32d07cd9664192e57ac

enable kernel regression example test

commit sha 6b39104a6d37c7a9add44ad15e5967a025bc42ad

Merge branch 'master' into kernel-example-test

commit sha 67038321f8b33bdab99c1fc3d003719553769c61

Revert support for building a non-GPU build with --config=cuda enabled. (#1757) It turns out there are implicit CUDA dependencies inside the TF libraries used by JAX, so the attempt to disable GPU dependencies conditionally didn't work.

commit sha 534d812b5777cc2da5bba11130626498f6aa2667

Add a handwritten ThreeFry2x32 CUDA kernel. (#1756) In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time. When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.

commit sha d1aa01874d2e25fba07c23aaa78cc3d623d17c21

Fix BUILD file formatting.

commit sha 3b7d92db79442b64fc24f695b80e95cbcbe825c8

Add missing pybind11 dependency.

commit sha 4e89d43a75a9d7a9d803ba8777b867c487e55ed6

Added JAX pytrees notebook Also added docstrings to the tree_util module.

commit sha b12a8019c8711afaef9b4d1a9f437bf944575cee

Update docs/notebooks/JAX_pytrees.ipynb Co-Authored-By: Stephan Hoyer <shoyer@google.com>

commit sha 8777864c96d91d205cdf96236f4ad86818f957c6

Minor edits

commit sha 132102498bd81867a7a07e578444d4639f70ba1a

Minor edit

commit sha 159690ae026fc8364cefc22990914a15cb10e106

Merge pull request #1758 from gnecula/bug_fix Added notebook for PyTrees

commit sha ace14bbb942bb1dd532fb07d3524e141b85808fc

Remove `join_pvals` This function is appears to be unused.

commit sha 34dfbc8ae6c697118b6a18faf02547aaab5946dc

Add error checking to PRNG CUDA kernel. (#1760) Refactor error checking code into a common helper library.

commit sha 82ce209ae694bdbffa60155edbe34f4535e26d14

Merge branch 'master' into kernel-example-test

commit sha 36c882ba469fc009dba1a29df640c72eec12b36b

raise an error on jit-of-multi-host-pmap (#1761) Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

commit sha 1b2350bf36bc2eabfacea8f409eb04e74514f69f

Merge pull request #1730 from google/kernel-example-test enable kernel regression example test

commit sha f415f266b890e7b1476879c108f6e846d7f93177

Remove 'backend' argument from device_put. (#1762) The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.

commit sha 227a91220b51ed883e28023e1f7eea5662af28f5

Update minimum jaxlib to 0.1.36. This is needed in part to pull in new Device.platform from Tensorflow. See #1764.

commit sha 8547050b09318414190b88d29a2d0914ccc325a3

Merge pull request #1765 from gnecula/bug_fix Update minimum jaxlib to 0.1.36.

commit sha 7a9f1f3f1c90b112fc1830da68a89b514716acce

Pin the minimum jaxlib version in travis.yml. (#1767)

push time in 8 days

push eventgoogle/jax

commit sha 887ec1d948b990e186828d1201010c2034dec662

raise error if we do concrete aval FLOPs w/o remat

push time in 8 days

push eventgoogle/jax

commit sha 10f6010d0d94fb4d9b51832699071fb9ac880787

make remat_call partial-eval into one remat_call

push time in 8 days

pull request commentgoogle/jax

Add experimental rematerialization decorator

I learned something interesting! In 9987f73 I tried to make the partial eval rule of `remat_call`

stage out *two* `remat_call`

s, which in the context of reverse-mode ad would have the first only involve nonlinear computations (not depending on the tangents) and output residuals, while the second was linear. The aim was to be cleaner conceptually (don't mix linear and nonlinear in a single call), and also to avoid unnecessary recursing into calls when evaluating nonlinear constants in `ad.backward_pass`

.

But that ran into an interesting problem: it would form jaxprs like this:

```
{ lambda ; ; a b.
let c d = remat_call[ concrete=False ] 2.0 *
{ lambda ; ; a b.
let c = sin a
e = cos a
f = cos c
in [e, f] } [ ; ]
e f = remat_call[ concrete=False ] * b c d
{ lambda ; ; a b c e.
let d = mul b c
f = mul d e
in [*, f] } [ ; ]
in [*, f] }
```

and that is not "round-trippable" in the sense that if we do a `make_jaxpr`

on an `eval_jaxpr`

the first `remat_call`

will get partial-evaluated away (it has no data dependence on the input arguments). We use round-tripping like that in `partial_eval_jaxpr`

e.g. to implement partial eval of `scan`

. That meant that by trying to stage out separate nonlinear and linear `remat_call`

s, I broke `scan(remat(f), ...)`

.

So I reverted to staging out a single (mixed linear/nonlinear) `remat_call`

, and just special-casing it in `ad.backward_pass`

to avoid unnecessary recursions into other call primitives (which are always linearized by our partial-eval-of-jvp machinery).

comment created time in 8 days

push eventgoogle/jax

commit sha 227a91220b51ed883e28023e1f7eea5662af28f5

Update minimum jaxlib to 0.1.36. This is needed in part to pull in new Device.platform from Tensorflow. See #1764.

commit sha 8547050b09318414190b88d29a2d0914ccc325a3

Merge pull request #1765 from gnecula/bug_fix Update minimum jaxlib to 0.1.36.

commit sha 7a9f1f3f1c90b112fc1830da68a89b514716acce

Pin the minimum jaxlib version in travis.yml. (#1767)

commit sha 2867e4be082237e2b184d7b533dfcbfa31b24f63

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

commit sha fbc9446afa13b85a68484e8240aec07cfcb0fb8f

Fix some missing docstrings for Numpy functions. (#1768)

commit sha 3ae4a41320d6a116656626bfd0d87cde6c928335

Add "loops" BUILD target. (#1771)

commit sha b7579492690b1d94da89b7f1d1b6ddcfadbaacae

fix pulldown bugs

commit sha 8df1ccf42b6025e6e966df8a4ce034c27bdbff7e

Make jax.numpy.broadcast_to consistent with numpy. (#1773) * Make jax.numpy.broadcast_to consistent with numpy. jax.numpy.broadcast(10.0, ()) should return array(10.0) and not 10.0. * Improve broadcast_to test.

commit sha 5c96d83ea6247a8d27cd1f1a591d3f41fb8c0b64

Simplify einsum implementation. (#1774) XLA's DotGeneral operator has been generalized so we no longer need the _dot_general wrapper. Avoids the need for unnecessary reshapes.

commit sha da6a474a63bea7d7d27f3ee112cff75be6693a74

Simplify jax.numpy.tensordot by using lax.dot_general. (#1775)

commit sha ec79adccbb22730124210ef801d7e6ecfbf88316

source sync PiperOrigin-RevId: 282633556

commit sha c1d8d3f74d422222fe173d8e0ef5b05f9e2fd300

Add error checking that arguments of jvp are tuples

commit sha e0706ff86476271bfeb1b3e0055818c343fdf862

Relaxed check to allow both tuples and lists

commit sha 96f075db13787bd81d7ec40e17d4f21ba0c95299

Merge pull request #1777 from gnecula/bug_fix Add error checking that arguments of jvp are tuples

commit sha 6931489733a568a98a452ba3581f70bc3d7d1dea

update version for pypi

commit sha 80036744b35de220b8989b0e0bfa54eb956ab62c

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

commit sha 9987f73c41724cbc8fd707728e6f16192aa9bc36

try remat_call partial-eval into two remat_calls The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!

commit sha a58cf307bf7f20138c82780a1457960ca1650c41

make remat_call partial-eval into one remat_call

push time in 8 days

Pull request review commentgoogle/jax

Changed api.make_jaxpr to return a TypedJaxpr

def jaxpr_maker(*args, **kwargs): wrapped = lu.wrap_init(fun) jax_args, in_tree = tree_flatten((args, kwargs)) jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)- pvals = map(pv_like, jax_args)- jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)- return jaxpr+ in_pvals = map(pv_like, jax_args)+ jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jaxtree_fun, in_pvals, instantiate=True)

You're right, I think we need `instantiate=True`

to form a TypedJaxpr. Ignore my suggestion!

comment created time in 8 days

Pull request review commentgoogle/jax

Changed api.make_jaxpr to return a TypedJaxpr

def pp_eqn(eqn): >> pp(' [ {} ; {} ]'.format(pp_vars(const_vars), pp_vars(bound_vars)))) return (pp('{} = '.format(lhs)) >>- pp(eqn.primitive.name) >> pp_kv_pairs(eqn.params.items())+ pp(eqn.primitive.name) >> pp_kv_pairs(eqn.params)

Maybe rename `pp_kv_pairs`

to `pp_dict`

or something, since the name no longer reflects the type?

Also because tuple sorting is lexical, I think you could alternatively keep the previous function `pp_kv_pairs`

unmodified and have written here `pp_kv_pairs(sorted(eqn.params.items()))`

.

comment created time in 8 days

Pull request review commentgoogle/jax

Changed api.make_jaxpr to return a TypedJaxpr

def jaxpr_maker(*args, **kwargs): wrapped = lu.wrap_init(fun) jax_args, in_tree = tree_flatten((args, kwargs)) jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)- pvals = map(pv_like, jax_args)- jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)- return jaxpr+ in_pvals = map(pv_like, jax_args)+ jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jaxtree_fun, in_pvals, instantiate=True)

One theoretical downside to always having `instantiate=True`

is that we can't then use `make_jaxpr`

to illustrate what gets hoisted out of a function when e.g. we stage it out to XLA. One option might be to add an `instantiate`

argument to `make_jaxpr`

with a default value of True. Then it has the right default behavior but still lets us make the point about hoisting (which has been useful a time or two in answering JAXers questions IIRC).

I don't have a strong feeling there though, so whichever you think is best!

comment created time in 8 days

issue commentpyro-ppl/brmp

Memory leak using numpyro backend

FYI after google/jax#1721 compilation cache entries will be deleted after the underlying Python callable is deleted, but the cache size can grow without bound. (Previously we had a limit on the number of cache entries, but we had to remove it in google/jax#1721 to support weakrefs.) We plan to write our own logic to limit the cache size (not just number of entries, by querying XLA for the size of each compiled program), likely with an LRU eviction policy, but we're not acting on that urgently because as far as we know no one is getting compilation cache OOMs.

If you know otherwise, i.e. if you ever see the compilation cache blowing up, let us know on the JAX issue tracker!

comment created time in 8 days

issue commentpyro-ppl/brmp

Memory leak using numpyro backend

Just pushed jax 0.1.53 to pypi. (If we're ever slow at releases and you need one, just open an issue on our issue tracker. It only takes a few seconds, but we often forget to do it!)

comment created time in 8 days

push eventgoogle/jax

commit sha 6931489733a568a98a452ba3581f70bc3d7d1dea

update version for pypi

push time in 8 days

issue commentgoogle/jax

Slow XLA compile and high memory usage of `odeint` on high-dim problem

cc @dsweaver

comment created time in 8 days

startedgehring/fax

started time in 8 days

pull request commentgoogle/jax

Wrap optix rules in a namedtuple with init/apply

Nice!

comment created time in 9 days

push eventgoogle/jax

commit sha ec79adccbb22730124210ef801d7e6ecfbf88316

source sync PiperOrigin-RevId: 282633556

push time in 9 days

PR merged google/jax

Wrap optix rules in a namedtuple with init/update.

In general we've found it convenient to be able to keep track of pairs of functions as a single object. For example:

```
# Init.
loss_fn = hk.transform(loss_fn)
optimizer = optix.sgd(...)
# Initial state.
params = loss_fn.init(rng, ...)
opt_state = optimizer.init(params)
# Step.
grads = jax.grad(loss_fn.apply)(params, ...)
updates, opt_state = optimizer.update(grads, opt_state)
params = optix.apply_update(params, updates)
```

PiperOrigin-RevId: 282633556

pr closed time in 9 days

push eventgoogle/jax

commit sha 370a422914eda22587b08612202e2c9e16ccfb35

wip lazy sublanguage

push time in 9 days

Pull request review commentgoogle/jax

Make jax.numpy.broadcast_to consistent with numpy.

def testBroadcastToIntIssue1548(self): self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)), check_dtypes=False) + def testBroadcastToOnScalar(self):+ self.assertTrue(isinstance(lnp.broadcast_to(10.0, ()), lnp.ndarray))

WDYT about also adding a line like

```
self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray))
```

in the spirit of self-documenting code?

comment created time in 9 days

Pull request review commentgoogle/jax

Make jax.numpy.broadcast_to consistent with numpy.

def testBroadcastToIntIssue1548(self): self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)), check_dtypes=False) + def testBroadcastToOnScalar(self):+ self.assertTrue(isinstance(lnp.broadcast_to(10.0, ()), lnp.ndarray))

`self.assertIsInstance`

?

comment created time in 9 days

push eventgoogle/jax

commit sha b7579492690b1d94da89b7f1d1b6ddcfadbaacae

fix pulldown bugs

push time in 9 days

push eventgoogle/jax

commit sha ae572960a20ac10dd2a7d4e7867f23dd0299e65b

wip lazy sublanguage

push time in 9 days

push eventgoogle/jax

commit sha 2867e4be082237e2b184d7b533dfcbfa31b24f63

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

push time in 9 days

push eventgoogle/jax

commit sha f415f266b890e7b1476879c108f6e846d7f93177

Remove 'backend' argument from device_put. (#1762) The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.

commit sha 22b7c9622176f849e5aaa79ce593904b52a20992

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

commit sha 57dd913834a54dee921047af7c78be0374e83c47

push time in 9 days

push eventgoogle/jax

commit sha 22b7c9622176f849e5aaa79ce593904b52a20992

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

push time in 9 days

push eventgoogle/jax

commit sha f4c9a09866ec0eefcd101cf5371c5e561567773b

push time in 10 days

push eventgoogle/jax

commit sha 16484352685af5934facd32d07cd9664192e57ac

enable kernel regression example test

commit sha 6b39104a6d37c7a9add44ad15e5967a025bc42ad

Merge branch 'master' into kernel-example-test

commit sha 45a1ba0bbcfc64e9c65f9d2589bb474a59312c99

Make more tests pass on TPU. (#1752)

commit sha 67038321f8b33bdab99c1fc3d003719553769c61

Revert support for building a non-GPU build with --config=cuda enabled. (#1757) It turns out there are implicit CUDA dependencies inside the TF libraries used by JAX, so the attempt to disable GPU dependencies conditionally didn't work.

commit sha 534d812b5777cc2da5bba11130626498f6aa2667

Add a handwritten ThreeFry2x32 CUDA kernel. (#1756) In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time. When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.

commit sha d1aa01874d2e25fba07c23aaa78cc3d623d17c21

Fix BUILD file formatting.

commit sha 3b7d92db79442b64fc24f695b80e95cbcbe825c8

Add missing pybind11 dependency.

commit sha 4e89d43a75a9d7a9d803ba8777b867c487e55ed6

Added JAX pytrees notebook Also added docstrings to the tree_util module.

commit sha b12a8019c8711afaef9b4d1a9f437bf944575cee

Update docs/notebooks/JAX_pytrees.ipynb Co-Authored-By: Stephan Hoyer <shoyer@google.com>

commit sha 8777864c96d91d205cdf96236f4ad86818f957c6

Minor edits

commit sha 132102498bd81867a7a07e578444d4639f70ba1a

Minor edit

commit sha 159690ae026fc8364cefc22990914a15cb10e106

Merge pull request #1758 from gnecula/bug_fix Added notebook for PyTrees

commit sha ace14bbb942bb1dd532fb07d3524e141b85808fc

Remove `join_pvals` This function is appears to be unused.

commit sha 34dfbc8ae6c697118b6a18faf02547aaab5946dc

Add error checking to PRNG CUDA kernel. (#1760) Refactor error checking code into a common helper library.

commit sha 82ce209ae694bdbffa60155edbe34f4535e26d14

Merge branch 'master' into kernel-example-test

commit sha 36c882ba469fc009dba1a29df640c72eec12b36b

raise an error on jit-of-multi-host-pmap (#1761) Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

commit sha 1b2350bf36bc2eabfacea8f409eb04e74514f69f

Merge pull request #1730 from google/kernel-example-test enable kernel regression example test

commit sha 14b8bab98f7c820d6b796442bf5844ed43b6d56f

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. The basic design here is to create a new call primitive, like `core.call_p` (which doesn't raise the abstraction level of its inputs), that has a special partial evaluation rule: stage out the full computation, not just the parts that couldn't be partially evaluated. We call the new primitive `remat_call_p`. The implementation of the partial evaluation rule has three main steps: 1. trace the full jaxpr of the called function, treating all inputs as "unknown" (but potentially using concrete avals for the arguments which are actually "known"), 2. process the full jaxpr to determine which outputs are meant to be "known"/"unknown", and to prune any extra computation, 3. compute any "known" values we need (and as an optimization, extract values from any concrete avals). Otherwise, it looks a lot like the standard call partial evaluation rule, which is JaxprTrace.process_call. To make `remat_call_p` transposable, we need to upgrade our `ad.backward_pass` code to do some nonlinear evaluation (to rematerialize the residuals needed). To translate `remat_call_p` to XLA in a way that prevents XLA's CSE optimization from foiling our rematerialization plans, we use a special widget that (according to friendly XLA team members) will in effect force a false data dependence to avoid CSE. This change also adjusts our abstract evaluation rules for concrete values not to promote to the shaped level. (It was a long time ago, but I think we had that policy only because we didn't have a use case for performing FLOPs while creating jaxprs.) The upgrade to ad.backward_pass was co-authored with @dougalm, and is essentially part of #1719. Fixes #1732. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

push time in 10 days

push eventgoogle/jax

commit sha ef95cf497d5cb62a8cc7184e66d8f96d7589aa54

raise an error on jit-of-multi-host-pmap Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

push time in 10 days

PR opened google/jax

We can handle this case correctly, but for now we raise an error rather than silently producing incorrect values.

pr created time in 10 days

push eventgoogle/jax

commit sha 5d5c8523dec486aa2d7ba427a82634f1475f2f76

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. The basic design here is to create a new call primitive, like `core.call_p` (which doesn't raise the abstraction level of its inputs), that has a special partial evaluation rule: stage out the full computation, not just the parts that couldn't be partially evaluated. We call the new primitive `remat_call_p`. The implementation of the partial evaluation rule has three main steps: 1. trace the full jaxpr of the called function, treating all inputs as "unknown" (but potentially using concrete avals for the arguments which are actually "known"), 2. process the full jaxpr to determine which outputs are meant to be "known"/"unknown", and to prune any extra computation, 3. compute any "known" values we need (and as an optimization, extract values from any concrete avals). Otherwise, it looks a lot like the standard call partial evaluation rule, which is JaxprTrace.process_call. To make `remat_call_p` transposable, we need to upgrade our `ad.backward_pass` code to do some nonlinear evaluation (to rematerialize the residuals needed). To translate `remat_call_p` to XLA in a way that prevents XLA's CSE optimization from foiling our rematerialization plans, we use a special widget that (according to friendly XLA team members) will in effect force a false data dependence to avoid CSE. This change also adjusts our abstract evaluation rules for concrete values not to promote to the shaped level. (It was a long time ago, but I think we had that policy only because we didn't have a use case for performing FLOPs while creating jaxprs.) The upgrade to ad.backward_pass was co-authored with @dougalm, and is essentially part of #1719. Fixes #1732. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

push time in 10 days

pull request commentgoogle/jax

Add experimental rematerialization decorator

@dougalm spotted an issue: when tracing on concrete values (to enable Python control flow) we might still incur extra work if we have `jit`

mixed in in the right way. Concretely, here's the "three sines" problem he came up with:

```
from jax import jit, grad, remat
import jax.numpy as np
@jit
def g(x):
return f(2., x)
@remat
def f(x, y):
return np.sin(x) * y
grad(g)(3.) # sin_p.impl evaluated three times!
```

Running this, `sin(2.)`

is evaluated *three* times rather than the intended two. The issue is that the residual value `sin(2.0)`

is computed but thrown away when tracing the function on concrete avals (i.e. in the first of the three steps in the OP, it's thrown away when `mul`

has one concrete and one shape-abstracted-by-`jit`

argument), then recomputed when we need to evaluate the known outputs (i.e. in the third of the three steps in the OP). I believe `jit`

is essential here (and that's why JAX is harder than Autograd!).

The fix we want to try is to make tracing on concrete values (to support Python control flow) optional, at the risk of doing redundant FLOPs when `jit`

is involved in slightly tricky ways.

comment created time in 10 days

push eventgoogle/jax

commit sha ace14bbb942bb1dd532fb07d3524e141b85808fc

Remove `join_pvals` This function is appears to be unused.

push time in 10 days

PR merged google/jax

This function is appears to be unused.

pr closed time in 10 days

push eventgoogle/jax

commit sha d1b78a7332fb21c2aefa017c20fa3e2f8c1d4780

wip lazy sublanguage

push time in 11 days

push eventgoogle/jax

commit sha a4faaaf124a45ca54f57a8e4e4d356a3fbfd22c3

wip lazy sublanguage

push time in 11 days

push eventgoogle/jax

commit sha 29b6546e635758d222ee10134dd9d7353625cad6

wip lazy sublanguage

push time in 11 days