profile
viewpoint
Matthew Johnson mattjj Google San Francisco people.csail.mit.edu/~mattjj research scientist @ Google Brain

push eventgoogle/jax

Peter Hawkins

commit sha d958f3007da36913d235e435c1df765d56ba1722

Change JAX type promotion to prefer inexact types. (#1815) Change the JAX type promotion table to prefer inexact types during type promotion. NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost. This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators: e.g., ``` import numpy as onp from jax import numpy as np In [1]: onp.promote_types(onp.float32, onp.int32) Out[1]: dtype('float64') In [2]: onp.promote_types(onp.float16, onp.int64) Out[2]: dtype('float64') In [3]: np.promote_types(onp.float32, onp.int32) Out[3]: dtype('float32') In [4]: np.promote_types(onp.float16, onp.int64) Out[4]: dtype('float16') ``` This change is in preparation for enabling x64 mode by default on all platforms.

view details

Matthew Johnson

commit sha 0c0137d787830d8ebd584c4610f7932f3787cab6

avoid compiling trivial programs from partial_eval also minor clean up in api_test.py

view details

Matthew Johnson

commit sha 5a5db5f2089292c57d43c79fe237484b19720f8f

add lazy sub-language, fuse into op-by-op computations Also removes the DeviceConstant system.

view details

push time in 5 hours

delete branch google/jax

delete branch : avoid-trivial-compilations

delete time in 5 hours

push eventgoogle/jax

Matthew Johnson

commit sha 0c0137d787830d8ebd584c4610f7932f3787cab6

avoid compiling trivial programs from partial_eval also minor clean up in api_test.py

view details

push time in 5 hours

PR merged google/jax

avoid compiling trivial programs from partial_eval cla: yes

Follow-up work for subsequent PRs

  • [ ] update pmap logic with the same

On master, we often end up compiling and executing trivial computations as a result of partial evaluation, where "trivial computations" means computations that are just tuple rearrangements of their inputs, possibly with some constants mixed in. More concretely, it means jaxprs with no eqns.

For example, executing this:

@jit 
def f(x):
  @jit 
  def g(x):
    return x + 1
  return g(x)

f(2)

Compiles and executes XLA computations corresponding to these jaxprs (from inserting a print(jaxpr) in _xla_computation in xla.py):

{ lambda  ;  ; a.
  let
  in [*] }

{ lambda  ;  ; a.
  let b = xla_call[ backend=None
                    device=None ] a
        { lambda  ;  ; a.
          let b = add a 1
          in [b] } [  ;  ]
  in [b] }

Notice the first jaxpr, which really does get compiled and executed!

Here's a variant:

@jit 
def f(x):
  @jit 
  def g(x, y):
    return (y, x + 1, 3)
  return g(x, 2)

f(2)
{ lambda  ;  ; a b.
  let
  in [b, *, *] }

{ lambda  ;  ; a.
  let b c d = xla_call[ backend=None
                        device=None ] a *
        { lambda  ;  ; a b.
          let c = add a 1
          in [*, c, *] } [  ;  ]
  in [*, c, *] }

Basically, nested jits will always produce these trivial computations, but so will things like make_jaxpr and xla_computation.

We've just been putting up with these extra computations because no one complained, but when working on #1668 I noticed it has another ill effect: if jit computations force their arguments, then compiling and executing trivial computations like this causes more forcing than we might want! That was the last straw, though to be clear the bigger reason for avoiding these recompiles is just they're incurring a totally unnecessary costs.

Luckily, the fix is easy: we just check when a jaxpr has no eqns (i.e. when it's just a tuple rearrangement and/or constant-output computation) and in that case don't compile anything, and instead just do the rearrangement in Python.

This PR also includes some unrelated api_test.py cleanups.

+89 -60

1 comment

2 changed files

mattjj

pr closed time in 5 hours

pull request commentgoogle/jax

avoid compiling trivial programs from partial_eval

@dougalm gave the verbal LGTM:

image

mattjj

comment created time in 5 hours

push eventgoogle/jax

Matthew Johnson

commit sha b0c7586095092c5626d7c527ac7beeb528450740

make xla.force non-side-effecting

view details

push time in 6 hours

push eventgoogle/jax

Matthew Johnson

commit sha 460b0c7e1f4959014871eec9056f9b438af875fd

remove cruft from rebase

view details

push time in 8 hours

push eventgoogle/jax

George Necula

commit sha 437e6db8a1c0373fefc045f9dfd3693b00024212

Disable linalg_test.py::NumpyLinalgTest.testPinv on TPU and GPU This failed in google3 presubmits.

view details

George Necula

commit sha 120270cb47e0276021afa93ceca9149aee9dca35

Refined the test disabling for only TPU

view details

Peter Hawkins

commit sha 17813eab20604f89cccdd45ba06431b276d73839

Simplify np.cross. Add a jit decorator. (#1810) * Simplify np.cross. Add a jit decorator.

view details

George Necula

commit sha 31eb4fc1f3f673d0707b3eb4d1acf26a1166d15d

Merge pull request #1812 from gnecula/bug_fix Disable linalg_test.py::NumpyLinalgTest.testPinv on TPU and GPU

view details

George Necula

commit sha eca0d98ffde30573a5b95a8e9c2d3d10387e2e43

Increase test tolerance for float16 for LaxBackedNumpyTests.testCross Due to failure in google3 presubmit

view details

George Necula

commit sha 4a42e5d83006fbda7842b0fb7e947d1eddc42d9b

Merge pull request #1813 from gnecula/bug_fix Increase test tolerance for float16 for LaxBackedNumpyTests.testCross

view details

Matthew Johnson

commit sha c1aeaf511cb38c9d7a3174446d0525877256e6c9

xla_computation option to instantiate const output

view details

Matthew Johnson

commit sha 0899673363e4189dc670792bd3f0317795196b75

switch xla_computation instantiate outputs default

view details

Skye Wanderman-Milne

commit sha d113416844604c616bf11e8650c5018a15524ea6

Update WORKSPACE. We haven't published jaxlib 0.1.27 yet so I'm leaving the version as-is.

view details

Matthew Johnson

commit sha 7f4c2fcb6bee990db88d1eed6e02e2e898b9af3e

bump version for pypi

view details

James Bradbury

commit sha 1a82da37a31c49a94df15944eb7029f475936914

log compiles in pmap (#1817)

view details

Matthew Johnson

commit sha 36826585ee03781386db0c64a359ca4f6859b2ba

wip lazy sublanguage

view details

Matthew Johnson

commit sha dc93867d5e1ae93d956a55d2d5c02bcf5ac1241e

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

view details

Matthew Johnson

commit sha 2a6987bb187552bd81e7f3ab0b0c2eadcec0d209

revise np.arange

view details

Matthew Johnson

commit sha 557a25d9c8cd094da6a28d6f5fd268e92453311b

performance

view details

push time in 8 hours

push eventgoogle/jax

Matthew Johnson

commit sha 7f4c2fcb6bee990db88d1eed6e02e2e898b9af3e

bump version for pypi

view details

push time in 19 hours

PR opened google/jax

avoid compiling trivial programs from partial_eval

On master, we often end up compiling and executing trivial computations as a result of partial evaluation, where "trivial computations" means computations that are just tuple rearrangements of their inputs, possibly with some constants mixed in. More concretely, it means jaxprs with no eqns.

For example, executing this:

@jit 
def f(x):
  @jit 
  def g(x):
    return x + 1
  return g(x)

f(2)

Compiles and executes XLA computations corresponding to these jaxprs (from inserting a print(jaxpr) in _xla_computation in xla.py):

{ lambda  ;  ; a.
  let
  in [*] }

{ lambda  ;  ; a.
  let b = xla_call[ backend=None
                    device=None ] a
        { lambda  ;  ; a.
          let b = add a 1
          in [b] } [  ;  ]
  in [b] }

Notice the first jaxpr, which really does get compiled and executed!

Here's a variant:

@jit 
def f(x):
  @jit 
  def g(x, y):
    return (y, x + 1, 3)
  return g(x, 2)

f(2)
{ lambda  ;  ; a b.
  let
  in [b, *, *] }

{ lambda  ;  ; a.
  let b c d = xla_call[ backend=None
                        device=None ] a *
        { lambda  ;  ; a b.
          let c = add a 1
          in [*, c, *] } [  ;  ]
  in [*, c, *] }

Basically, nested jits will always produce these trivial computations, but so will things like make_jaxpr and xla_computation.

We've just been putting up with these extra computations because no one complained, but when working on #1668 I noticed it has another ill effect: if jit computations force their arguments, then compiling and executing trivial computations like this causes more forcing than we might want! That was the last straw, though to be clear the bigger reason for avoiding these recompiles is just they're incurring a totally unnecessary costs.

Luckily, the fix is easy: we just check when a jaxpr has no eqns (i.e. when it's just a tuple rearrangement and/or constant-output computation) and in that case don't compile anything, and instead just do the rearrangement in Python.

+89 -60

0 comment

2 changed files

pr created time in 19 hours

create barnchgoogle/jax

branch : avoid-trivial-compilations

created branch time in 19 hours

push eventgoogle/jax

Matthew Johnson

commit sha e8b46869b69d658878fc57b22de85bc7e9e86c73

performance

view details

push time in a day

delete branch google/jax

delete branch : xla-computation-instantiate-consts

delete time in a day

push eventgoogle/jax

Matthew Johnson

commit sha c1aeaf511cb38c9d7a3174446d0525877256e6c9

xla_computation option to instantiate const output

view details

Matthew Johnson

commit sha 0899673363e4189dc670792bd3f0317795196b75

switch xla_computation instantiate outputs default

view details

push time in a day

issue commentgoogle/jax

@differentiable decorator for classes

I haven't parsed the details here yet, but I think there is a use case (based on conversations long ago with rxwei, jekbradbury, sharadmv, and dougalm) where one wants to differentiate with respect to a container (ie a product type) that contains, say, a float and an int. That's potentially different from using the pytree mechanism to shuttle off the int part, since we need to pay attention to that int part for, say, jit.

We likely need to teach the AD system that the tangent space for integral values is core.unit, modeling the trivial vector space. (Alternatively we could have the differentiation api handle pytrees differently. Dynamic typing gives us a lot of options!)

Sorry if the above is cryptic. I wanted to jot down some quick thoughts without yet delving into the excellent points made in this thread already.

cgarciae

comment created time in a day

push eventgoogle/jax

Matthew Johnson

commit sha 6009efa2e17061bfa23cdf55c5b75f8c8975da00

switch xla_computation instantiate outputs default

view details

push time in a day

create barnchgoogle/jax

branch : xla-computation-instantiate-consts

created branch time in a day

push eventgoogle/jax

Matthew Johnson

commit sha 4ffeaf82a6f25d93f36503e68d497fc7c4b47a3f

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

view details

Matthew Johnson

commit sha d996829270b33d0108ba86094170e911e1e62843

revise np.arange

view details

push time in a day

push eventgoogle/jax

Matthew Johnson

commit sha c1777faedd9293a705968aa67a7044a720c4873d

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

view details

push time in a day

push eventgoogle/jax

Matthew Johnson

commit sha 73c3b3ed9232f70d439e9d0247582c29c8796197

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

view details

push time in a day

push eventgoogle/jax

Skye Wanderman-Milne

commit sha 12a62c1f33ab3cf32c2e2157016f745f834c831c

Bump jaxlib version to 0.1.37 and update WORKSPACE.

view details

Skye Wanderman-Milne

commit sha 5b6c9325ed47b29d9182b0480206ba15b5787500

Fix WORKSPACE hash

view details

Peter Hawkins

commit sha d6b18fbb51d78cc2eb3177736e9ed52f925fbc6a

Add some missing NumPy constants: euler_gamma, NZERO and PZERO. (#1809) I avoided adding the deprecated aliases for inf and nan.

view details

Matthew Johnson

commit sha 9503baf475f2922f1df82455e3f597b28fa1d437

wip lazy sublanguage

view details

Matthew Johnson

commit sha 924f6d6c151ca159783bb029c1d062e238f70083

make jit strict in its arguments (i.e. force args) This change is to avoid recompiles. See comment: https://github.com/google/jax/pull/1668#issuecomment-561699616 Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.

view details

push time in a day

pull request commentgoogle/jax

Lazy sublanguage

Running the mnist_vae.py example, I noticed something unfortunate: we were compiling the update loop (run_epoch) twice!

First call's signature:

WARNING:absl:Compiling run_epoch for args (ArgSpec(aval=ShapedArray(uint32[2]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(2,), dims=(0,)), xla_shape=u32[2]{0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(None,)), xla_shape=f32[])).

Second and subsequent calls' signatures:

WARNING:absl:Compiling run_epoch for args (ArgSpec(aval=ShapedArray(uint32[2]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(2,), dims=(0,)), xla_shape=u32[2]{0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0})).

The difference is in the broadcasts, like this pair of entries:

ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(None, None)), xla_shape=f32[]),
ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}),

Because of the np.zeros used to initialize optimizers in optimizers.py, on the first call to run_epoch we're fusing the formation of those lazy zeros into the run_epoch computation, meaning we compile for a signature with scalar arguments and broadcasts on the inputs. Then on the second call to run_epoch we need to recompile for dense array inputs. (If we add an xla.force call to the zeros creation, like xla.force(np.zeros_like(x0)) to the init function of momentum in optimizers.py, we don't have to recompile, as on master.)

Two possible solutions:

  1. make jit force its arguments, as on master (while still keeping op-by-op application and jit closure nonstrict)
  2. change optimizers.py, and similar user code, to call xla.force in cases like this

Discussing with @hawkinsp, we think Option 1 sounds like a better heuristic. (As an extension, if we want it, we could add a nonstrict_argnums to jit to control this behavior, maybe useful for library writers.) Forcing users to think about laziness and change their code, as we'd need to change optimizers.py, is costly, and we don't yet have real use cases where that this strict-jit policy would be problematic.

mattjj

comment created time in a day

push eventgoogle/jax

Matthew Johnson

commit sha 07c346153cc836996e53df8b7582a7f8c39b2702

wip lazy sublanguage

view details

push time in a day

push eventgoogle/jax

Matthew Johnson

commit sha be9cc31bcbf755e740fb22802cfefad638fc1cbb

wip lazy sublanguage

view details

push time in 2 days

push eventgoogle/jax

George Necula

commit sha 5c15dda2c973598ef4130e5cd3979c74dda32cf2

Changed api.make_jaxpr to return a TypedJaxpr * A TypedJaxpr contains more useful information (consts, types) * Also forced the instantiation of constants when producing the jaxpr. Before: >>>print(api.make_jaxpr(lambda x: 1.)(0.)) lambda ; ; a. let in [*]} After this change: >>>print(api.make_jaxpr(lambda x: 1.)(0.)) lambda ; ; a. let in [1.0]}

view details

George Necula

commit sha 603258ebb825d61978fe47da34bc69af94ad3040

Fixed a couple of tests

view details

George Necula

commit sha 2bb74b627e0382f27f2f11e197d93e58f859f3fe

Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output

view details

George Necula

commit sha b0ffbaf1f6af38cdd0087dad3b2ac7147cfb6c50

Fixed also a notebook that has gone stale

view details

Matthew Johnson

commit sha 9a8523603c0d4efbdd82b2bd174060ac7825c051

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

Matthew Johnson

commit sha b2b5049eb52d69c836a72a77f50822ec27f5b9b2

try remat_call partial-eval into two remat_calls The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!

view details

Matthew Johnson

commit sha ac251046fcbe0a940ab79fe49d58227e61f7c675

make remat_call partial-eval into one remat_call

view details

Matthew Johnson

commit sha 115d365a92ef038426b5d2777943948463c2725b

raise error if we do concrete aval FLOPs w/o remat

view details

George Necula

commit sha 0cb3b433b516efd94046d8346ad397a04e5521fa

Change in how we print sorted params for eqns

view details

George Necula

commit sha a47f365c924269029b2366870ebe69020ddf7785

Cleaned some test warnings. Specifically: * lax_control_flow_test.py:...: DeprecationWarning: invalid escape sequence \( * Deprecated assertRaisesRegexp, replace with assertRaisesRegex

view details

George Necula

commit sha 2b0b04fcadc9ea952c14dbde04894312e2c7b757

Merge remote-tracking branch 'upstream/master' into jaxpr_pp

view details

George Necula

commit sha 3b97c5f792c0053b283eec77c7e638a2907f8f15

Updated uses of make_jaxpr in new code

view details

George Necula

commit sha 0bc081ec9896984feffbdbdbd2e044c0e1603751

Merge pull request #1766 from gnecula/jaxpr_pp Changed api.make_jaxpr to return a TypedJaxpr

view details

George Necula

commit sha fc73e50e0417299ca84ac52197af3a2b42adce3e

Merge pull request #1785 from gnecula/bug_fix3 Cleaned some test warnings.

view details

Tuan Nguyen

commit sha 0ebf8488ae53737f77edc35bae04fc3e53e90a7a

Implement np.flip with axis = None (#1783) * super minimal starter code * Update optimizers.py * implement flip with axis = None

view details

Srinivas Vasudevan

commit sha 6d2eb6790ee4ba7b6fd9d6117668b42e54c17e55

Add betaln, a wrapper for the Beta function (scipy.special.betaln). (#1788) * Add betaln, a wrapper for the Beta function (scipy.special.betaln). * Use infix operators for addition and multiplication.

view details

Peter Hawkins

commit sha f0d93333791db927a98e8029878c4c027b97bb49

Document functions in jax.nn. (#1795)

view details

Peter Hawkins

commit sha f3c8af49e78b04542535824be9e8d3c95b0cc778

Fix bugs in handling of convolutions whose LHS has spatial size 0. (#1794) * Fix bugs in handling of convolutions whose LHS has spatial size 0. * Use onp.shape to compute shapes.

view details

Peter Hawkins

commit sha 8782860d0bdc84a52ee4d33db3d48abe3ec34892

Relax test tolerances to fix test flakiness.

view details

Peter Hawkins

commit sha 441ad4dbbdc5f38d9f61017621228d7dbf994a57

Relax test tolerances for scipy test.

view details

push time in 2 days

pull request commentgoogle/jax

Implement np.linalg.pinv

@shoyer Since it's been a while on our end, WDYT about landing and iterating? Do you want to take on the remaining todos (from your comments), or should I? (I don't mind doing it, if it means merging this excellent feature sooner!)

TuanNguyen27

comment created time in 2 days

issue closedgoogle/jax

jax not working in multiprocessing

import numpy as np
from functools import partial
import jax.numpy as jnp
from jax import jit, random
from collections import namedtuple
from multiprocessing import Process, Queue
from pickle import dumps


class Brain(namedtuple("Brain", ("w1", "b1", "w2", "b2"))):
    def __sub__(self, other):
        return Brain(
            w1=self.w1 - other.w1,
            b1=self.b1 - other.b1,
            w2=self.w2 - other.w2,
            b2=self.b2 - other.b2,
        )

    def __mul__(self, scalar):
        return Brain(
            w1=self.w1 * scalar,
            b1=self.b1 * scalar,
            w2=self.w2 * scalar,
            b2=self.b2 * scalar,
        )

    __rmul__ = __mul__


def get_brain(
    input_size: int, hidden_size: int, output_size: int, max_memory: int, seed: int
):
    key = random.PRNGKey(seed)
    w1 = random.truncated_normal(
        key, lower=0, upper=0.1, shape=(input_size, hidden_size)
    )
    w2 = random.truncated_normal(
        key, lower=0, upper=0.1, shape=(hidden_size, output_size)
    )
    b1 = jnp.zeros(shape=(hidden_size,))
    b2 = jnp.zeros(shape=(output_size,))
    return Brain(w1=w1, b1=b1, w2=w2, b2=b2)


@jit
def forward(brain: Brain, data: np.ndarray):
    o1 = jnp.matmul(data, brain.w1) + brain.b1
    a1 = jnp.tanh(o1)
    o2 = jnp.matmul(a1, brain.w2) + brain.b2
    a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
    return a2


def worker(queue):
    import jax.numpy as jnp
    from jax import grad, jit

    @jit
    def forward(brain: Brain, data: np.ndarray):
        o1 = jnp.matmul(data, brain.w1) + brain.b1
        a1 = jnp.tanh(o1)
        o2 = jnp.matmul(a1, brain.w2) + brain.b2
        a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
        return a2

    @jit
    def loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
        pred = forward(brain, data)
        loss = jnp.mean(-(labels * pred).sum(1))
        return loss

    @jit
    def grad_loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
        return partial(grad(loss), data=data, labels=labels)(brain)

    @jit
    def sgd(brain: Brain, data: np.ndarray, labels: np.ndarray, learning_rate: float):
        g = grad_loss(brain, data, labels)
        brain = brain - g * learning_rate
        return brain

    while True:
        brain, data, label, epoch, learning_rate = queue.get()
        for i in range(epoch):
            brain = sgd(brain, data, labels, learning_rate)
        break # if multiprocess the control flow dows not even come here, nothing get's returned from forward


if __name__ == "__main__":
    brain = get_brain(100, 200, 9, 1000, 1)
    data = np.random.normal(size=(1000, 100))
    labels = np.random.uniform(0, 1, size=(1000, 9))
    queue = Queue(10)
    workers = []
    for i in range(2):
        p = Process(target=worker, args=(queue,))
        p.start()  # does not work
        workers.append(p)
    for i in range(10):
        queue.put((brain, data, labels, 1000, 1))
    worker(queue)  # works

closed time in 2 days

dchatterjee172

issue commentgoogle/jax

jax not working in multiprocessing

Thanks for the question!

dchatterjee172

comment created time in 2 days

Pull request review commentgoogle/jax

Support atrous conv in same padded convolution and add warning if use…

 def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,   if type(dimension_numbers) is not ConvDimensionNumbers:     dimension_numbers = conv_dimension_numbers(         lhs.shape, rhs.shape, dimension_numbers)-  if isinstance(padding, str):-    lhs_perm, rhs_perm, _ = dimension_numbers-    padding = padtype_to_pads(-        onp.take(lhs.shape, lhs_perm)[2:], onp.take(rhs.shape, rhs_perm)[2:],-        window_strides, padding)   if lhs_dilation is None:     lhs_dilation = (1,) * (lhs.ndim - 2)+  elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):+    warnings.warn(

Instead of a warning, how about an error? (Also, maybe the message could say more about how to specify the required padding explicitly, i.e. what the alternative to a string is.)

tamaranorman

comment created time in 2 days

Pull request review commentgoogle/jax

Support atrous conv in same padded convolution and add warning if use…

 def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,   if type(dimension_numbers) is not ConvDimensionNumbers:     dimension_numbers = conv_dimension_numbers(         lhs.shape, rhs.shape, dimension_numbers)-  if isinstance(padding, str):-    lhs_perm, rhs_perm, _ = dimension_numbers-    padding = padtype_to_pads(-        onp.take(lhs.shape, lhs_perm)[2:], onp.take(rhs.shape, rhs_perm)[2:],-        window_strides, padding)   if lhs_dilation is None:     lhs_dilation = (1,) * (lhs.ndim - 2)+  elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):+    warnings.warn(+        "String padding is not set-up correctly for transposed convolution "+        "using this op. Please either exaclty specify the required padding or "

Typo: "exaclty" -> "Exactly"

tamaranorman

comment created time in 2 days

Pull request review commentgoogle/jax

add optional `length` argument to scan

 def scan(f, init, xs):     xs: the value of type ``[a]`` over which to scan along the leading axis,       where ``[a]`` can be an array or any pytree (nested Python       tuple/list/dict) thereof with consistent leading axis sizes.+    length: optional integer specifying the number of loop iterations, which+      must agree with the sizes of leading axes of the arrays in ``xs`` (but can+      be used to perform scans where no input ``xs`` are needed).

Good question! I believe that if xs is an empty container, then the second argument to f will be the same empty container. That is, if xs=None then the second argument to f is None, but if xs=() then the second argument to f is (), and if xs=[()] then the second argument to f is [()].

mattjj

comment created time in 2 days

delete branch google/jax

delete branch : scan-length-arg

delete time in 3 days

push eventgoogle/jax

Matthew Johnson

commit sha 09f94a1e3d6a4f43d82f43e85ed1dc74f5290b93

add optional `length` argument to scan

view details

Matthew Johnson

commit sha ac2af106ed35556841a6fc3bf643b0e200ca0fa4

adjust scan docstring (thanks @shoyer)

view details

push time in 3 days

PR merged google/jax

Reviewers
add optional `length` argument to scan cla: yes

cc @shoyer

+39 -10

3 comments

2 changed files

mattjj

pr closed time in 3 days

push eventgoogle/jax

Matthew Johnson

commit sha 991f626eac49694afa262aa5ebaf336c0213f002

adjust scan docstring (thanks @shoyer)

view details

push time in 3 days

pull request commentgoogle/jax

add optional `length` argument to scan

Great idea to update the dosctring; done! I'm not sold on making xs=None a default, though, because I don't like default argument values in general.

mattjj

comment created time in 3 days

push eventgoogle/jax

Matthew Johnson

commit sha c183dd1295321225bb05d6f2604555384a6f9f48

adjust scan docstring (thanks @shoyer)

view details

push time in 3 days

push eventgoogle/jax

wang12tao

commit sha 51686f43d390e209923b476d04d352bb2a340f01

Make get_compile_options API accept 2D device assignment.

view details

push time in 3 days

PR merged google/jax

Make get_compile_options API accept 2D device assignment. cla: yes
+32 -1

0 comment

2 changed files

wang12tao

pr closed time in 3 days

PR opened google/jax

add optional `length` argument to scan
+33 -6

0 comment

2 changed files

pr created time in 3 days

create barnchgoogle/jax

branch : scan-length-arg

created branch time in 3 days

pull request commentgoogle/jax

Adding `broadcast_argnums` to `pmap` for allowing similar behaviour t…

Thanks for opening this! I think it's likely we'll merge it but we need to first sort out an API question about how to encode "static" and "broadcast" separately. One option might be to have a static_broadcasted_argnums like in this PR, then separately have an in_axes like in vmap to specify non-static broadcasted argnums.

cc @skye

botev

comment created time in 3 days

push eventgoogle-research/autoconj

Matthew Johnson

commit sha 710d95ce4862f764cbdc96b399244394a63bf885

remove some fmap cruft

view details

Matthew MacKay

commit sha 8a167210a202a1bfdeab87d34e6e66391cad08a3

copying over pgm support (WIP)

view details

push time in 5 days

issue commentgoogle/jax

FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated

I think it’s okay as-is: we support the deprecated numpy behavior, but when our tests compare against numpy it generates the warning. (Actually, I guess we could have the test case slurp up the warning... I only just realized that might be one of the solutions you have in mind.)

That said I’m interested to hear @shoyer ’s thoughts!

gnecula

comment created time in 7 days

push eventgoogle/jax

Matthew Johnson

commit sha 9232549620941dae0b4062155711d6b43c74ace8

add lax.delta/lax.broadcasted_eye to lazy language

view details

push time in 7 days

pull request commentgoogle/jax

Cleaned some test warnings.

I often ignore warnings by running pytest -n auto tests -W ignore. We could clean up the tests not to generate any warnings, but I'm not sure how much work that will take. We might also want to wait until the dust settles on @hawkinsp 's dtype-handling changes.

gnecula

comment created time in 7 days

pull request commentgoogle/jax

Cleaned some test warnings.

I often ignore warnings by running pytest -n auto tests -W ignore. We could clean up the tests not to generate any warnings, but I'm not sure how much work that will take. We might also want to wait until the dust settles on @hawkinsp 's dtype-handling changes.

gnecula

comment created time in 7 days

push eventgoogle/jax

Matthew Johnson

commit sha 42834312d8ee43274e9d47f39a35cd10b0f4108e

add lax.delta/lax.broadcasted_eye to lazy language

view details

push time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 362845f43cf842fbf81a9ccdfed4e4a67a6178ff

add lax.eye, lax.tri to lazy language

view details

push time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 162856b9003c1840cffe66171218fdaaa7eadef0

add lax.eye, lax.tri to lazy language

view details

push time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha f032af2356280e66d097fddeb3f66b8a659b5782

add lax.eye to lazy language

view details

Matthew Johnson

commit sha 9bf7deca105d8a564180f1d3cba5da5020dd2fd9

add lax.tri to lazy language

view details

push time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 50f8e74e4872e2deae7088f00ca27ccfcde561a7

add lazy identity

view details

push time in 8 days

delete branch google/jax

delete branch : remat2

delete time in 8 days

pull request commentgoogle/jax

Add experimental rematerialization decorator

cc @joschu

mattjj

comment created time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 9a8523603c0d4efbdd82b2bd174060ac7825c051

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

Matthew Johnson

commit sha b2b5049eb52d69c836a72a77f50822ec27f5b9b2

try remat_call partial-eval into two remat_calls The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!

view details

Matthew Johnson

commit sha ac251046fcbe0a940ab79fe49d58227e61f7c675

make remat_call partial-eval into one remat_call

view details

Matthew Johnson

commit sha 115d365a92ef038426b5d2777943948463c2725b

raise error if we do concrete aval FLOPs w/o remat

view details

push time in 8 days

PR merged google/jax

Add experimental rematerialization decorator cla: yes

The problem

We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. This allows users greater control over the time/memory tradeoffs of their computations. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.)

In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves.

This PR adds a checkpoint decorator (alias: remat) for this purpose.

Fixes #1732.

API examples

Example 1: grad-of-jit

Consider this grad-of-jit situation:

import jax.numpy as np
from jax import jit, grad, remat

@remat
def g(x):
  return np.sin(np.sin(x))

@jit
def f(x):
  return g(x)

grad(f)(2.0)

If we run this without the remat decorator, here are the jaxprs we end up lowering to XLA, one for the forward pass and one for the backward pass (by adding a print(jaxpr) to xla._xla_callable):

{ lambda  ;  ; a b.
  let c = sin a
      d = sin c
      e = cos a
      f = cos c
  in [d, *, e, f] }

{ lambda  ;  ; a b c.
  let d = mul c b
      e = mul d a
  in [e] }

Here's what we see with the remat decorator:

{ lambda  ;  ; a b.
  let c = sin a
      d = sin c
  in [d, *, a] }

{ lambda  ;  ; a b.
  let c = remat_call a b
        { lambda  ;  ; a b.
          let c = sin a
              d = cos c
              e = mul b d
              f = cos a
              g = mul e f
          in [g] } [  ;  ]
  in [c] }

Notice how there are no residuals passed into the backward pass as constants, and we're computing sin(2.0) both in the forward pass and in the backward pass.

Example 2: jit-of-grad

Here's a jit-of-grad situation:

import jax.numpy as np
from jax import jit, value_and_grad, remat

@remat
def g(x):
  return np.sin(np.sin(x))

jit(grad(g))(2.0)
{ lambda  ;  ; a.
  let b = sin a
      c = sin b
      d = remat_call a 1.0
        { lambda  ;  ; a b.
          let c = sin a
              d = cos c
              e = mul b d
              f = cos a
              g = mul e f
          in [g] } [  ;  ]
  in [c, d] }

Again we're computing sin(2.0) twice. Not shown is the fact that the translation rule for lowering remat_call to XLA HLO includes some HLO widgets that will foil XLA's CSE optimization across remat_call boundaries, even though it's all being lowered to one computation here.

Example 3: differentiating Python control flow under remat

import jax.numpy as np
from jax import jit, grad, linearize, remat, make_jaxpr

@partial(remat, concrete=True)
def g(x):
  if x > 0.:
    return np.sin(np.sin(x))
  else:
    return np.cos(np.cos(x))

def f(x):
  x = 3. * x
  x = g(x)
  x = 4. * x
  return x

print(grad(f)(2.))
jaxpr = make_jaxpr(linearize(f, 2.)[1])(1.)
print(jaxpr)
11.075182

{ lambda  ;  ; a.
  let b = mul a 3.0
      c d = remat_call 6.0 b
        { lambda  ;  ; a b.
          let c = sin a
              d = sin c
              e = cos a
              f = mul b e
              g = cos c
              h = mul f g
          in [d, h] } [  ;  ]
      e = mul d 4.0
  in [e] }

Functions with remat can still support differentiating through Python control flow! At least, when you pass concrete=True. There's a tradeoff here: to handle Python control flow, in some cases involving jit we might end up doing redundant FLOPs, as in the Three Sines Problem described below, so we wanted to make this feature opt-in.

Notice that in this last jaxpr, the remat_call is applied to an argument that is a literal 6.0. We've prevented that from being partially evaluated through the called jaxpr on the JAX side, and moreover the lowering of remat_call will prevent XLA from doing that constant folding as well.

Example 4: scanning a remat function

import jax.numpy as np
from jax import lax
from jax import vjp, remat, make_jaxpr

def f(c, x):
  return np.sin(c), None

def foo(x):
  y, _ = lax.scan(remat(f), x, np.arange(3.))
  return y

_, foo_vjp = vjp(foo, 4.)
foo_vjp(1.)  # added a `print(jaxpr)` in ad.backward_pass
{ lambda c d ;  ; a b.
  let e f = scan[ forward=True
                  length=3
                  jaxpr={ lambda  ;  ; a b c d e.
                          let f g = remat_call d e b
                                { lambda  ;  ; a b c.
                                  let d = sin a
                                      e = cos a
                                      f = mul c e
                                  in [d, f] } [  ;  ]
                          in [*, g] }
                  num_consts=0
                  num_carry=2
                  linear=[True, True, True, False, False] ] * b * c d
  in [*, f] }

Notice the sin and cos computation in the scanned function! Compare to not using the remat decorator:

{ lambda c ;  ; a b.
  let d e = scan[ forward=True
                  length=3
                  jaxpr={ lambda  ;  ; a b c d.
                          let e = mul b d
                          in [*, e] }
                  num_consts=0
                  num_carry=2
                  linear=[True, True, True, False] ] * b * c
  in [*, e] }

Example 5: binomial checkpointing

def binom_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = binom_checkpoint(funs[:len(funs)//2])
    f2 = binom_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(remat(f2)(x))

# forward pass
f = binom_checkpoint([np.cos, np.sin, np.cos, np.sin])
print(make_jaxpr(f)(4.))

# forward and backward pass
f = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
print(make_jaxpr(value_and_grad(f))(4.))

# longer forward pass
f = binom_checkpoint([np.cos, np.sin, np.cos, np.sin,
                      np.cos, np.sin, np.cos, np.sin])
print(make_jaxpr(f)(4.))
# forward pass
{ lambda  ;  ; a.
  let b = remat_call[ concrete=False ] a
        { lambda  ;  ; a.
          let b = sin a
              c = cos b
          in [c] } [  ;  ]
      c = sin b
      d = cos c
  in [d] }

# forward and backward pass
{ lambda  ;  ; a.
  let b = sin a
      c = sin b
      d = sin c
      e = sin d
      f = cos d
      g = mul 1.0 f
      h = cos c
      i = mul g h
      j = remat_call[ concrete=False ] a i
        { lambda  ;  ; a b.
          let c = sin a
              d = cos c
              e = mul b d
              f = cos a
              g = mul e f
          in [g] } [  ;  ]
  in [e, j] }

# longer forward pass
{ lambda  ;  ; a.
  let b = remat_call[ concrete=False ] a
        { lambda  ;  ; a.
          let b = remat_call[ concrete=False ] a
                { lambda  ;  ; a.
                  let b = sin a
                      c = cos b
                  in [c] } [  ;  ]
              c = sin b
              d = cos c
          in [d] } [  ;  ]
      c = remat_call[ concrete=False ] b
        { lambda  ;  ; a.
          let b = sin a
              c = cos b
          in [c] } [  ;  ]
      d = sin c
      e = cos d
  in [e] }

Implementation

The basic design here is to create a new call primitive, like core.call_p (which doesn't necessarily raise the abstraction level of its inputs), that has a special partial evaluation rule: stage out the full computation, not just the parts that couldn't be partially evaluated. We call the new primitive remat_call_p.

The implementation of the partial evaluation rule has three main steps:

  1. trace the full jaxpr of the called function, treating all inputs as "unknown" (but potentially using concrete avals for the arguments which are actually "known", if concrete=True),
  2. process the full jaxpr to determine which outputs are meant to be "known"/"unknown", and to prune any extra computation,
  3. compute any "known" values we need (as an optimization, extract values from any concrete avals instead of re-evaluating them).

Otherwise, it looks a lot like the standard call partial evaluation rule, which is JaxprTrace.process_call.

To make remat_call_p transposable, we need to upgrade our ad.backward_pass code to do some nonlinear evaluation (to rematerialize the residuals needed).

To translate remat_call_p to XLA in a way that prevents XLA's CSE optimization from foiling our rematerialization plans, we use a special widget that (according to friendly XLA team members) will in effect force a false data dependence to avoid CSE.

This change also adjusts our abstract evaluation rules for concrete values not to promote to the shaped level. (It was a long time ago, but I think we had that policy only because we didn't have a use case for performing FLOPs while creating jaxprs.) To ensure that this change doesn't accidentally lead to more FLOPs in user code, I tried running the test suite with an assert False in the new concrete-evaluation path and checked that it was never hit except in the new remat tests.

The upgrade to ad.backward_pass was co-authored with @dougalm, and is essentially part of #1719.

+512 -45

3 comments

11 changed files

mattjj

pr closed time in 8 days

issue closedgoogle/jax

add explicit checkpointing (rematerialization) control

Analogous to checkpoint in Autograd (but handling closed-over tracers).

closed time in 8 days

mattjj

push eventgoogle/jax

Roy Frostig

commit sha 16484352685af5934facd32d07cd9664192e57ac

enable kernel regression example test

view details

Roy Frostig

commit sha 6b39104a6d37c7a9add44ad15e5967a025bc42ad

Merge branch 'master' into kernel-example-test

view details

Peter Hawkins

commit sha 67038321f8b33bdab99c1fc3d003719553769c61

Revert support for building a non-GPU build with --config=cuda enabled. (#1757) It turns out there are implicit CUDA dependencies inside the TF libraries used by JAX, so the attempt to disable GPU dependencies conditionally didn't work.

view details

Peter Hawkins

commit sha 534d812b5777cc2da5bba11130626498f6aa2667

Add a handwritten ThreeFry2x32 CUDA kernel. (#1756) In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time. When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.

view details

Peter Hawkins

commit sha d1aa01874d2e25fba07c23aaa78cc3d623d17c21

Fix BUILD file formatting.

view details

Peter Hawkins

commit sha 3b7d92db79442b64fc24f695b80e95cbcbe825c8

Add missing pybind11 dependency.

view details

George Necula

commit sha 4e89d43a75a9d7a9d803ba8777b867c487e55ed6

Added JAX pytrees notebook Also added docstrings to the tree_util module.

view details

George Necula

commit sha b12a8019c8711afaef9b4d1a9f437bf944575cee

Update docs/notebooks/JAX_pytrees.ipynb Co-Authored-By: Stephan Hoyer <shoyer@google.com>

view details

George Necula

commit sha 8777864c96d91d205cdf96236f4ad86818f957c6

Minor edits

view details

George Necula

commit sha 132102498bd81867a7a07e578444d4639f70ba1a

Minor edit

view details

George Necula

commit sha 159690ae026fc8364cefc22990914a15cb10e106

Merge pull request #1758 from gnecula/bug_fix Added notebook for PyTrees

view details

Chris Jones

commit sha ace14bbb942bb1dd532fb07d3524e141b85808fc

Remove `join_pvals` This function is appears to be unused.

view details

Peter Hawkins

commit sha 34dfbc8ae6c697118b6a18faf02547aaab5946dc

Add error checking to PRNG CUDA kernel. (#1760) Refactor error checking code into a common helper library.

view details

Roy Frostig

commit sha 82ce209ae694bdbffa60155edbe34f4535e26d14

Merge branch 'master' into kernel-example-test

view details

Matthew Johnson

commit sha 36c882ba469fc009dba1a29df640c72eec12b36b

raise an error on jit-of-multi-host-pmap (#1761) Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

view details

Roy Frostig

commit sha 1b2350bf36bc2eabfacea8f409eb04e74514f69f

Merge pull request #1730 from google/kernel-example-test enable kernel regression example test

view details

Skye Wanderman-Milne

commit sha f415f266b890e7b1476879c108f6e846d7f93177

Remove 'backend' argument from device_put. (#1762) The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.

view details

George Necula

commit sha 227a91220b51ed883e28023e1f7eea5662af28f5

Update minimum jaxlib to 0.1.36. This is needed in part to pull in new Device.platform from Tensorflow. See #1764.

view details

George Necula

commit sha 8547050b09318414190b88d29a2d0914ccc325a3

Merge pull request #1765 from gnecula/bug_fix Update minimum jaxlib to 0.1.36.

view details

Peter Hawkins

commit sha 7a9f1f3f1c90b112fc1830da68a89b514716acce

Pin the minimum jaxlib version in travis.yml. (#1767)

view details

push time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 887ec1d948b990e186828d1201010c2034dec662

raise error if we do concrete aval FLOPs w/o remat

view details

push time in 8 days

create barnchgoogle/jax

branch : dynamic-scoping

created branch time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 10f6010d0d94fb4d9b51832699071fb9ac880787

make remat_call partial-eval into one remat_call

view details

push time in 8 days

pull request commentgoogle/jax

Add experimental rematerialization decorator

I learned something interesting! In 9987f73 I tried to make the partial eval rule of remat_call stage out two remat_calls, which in the context of reverse-mode ad would have the first only involve nonlinear computations (not depending on the tangents) and output residuals, while the second was linear. The aim was to be cleaner conceptually (don't mix linear and nonlinear in a single call), and also to avoid unnecessary recursing into calls when evaluating nonlinear constants in ad.backward_pass.

But that ran into an interesting problem: it would form jaxprs like this:

{ lambda  ;  ; a b.
  let c d = remat_call[ concrete=False ] 2.0 *
        { lambda  ;  ; a b.
          let c = sin a
              e = cos a
              f = cos c
          in [e, f] } [  ;  ]
      e f = remat_call[ concrete=False ] * b c d
        { lambda  ;  ; a b c e.
          let d = mul b c
              f = mul d e
          in [*, f] } [  ;  ]
  in [*, f] }

and that is not "round-trippable" in the sense that if we do a make_jaxpr on an eval_jaxpr the first remat_call will get partial-evaluated away (it has no data dependence on the input arguments). We use round-tripping like that in partial_eval_jaxpr e.g. to implement partial eval of scan. That meant that by trying to stage out separate nonlinear and linear remat_calls, I broke scan(remat(f), ...).

So I reverted to staging out a single (mixed linear/nonlinear) remat_call, and just special-casing it in ad.backward_pass to avoid unnecessary recursions into other call primitives (which are always linearized by our partial-eval-of-jvp machinery).

mattjj

comment created time in 8 days

push eventgoogle/jax

George Necula

commit sha 227a91220b51ed883e28023e1f7eea5662af28f5

Update minimum jaxlib to 0.1.36. This is needed in part to pull in new Device.platform from Tensorflow. See #1764.

view details

George Necula

commit sha 8547050b09318414190b88d29a2d0914ccc325a3

Merge pull request #1765 from gnecula/bug_fix Update minimum jaxlib to 0.1.36.

view details

Peter Hawkins

commit sha 7a9f1f3f1c90b112fc1830da68a89b514716acce

Pin the minimum jaxlib version in travis.yml. (#1767)

view details

Matthew Johnson

commit sha 2867e4be082237e2b184d7b533dfcbfa31b24f63

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

Peter Hawkins

commit sha fbc9446afa13b85a68484e8240aec07cfcb0fb8f

Fix some missing docstrings for Numpy functions. (#1768)

view details

Skye Wanderman-Milne

commit sha 3ae4a41320d6a116656626bfd0d87cde6c928335

Add "loops" BUILD target. (#1771)

view details

Matthew Johnson

commit sha b7579492690b1d94da89b7f1d1b6ddcfadbaacae

fix pulldown bugs

view details

Peter Buchlovsky

commit sha 8df1ccf42b6025e6e966df8a4ce034c27bdbff7e

Make jax.numpy.broadcast_to consistent with numpy. (#1773) * Make jax.numpy.broadcast_to consistent with numpy. jax.numpy.broadcast(10.0, ()) should return array(10.0) and not 10.0. * Improve broadcast_to test.

view details

Peter Hawkins

commit sha 5c96d83ea6247a8d27cd1f1a591d3f41fb8c0b64

Simplify einsum implementation. (#1774) XLA's DotGeneral operator has been generalized so we no longer need the _dot_general wrapper. Avoids the need for unnecessary reshapes.

view details

Peter Hawkins

commit sha da6a474a63bea7d7d27f3ee112cff75be6693a74

Simplify jax.numpy.tensordot by using lax.dot_general. (#1775)

view details

Tom Hennigan

commit sha ec79adccbb22730124210ef801d7e6ecfbf88316

source sync PiperOrigin-RevId: 282633556

view details

George Necula

commit sha c1d8d3f74d422222fe173d8e0ef5b05f9e2fd300

Add error checking that arguments of jvp are tuples

view details

George Necula

commit sha e0706ff86476271bfeb1b3e0055818c343fdf862

Relaxed check to allow both tuples and lists

view details

George Necula

commit sha 96f075db13787bd81d7ec40e17d4f21ba0c95299

Merge pull request #1777 from gnecula/bug_fix Add error checking that arguments of jvp are tuples

view details

Matthew Johnson

commit sha 6931489733a568a98a452ba3581f70bc3d7d1dea

update version for pypi

view details

Matthew Johnson

commit sha 80036744b35de220b8989b0e0bfa54eb956ab62c

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

Matthew Johnson

commit sha 9987f73c41724cbc8fd707728e6f16192aa9bc36

try remat_call partial-eval into two remat_calls The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!

view details

Matthew Johnson

commit sha a58cf307bf7f20138c82780a1457960ca1650c41

make remat_call partial-eval into one remat_call

view details

push time in 8 days

Pull request review commentgoogle/jax

Changed api.make_jaxpr to return a TypedJaxpr

 def jaxpr_maker(*args, **kwargs):     wrapped = lu.wrap_init(fun)     jax_args, in_tree = tree_flatten((args, kwargs))     jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)-    pvals = map(pv_like, jax_args)-    jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)-    return jaxpr+    in_pvals = map(pv_like, jax_args)+    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jaxtree_fun, in_pvals, instantiate=True)

You're right, I think we need instantiate=True to form a TypedJaxpr. Ignore my suggestion!

gnecula

comment created time in 8 days

Pull request review commentgoogle/jax

Changed api.make_jaxpr to return a TypedJaxpr

 def pp_eqn(eqn):           >> pp(' [ {} ; {} ]'.format(pp_vars(const_vars),                                       pp_vars(bound_vars))))   return (pp('{} = '.format(lhs)) >>-          pp(eqn.primitive.name) >> pp_kv_pairs(eqn.params.items())+          pp(eqn.primitive.name) >> pp_kv_pairs(eqn.params)

Maybe rename pp_kv_pairs to pp_dict or something, since the name no longer reflects the type?

Also because tuple sorting is lexical, I think you could alternatively keep the previous function pp_kv_pairs unmodified and have written here pp_kv_pairs(sorted(eqn.params.items())).

gnecula

comment created time in 8 days

Pull request review commentgoogle/jax

Changed api.make_jaxpr to return a TypedJaxpr

 def jaxpr_maker(*args, **kwargs):     wrapped = lu.wrap_init(fun)     jax_args, in_tree = tree_flatten((args, kwargs))     jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)-    pvals = map(pv_like, jax_args)-    jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)-    return jaxpr+    in_pvals = map(pv_like, jax_args)+    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jaxtree_fun, in_pvals, instantiate=True)

One theoretical downside to always having instantiate=True is that we can't then use make_jaxpr to illustrate what gets hoisted out of a function when e.g. we stage it out to XLA. One option might be to add an instantiate argument to make_jaxpr with a default value of True. Then it has the right default behavior but still lets us make the point about hoisting (which has been useful a time or two in answering JAXers questions IIRC).

I don't have a strong feeling there though, so whichever you think is best!

gnecula

comment created time in 8 days

issue commentpyro-ppl/brmp

Memory leak using numpyro backend

FYI after google/jax#1721 compilation cache entries will be deleted after the underlying Python callable is deleted, but the cache size can grow without bound. (Previously we had a limit on the number of cache entries, but we had to remove it in google/jax#1721 to support weakrefs.) We plan to write our own logic to limit the cache size (not just number of entries, by querying XLA for the size of each compiled program), likely with an LRU eviction policy, but we're not acting on that urgently because as far as we know no one is getting compilation cache OOMs.

If you know otherwise, i.e. if you ever see the compilation cache blowing up, let us know on the JAX issue tracker!

null-a

comment created time in 8 days

issue commentpyro-ppl/brmp

Memory leak using numpyro backend

Just pushed jax 0.1.53 to pypi. (If we're ever slow at releases and you need one, just open an issue on our issue tracker. It only takes a few seconds, but we often forget to do it!)

null-a

comment created time in 8 days

push eventgoogle/jax

Matthew Johnson

commit sha 6931489733a568a98a452ba3581f70bc3d7d1dea

update version for pypi

view details

push time in 8 days

startedgehring/fax

started time in 8 days

pull request commentgoogle/jax

Wrap optix rules in a namedtuple with init/apply

Nice!

tomhennigan

comment created time in 9 days

push eventgoogle/jax

Tom Hennigan

commit sha ec79adccbb22730124210ef801d7e6ecfbf88316

source sync PiperOrigin-RevId: 282633556

view details

push time in 9 days

PR merged google/jax

Wrap optix rules in a namedtuple with init/apply cla: yes

Wrap optix rules in a namedtuple with init/update.

In general we've found it convenient to be able to keep track of pairs of functions as a single object. For example:

# Init.
loss_fn = hk.transform(loss_fn)
optimizer = optix.sgd(...)

# Initial state.
params = loss_fn.init(rng, ...)
opt_state = optimizer.init(params)

# Step.
grads = jax.grad(loss_fn.apply)(params, ...)
updates, opt_state = optimizer.update(grads, opt_state)
params = optix.apply_update(params, updates)

PiperOrigin-RevId: 282633556

+22 -21

4 comments

2 changed files

tomhennigan

pr closed time in 9 days

push eventgoogle/jax

Matthew Johnson

commit sha 370a422914eda22587b08612202e2c9e16ccfb35

wip lazy sublanguage

view details

push time in 9 days

Pull request review commentgoogle/jax

Make jax.numpy.broadcast_to consistent with numpy.

 def testBroadcastToIntIssue1548(self):     self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),                         check_dtypes=False) +  def testBroadcastToOnScalar(self):+    self.assertTrue(isinstance(lnp.broadcast_to(10.0, ()), lnp.ndarray))

WDYT about also adding a line like

self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray))

in the spirit of self-documenting code?

petebu

comment created time in 9 days

Pull request review commentgoogle/jax

Make jax.numpy.broadcast_to consistent with numpy.

 def testBroadcastToIntIssue1548(self):     self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),                         check_dtypes=False) +  def testBroadcastToOnScalar(self):+    self.assertTrue(isinstance(lnp.broadcast_to(10.0, ()), lnp.ndarray))

self.assertIsInstance?

petebu

comment created time in 9 days

push eventgoogle/jax

Matthew Johnson

commit sha b7579492690b1d94da89b7f1d1b6ddcfadbaacae

fix pulldown bugs

view details

push time in 9 days

push eventgoogle/jax

Matthew Johnson

commit sha ae572960a20ac10dd2a7d4e7867f23dd0299e65b

wip lazy sublanguage

view details

push time in 9 days

delete branch google/jax

delete branch : grad-of-jit-caching

delete time in 9 days

push eventgoogle/jax

Matthew Johnson

commit sha 2867e4be082237e2b184d7b533dfcbfa31b24f63

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

push time in 9 days

PR merged google/jax

fix grad of jit caching bug cla: yes
+33 -6

0 comment

4 changed files

mattjj

pr closed time in 9 days

push eventgoogle/jax

Skye Wanderman-Milne

commit sha f415f266b890e7b1476879c108f6e846d7f93177

Remove 'backend' argument from device_put. (#1762) The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.

view details

Matthew Johnson

commit sha 22b7c9622176f849e5aaa79ce593904b52a20992

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

Matthew Johnson

commit sha 57dd913834a54dee921047af7c78be0374e83c47

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

push time in 9 days

push eventgoogle/jax

Matthew Johnson

commit sha 22b7c9622176f849e5aaa79ce593904b52a20992

fix grad of jit caching bug Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

push time in 9 days

PR opened google/jax

fix grad of jit caching bug
+33 -6

0 comment

4 changed files

pr created time in 9 days

create barnchgoogle/jax

branch : grad-of-jit-caching

created branch time in 9 days

push eventgoogle/jax

Matthew Johnson

commit sha f4c9a09866ec0eefcd101cf5371c5e561567773b

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

push time in 10 days

push eventgoogle/jax

Roy Frostig

commit sha 16484352685af5934facd32d07cd9664192e57ac

enable kernel regression example test

view details

Roy Frostig

commit sha 6b39104a6d37c7a9add44ad15e5967a025bc42ad

Merge branch 'master' into kernel-example-test

view details

Peter Hawkins

commit sha 45a1ba0bbcfc64e9c65f9d2589bb474a59312c99

Make more tests pass on TPU. (#1752)

view details

Peter Hawkins

commit sha 67038321f8b33bdab99c1fc3d003719553769c61

Revert support for building a non-GPU build with --config=cuda enabled. (#1757) It turns out there are implicit CUDA dependencies inside the TF libraries used by JAX, so the attempt to disable GPU dependencies conditionally didn't work.

view details

Peter Hawkins

commit sha 534d812b5777cc2da5bba11130626498f6aa2667

Add a handwritten ThreeFry2x32 CUDA kernel. (#1756) In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time. When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.

view details

Peter Hawkins

commit sha d1aa01874d2e25fba07c23aaa78cc3d623d17c21

Fix BUILD file formatting.

view details

Peter Hawkins

commit sha 3b7d92db79442b64fc24f695b80e95cbcbe825c8

Add missing pybind11 dependency.

view details

George Necula

commit sha 4e89d43a75a9d7a9d803ba8777b867c487e55ed6

Added JAX pytrees notebook Also added docstrings to the tree_util module.

view details

George Necula

commit sha b12a8019c8711afaef9b4d1a9f437bf944575cee

Update docs/notebooks/JAX_pytrees.ipynb Co-Authored-By: Stephan Hoyer <shoyer@google.com>

view details

George Necula

commit sha 8777864c96d91d205cdf96236f4ad86818f957c6

Minor edits

view details

George Necula

commit sha 132102498bd81867a7a07e578444d4639f70ba1a

Minor edit

view details

George Necula

commit sha 159690ae026fc8364cefc22990914a15cb10e106

Merge pull request #1758 from gnecula/bug_fix Added notebook for PyTrees

view details

Chris Jones

commit sha ace14bbb942bb1dd532fb07d3524e141b85808fc

Remove `join_pvals` This function is appears to be unused.

view details

Peter Hawkins

commit sha 34dfbc8ae6c697118b6a18faf02547aaab5946dc

Add error checking to PRNG CUDA kernel. (#1760) Refactor error checking code into a common helper library.

view details

Roy Frostig

commit sha 82ce209ae694bdbffa60155edbe34f4535e26d14

Merge branch 'master' into kernel-example-test

view details

Matthew Johnson

commit sha 36c882ba469fc009dba1a29df640c72eec12b36b

raise an error on jit-of-multi-host-pmap (#1761) Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

view details

Roy Frostig

commit sha 1b2350bf36bc2eabfacea8f409eb04e74514f69f

Merge pull request #1730 from google/kernel-example-test enable kernel regression example test

view details

Matthew Johnson

commit sha 14b8bab98f7c820d6b796442bf5844ed43b6d56f

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. The basic design here is to create a new call primitive, like `core.call_p` (which doesn't raise the abstraction level of its inputs), that has a special partial evaluation rule: stage out the full computation, not just the parts that couldn't be partially evaluated. We call the new primitive `remat_call_p`. The implementation of the partial evaluation rule has three main steps: 1. trace the full jaxpr of the called function, treating all inputs as "unknown" (but potentially using concrete avals for the arguments which are actually "known"), 2. process the full jaxpr to determine which outputs are meant to be "known"/"unknown", and to prune any extra computation, 3. compute any "known" values we need (and as an optimization, extract values from any concrete avals). Otherwise, it looks a lot like the standard call partial evaluation rule, which is JaxprTrace.process_call. To make `remat_call_p` transposable, we need to upgrade our `ad.backward_pass` code to do some nonlinear evaluation (to rematerialize the residuals needed). To translate `remat_call_p` to XLA in a way that prevents XLA's CSE optimization from foiling our rematerialization plans, we use a special widget that (according to friendly XLA team members) will in effect force a false data dependence to avoid CSE. This change also adjusts our abstract evaluation rules for concrete values not to promote to the shaped level. (It was a long time ago, but I think we had that policy only because we didn't have a use case for performing FLOPs while creating jaxprs.) The upgrade to ad.backward_pass was co-authored with @dougalm, and is essentially part of #1719. Fixes #1732. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

push time in 10 days

push eventgoogle/jax

Matthew Johnson

commit sha ef95cf497d5cb62a8cc7184e66d8f96d7589aa54

raise an error on jit-of-multi-host-pmap Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

view details

push time in 10 days

PR opened google/jax

raise an error on jit-of-multi-host-pmap

We can handle this case correctly, but for now we raise an error rather than silently producing incorrect values.

+22 -1

0 comment

1 changed file

pr created time in 10 days

create barnchgoogle/jax

branch : jit-of-multi-host-pmap

created branch time in 10 days

push eventgoogle/jax

Matthew Johnson

commit sha 5d5c8523dec486aa2d7ba427a82634f1475f2f76

Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. The basic design here is to create a new call primitive, like `core.call_p` (which doesn't raise the abstraction level of its inputs), that has a special partial evaluation rule: stage out the full computation, not just the parts that couldn't be partially evaluated. We call the new primitive `remat_call_p`. The implementation of the partial evaluation rule has three main steps: 1. trace the full jaxpr of the called function, treating all inputs as "unknown" (but potentially using concrete avals for the arguments which are actually "known"), 2. process the full jaxpr to determine which outputs are meant to be "known"/"unknown", and to prune any extra computation, 3. compute any "known" values we need (and as an optimization, extract values from any concrete avals). Otherwise, it looks a lot like the standard call partial evaluation rule, which is JaxprTrace.process_call. To make `remat_call_p` transposable, we need to upgrade our `ad.backward_pass` code to do some nonlinear evaluation (to rematerialize the residuals needed). To translate `remat_call_p` to XLA in a way that prevents XLA's CSE optimization from foiling our rematerialization plans, we use a special widget that (according to friendly XLA team members) will in effect force a false data dependence to avoid CSE. This change also adjusts our abstract evaluation rules for concrete values not to promote to the shaped level. (It was a long time ago, but I think we had that policy only because we didn't have a use case for performing FLOPs while creating jaxprs.) The upgrade to ad.backward_pass was co-authored with @dougalm, and is essentially part of #1719. Fixes #1732. Co-authored-by: Dougal Maclaurin <dougalm@google.com>

view details

push time in 10 days

pull request commentgoogle/jax

Add experimental rematerialization decorator

@dougalm spotted an issue: when tracing on concrete values (to enable Python control flow) we might still incur extra work if we have jit mixed in in the right way. Concretely, here's the "three sines" problem he came up with:

from jax import jit, grad, remat
import jax.numpy as np

@jit
def g(x):
  return f(2., x)

@remat
def f(x, y):
  return np.sin(x) * y

grad(g)(3.)  # sin_p.impl evaluated three times!

Running this, sin(2.) is evaluated three times rather than the intended two. The issue is that the residual value sin(2.0) is computed but thrown away when tracing the function on concrete avals (i.e. in the first of the three steps in the OP, it's thrown away when mul has one concrete and one shape-abstracted-by-jit argument), then recomputed when we need to evaluate the known outputs (i.e. in the third of the three steps in the OP). I believe jit is essential here (and that's why JAX is harder than Autograd!).

The fix we want to try is to make tracing on concrete values (to support Python control flow) optional, at the risk of doing redundant FLOPs when jit is involved in slightly tricky ways.

mattjj

comment created time in 10 days

push eventgoogle/jax

Chris Jones

commit sha ace14bbb942bb1dd532fb07d3524e141b85808fc

Remove `join_pvals` This function is appears to be unused.

view details

push time in 10 days

PR merged google/jax

Remove `join_pvals` cla: yes

This function is appears to be unused.

+0 -22

0 comment

1 changed file

chr1sj0nes

pr closed time in 10 days

push eventgoogle/jax

Matthew Johnson

commit sha d1b78a7332fb21c2aefa017c20fa3e2f8c1d4780

wip lazy sublanguage

view details

push time in 11 days

push eventgoogle/jax

Matthew Johnson

commit sha a4faaaf124a45ca54f57a8e4e4d356a3fbfd22c3

wip lazy sublanguage

view details

push time in 11 days

push eventgoogle/jax

Matthew Johnson

commit sha 29b6546e635758d222ee10134dd9d7353625cad6

wip lazy sublanguage

view details

push time in 11 days

more