google/jax 9359

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Research language for array processing in the Haskell/ML family

A pedagogical implementation of Autograd

duvenaud/relax 133

Optimizing control variates for black-box gradient estimation

jacobjinkelly/easy-neural-ode 54

Code for the paper "Learning Differential Equations that are Easy to Solve"

Recognizing and exploiting conjugacy without a domain-specific language

Prototypes of differentiable differential equation solvers in JAX.

my fish configuration

my .vim

push eventgoogle/jax

commit sha 241267c275b14305ad85f0688f679d660ed90182

omnistaging on by default

push time in 14 hours

push eventgoogle/jax

commit sha e7b74caa79b9592ab3b97d7c899f9c7e030ae56a

omnistaging on by default

push time in 14 hours

issue commentgoogle/jax

Fastest implementation of an elementwise gradient?

One more thought: to get the second derivatives, you just need to write `grad(grad(...))`

inside the double-vmap. Getting rid of batch dimensions makes things so much easier!

comment created time in 15 hours

push eventgoogle/jax

commit sha 69ddb9f8ad17948384fd398e0b22b30a3b34031a

omnistaging on by default

push time in 15 hours

issue commentgoogle/jax

Fastest implementation of an elementwise gradient?

There may be a nicer way to write the nested vmap using the `vectorize`

wrapper.

comment created time in 15 hours

issue commentgoogle/jax

Fastest implementation of an elementwise gradient?

Thanks for the question! I agree with what @shoyer said.

To say a bit more, if f: R^n -> R^m then we expect its Jacobian to be an m x n matrix. The function being differentiated here, namely `RBF.forward`

with respect to its first argument, takes an input of shape (50, 1) and produces an output of shape (50, 50), which is why we expect to see a Jacobian of shape (50, 50, 50, 1). (If we want to think in flattened terms, we'd say that the input has dimension n=50 and the output has dimension m=2500.)

The trouble here is we just want to see the data axes of size 50 as batch axes. That is, we really want to think of the kernel as taking two vectors and outputting a scalar. We only carry along the batch axis in the implementation for vectorization efficiency.

But with `vmap`

, we don't need to batch by hand, and as a result it can be easier to express the functions and derivatives we want. Concretely, take a look at the `forward_`

and `dkdx2_`

methods below (note the underscores on the end), and the assertion:

```
class RBF(Kernel):
# ...
def forward_(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
assert thetas.shape == (1,)
length_scale = thetas[0]
dist_sq = np.vdot(x1, x1) + np.vdot(x2, x2) - 2 * np.vdot(x1, x2)
return np.exp(-0.5 * dist_sq / length_scale**2)
class RBFGrad(RBF):
def __init__(self, length_scale=1.0):
super(RBFGrad, self).__init__(length_scale)
self.dkdx1 = jit(jacfwd(super(RBFGrad, self).forward, argnums=0))
self.dkdx2 = jit(jacfwd(super(RBFGrad, self).forward, argnums=1))
self.dk2dx1dx2 = jit(jacfwd(jacrev(super(RBFGrad, self).forward, argnums=0), argnums=1))
self.dkdx2_ = jit(vmap(vmap(grad(super().forward_, argnums=1), (0, None, None)), (None, 0, None)))
def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
K = super().forward(x1, x2, thetas)
dx2 = self.dkdx2(x1, x2, thetas).sum(-2)
dx2_ = self.dkdx2_(x1, x2, thetas)
assert np.allclose(dx2, dx2_)
# ...
```

I rewrote the `forward_`

method so that it more obviously applies to single vectors at a time, and also to use the polarization identity ||u - v||^2 = ||u||^2 + ||v||^2 - 2 u \cdot v which gives us more matrix multiplies and I think is often faster (though one should benchmark against computing ||u - v||^2 more directly as before). I also added the `dkdx2_`

method, which uses `vmap`

to do all the batching, and compared against the old calculation (it may be faster to use `jvp`

than `grad`

but this was convenient). That second bit, with the `vmap`

s, is the main thing I wanted to illustrate, as it was basically @shoyer's advice.

One other piece of advice would be to put `jit`

on `RBFGrad.forward`

, because the more code XLA can see the more it can optimize things.

WDYT?

comment created time in 15 hours

issue closedgoogle/jax

Broadcasting minval and maxval in jax.random.uniform

Hi JAX team!

Currently, it seems that broadcasting of minval and maxval in `jax.random.uniform`

is not supported.
https://github.com/google/jax/blob/a169743f64e9ddf826cdad5c225e3e18a980a1db/jax/random.py#L343

In numpy, this is supported

```
np.random.uniform(low=jnp.zeros(2), high=jnp.ones(2), size=(10,2))
```

In JAX, this will result in a TypeError `TypeError: max got arrays of different rank: (2,), (10, 2).`

.

```
jax.random.uniform(
key=jax.random.PRNGKey(42),
shape=(10,2),
minval=jnp.zeros(2),
maxval=jnp.ones(2))
```

A manual broadcasting will fix this issue

```
jax.random.uniform(
key=jax.random.PRNGKey(42),
shape=(1,2),
minval=jnp.broadcast_to(jnp.zeros(2), (1,2)),
maxval=jnp.broadcast_to(jnp.ones(2), (1,2)))
```

Is JAX's behaviour intended? If not, I am happy to send a PR fixing this.

closed time in 16 hours

ethanluoycpush eventgoogle/jax

commit sha 85eea219e3e5041f3f9d5e7884db62a27492d58c

omnistaging on by default

push time in 16 hours

issue commentgoogle/jax

host_callback doesn't work inside grad(odeint)

Let me know if I can help! I looked at some of this bookkeeping recently.

comment created time in 17 hours

issue commentgoogle/jax

any suggestions on how to improve performance for gradient step with odeint?

By the way, XLA:CPU (and hence JAX on CPU) has some known performance issues with float64, e.g. I believe the 64bit GEMM kernel being called is a slow Eigen one, while the 32bit one is from MKL-DNN. I mention it just because I don't want it to be a confounder, though I haven't looked at the details here enough to know if it could be relevant.

I'd love if an outcome of this investigation is that we should add some alternative VJPs for odeint, because that sounds really fun!

comment created time in 17 hours

pull request commentgoogle/jax

Added initial implementation of numpy equivalent for trim_zeros to jax

Hey @gsp-27 , thanks for the contribution!

It turns out that `trim_zeros`

isn't a good candidate for JIT compilation with XLA, because the output shape depends on the values of the input. Because XLA programs have static shapes in their type system, that means we'd need to create and compile a separate program for most input values. That's why you're seeing issues with `_CompileAndCheck`

.

We could still include an implementation of `trim_zeros`

in `jax.numpy`

and just allow it to error when people try to include it in `jit`

functions, but that means the upside here is limited so it might also make sense just to leave it out. If we did want to include it, we'd likely need another implementation technique rather than the Python loop implementation here, perhaps based on `lax.while_loop`

(though I'm not sure about that).

Overall I'd recommend not adding `trim_zeros`

for the above reason.

What do you think?

comment created time in a day

issue closedgoogle/jax

I tried to run the example code in README to use jax on TPU from a GCE VM, but I got RuntimeError.

```
2020-08-11 08:26:26.301652: E external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc:1068] Failed to open the gRPC driver: 12: :
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/random.py", line 85, in PRNGKey
k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF))
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/random.py", line 81, in <lambda>
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lax/lax.py", line 397, in convert_element_type
operand, new_dtype=new_dtype, old_dtype=old_dtype)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/core.py", line 276, in bind
return self.impl(*args, **kwargs)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 236, in xla_primitive_callable
backend = xb.get_device_backend(device)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 176, in get_device_backend
return get_backend(platform)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 170, in get_backend
return backend(platform)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jax/lib/xla_bridge.py", line 148, in _get_tpu_driver_backend
_tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
File "/home/sj2660089/.local/lib/python3.7/site-packages/jaxlib/tpu_client.py", line 59, in create
return _tpu_client.TpuClient.Get(worker)
RuntimeError: Unimplemented: Failed to connect to remote server at address: grpc://10.240.1.2:8470. Error from gRPC: . Details:
```

Any idea why this happens?

closed time in a day

JaySunnnissue commentgoogle/jax

@jekbradbury @skye you are the greatest!

comment created time in a day

push eventgoogle/jax

commit sha c564aca77710df0599715d4231b7d5b7dd46984a

skip more ode tests on gpu, b/c slow to compile (#4028)

push time in a day

issue commentgoogle/jax

@skye can you take a look at this one?

comment created time in 2 days

startedlocuslab/monotone_op_net

started time in 2 days

push eventgoogle/jax

commit sha fe9f264b55f8b99c57f803db9eb7a2c8df897e9b

cumulative jet rules (#4000)

push time in 2 days

PR merged google/jax

Below is my attempt at adding jet rules for cumulative operations, following jvp rules defined in lax.py

pr closed time in 2 days

push eventgoogle/jax

commit sha 1e07712955939d6f8f461fc259b12a20808782b3

Fix typos in api.py docstrings (#4021)

push time in 2 days

PR merged google/jax

pr closed time in 2 days

pull request commentgoogle/jax

Fix some typos in api.py docstrings

Thanks for catching these! Possibly related, my 'r' key is broken and often types 0 or 2 copies of the letter! :P

comment created time in 2 days

Pull request review commentgoogle/jax

Initial version of vmap collectives

def process_primitive(self, primitive, tracers, params): vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims_in): return primitive.bind(*vals_in, **params)+ elif config.omnistaging_enabled and primitive in collective_rules:+ axes_names = params['axis_name']+ if not isinstance(axes_names, (tuple, list)):+ axes_names = (axes_names,)+ for i, axis_name in enumerate(axes_names):+ frame = core.axis_frame(axis_name)+ if frame.tag is self.master:

Yeah, I agree.

comment created time in 2 days

push eventgoogle/jax

commit sha d100327bf33515cedca41eedfc30d6fb49de7ef2

Remove type restrictions (#4017) We support half16 on TPU

push time in 2 days

PR merged google/jax

We support half16 on TPU

pr closed time in 2 days

push eventgoogle/jax

commit sha 09d8ac14de8c9edc2c9a14becb4963bfdebda605

use fewer internal APIs in custom interp notebook (#4016)

push time in 2 days

issue closedgoogle/jax

clarify "custom jaxpr interpreters" notebook uses internal APIs that may break

closed time in 2 days

mattjjpush eventgoogle/jax

commit sha 6a3b920507dcae7a4e4dfa513155222fa0c6feb1

make make_jaxpr work on tracer example args (#4014) (don't use xla.abstractify)

commit sha 265c3faa405260fb563dd492dabad0752b42e842

Remove type restrictions (#4011) We support s8, u8, s16, u16, half16 on TPU

commit sha f50216316f692d7c261aeedfc74a24e1424f9a42

use fewer internal APIs in custom interp notebook

push time in 2 days

push eventgoogle/jax

commit sha 265c3faa405260fb563dd492dabad0752b42e842

Remove type restrictions (#4011) We support s8, u8, s16, u16, half16 on TPU

push time in 2 days

PR merged google/jax

We support s8, u8, s16, u16, half16 on TPU

pr closed time in 2 days

issue commentgoogle/jax

host_callback doesn't work inside grad(odeint)

@gnecula want to take this one?

comment created time in 2 days

push eventgoogle/jax

commit sha 6a3b920507dcae7a4e4dfa513155222fa0c6feb1

make make_jaxpr work on tracer example args (#4014) (don't use xla.abstractify)

push time in 2 days

Pull request review commentgoogle/jax

DeviceArray.__iter__ returns DeviceArrays, without host sync

def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else:- return self._value.__iter__()+ device = self.device_buffer.device()+ if device is None or device.platform == 'cpu':+ return iter(self._value)

Yeah, good point. We should revise this case so that it returns CPU DeviceArrays, maybe.

comment created time in 3 days

push eventapaszke/jax

commit sha 2055d6ce421afd3dd8e278c5c26fdb6a5b3addeb

deflake

push time in 3 days

Pull request review commentgoogle/jax

Initial version of vmap collectives

def process_primitive(self, primitive, tracers, params): vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims_in): return primitive.bind(*vals_in, **params)+ elif config.omnistaging_enabled and primitive in collective_rules:+ axes_names = params['axis_name']+ if not isinstance(axes_names, (tuple, list)):+ axes_names = (axes_names,)+ for i, axis_name in enumerate(axes_names):+ frame = core.axis_frame(axis_name)+ if frame.tag is self.master:

I know you explained this to me before, but I forget: is the tag just to handle possible shadowing? If so, after this came up in a convo with James and Roy last week, I've started to think we should make shadowing illegal, which might also help us slightly simplify this logic.

For now, just flagging this for more discussion!

comment created time in 3 days

Pull request review commentgoogle/jax

Initial version of vmap collectives

def _psum_transpose_rule(cts, axis_name, axis_index_groups): pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p)+batching.collective_rules[psum_p] = \+ lambda vals, dims, axis_size, **_: [v.sum(d) if d is not batching.not_mapped else

Currently we might need to lower this to a reduce-sum followed by a broadcast, as per this comment and this one. But also we might want to revise that, so that psum just means reduce-sum... I'm not sure what the best thing to do is, just wanted to flag this as a discussion topic!

comment created time in 3 days

pull request commentgoogle/jax

Wow, awesome! Thanks for pushing on this.

(By the way, our GitHub CI tests have been really flakey for the last ~week, failing on dependency installation. So if you see dependency installation failures, assume they're not your fault!)

comment created time in 3 days

push eventgoogle/jax

commit sha d46ea969533dbfe4451517ac9a5e51dbdeee6d5d

add helper for flax to be omnistaging-compatible (#4004)

push time in 3 days

issue commentgoogle/jax

why is this still not officially supported?

Well, we have a small team, and so even though there's a lot of work worth doing, we have to choose what to prioritize.

I believe JAX works great with WSL, so if that works on your Windows setup you might want to give it a try.

@ericmjl I love your attitude! We've benefitted a huge amount from open-source contributions, and I hope we can get a lot more over time! We're all on the same team here, doing our best to push things forward.

comment created time in 4 days

issue closedgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

repro:

```
from jax import numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (3,), dtype=jnp.bfloat16)
```

observation: - crashes on TPU backends (both internal and cloud TPU as far as I can tell) - CPU and GPU backends don't seem to crash

expect: even if bfloat16 isn't supported by random.* on TPU it would be better to error-out rather than crashing

closed time in 4 days

levskayaissue commentgoogle/jax

random uniform with dtype bfloat16 crashes on TPU backend

Thanks, @majnemer! I just double-checked on an internal TPU colab and indeed it works. Thanks to your team for the fix.

comment created time in 4 days

issue closedgoogle/jax

code:

```
def gumbel_sample(log_probs, temperature=1.0):
"""Gumbel sampling from a categorical distribution."""
u = numpy.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
g = -np.log(-np.log(u))
return np.argmax(log_probs + g * temperature, axis=-1)
def predict(num_chars, prefix):
inp = [ord(c) for c in prefix]
result = [c for c in prefix]
max_len = len(prefix) + num_chars
for _ in range(num_chars):
cur_inp = np.array(inp + [0] * (max_len - len(inp)))
outp = model(cur_inp[None, :]) # Add batch dim.
next_char = gumbel_sample(outp[0, len(inp)])
inp += [int(next_char)]
if inp[-1] == 1:
break # EOS
result.append(chr(int(next_char)))
return "".join(result)
print(predict(32, ""))
```

Error:

```
LayerError Traceback (most recent call last)
<ipython-input-27-6f9f9afc30e6> in <module>
22 return "".join(result)
23
---> 24 print(predict(32, ""))
<ipython-input-27-6f9f9afc30e6> in predict(num_chars, prefix)
12 for _ in range(num_chars):
13 cur_inp = np.array(inp + [0] * (max_len - len(inp)))
---> 14 outp = model(cur_inp[None, :]) # Add batch dim.
15 next_char = gumbel_sample(outp[0, len(inp)])
16 inp += [int(next_char)]
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in __call__(self, x, weights, state, rng)
165 self.state = state # Needed if the model wasn't fully initialized.
166 state = self.state
--> 167 outputs, new_state = self.pure_fn(x, weights, state, rng)
168 self.state = new_state
169 self.weights = weights
/opt/conda/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
448 name, trace = self._name, _short_traceback(skip=3)
449 raise LayerError(name, 'pure_fn',
--> 450 self._caller, signature(x), trace) from None
451
452 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/<ipython-input-13-6f7ffe0c061e>, line 21
layer input shapes: ShapeDtype{shape:(1, 32), dtype:int32}
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Embedding_256_512 (in pure_fn):
layer created in file [...]/<ipython-input-13-6f7ffe0c061e>, line 18
layer input shapes: ShapeDtype{shape:(1, 32), dtype:int32}
File [...]/trax/layers/core.py, line 150, in forward
return jnp.take(self.weights, x, axis=0)
File [...]/jax/numpy/lax_numpy.py, line 3298, in take
slice_sizes=tuple(slice_sizes))
File [...]/jax/lax/lax.py, line 835, in gather
slice_sizes=canonicalize_shape(slice_sizes))
File [...]/site-packages/jax/core.py, line 273, in bind
return self.impl(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 228, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File [...]/jax/interpreters/xla.py, line 262, in xla_primitive_callable
*avals, **params)
File [...]/jax/interpreters/xla.py, line 320, in primitive_computation
raise RuntimeError(msg) from e
RuntimeError: Invalid argument: Slice size at index 0 in gather op is out of range, must be within [0, 1), got 1.:
This is a bug in JAX's shape-checking rules; please report it!
```

closed time in 4 days

Keshav15issue commentgoogle/jax

No worries, we always welcome issues, and also welcome them fixing themselves :)

comment created time in 4 days

issue closedgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

The tensordot on this line does not perform the intended broadcast matrix-vector multiply. I think you need this (or equivalent),

```
return mean + np.einsum('...ij,...j->...i', chol_factor, normal_samples)
```

Here's an example:

```
import jax.numpy as np
import jax.random
mean = np.zeros((10, 4))
cov = np.eye(4)[None,...].repeat(10, axis=0)
rng = jax.random.PRNGKey(0)
sample = jax.random.multivariate_normal(rng, mean, cov)
print(sample.shape) # is (10, 10, 4); should be (10, 4).
```

closed time in 4 days

slindermanissue commentgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

Merged the fix from Scott in #4002!

comment created time in 4 days

push eventgoogle/jax

commit sha ea88c55f555780cdccc27c13d19fd393e695e120

Fixes and tests for jax.random.multivariate_normal (#4002) * Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices. It works as long as mean and covariance are broadcast-compatible, as specified in the docstring. * Fix bug in multivariate_normal shape checking Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix. * Add test for multivariate_normal shapes This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes. * Fix linter issues in tests/random_test.py Trimming trialing whitespace and 80 char limit. * Really trimming whitespace in tests/random_test.py Arg. Have to fix my editor to do this automatically.

push time in 4 days

PR merged google/jax

This PR addresses bug https://github.com/google/jax/issues/3997 and another shape issue exposed during testing.

pr closed time in 4 days

issue openedgoogle/jax

clarify "custom jaxpr interpreters" notebook uses internal APIs that may break

created time in 4 days

issue commentgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

I left this bug in there just for you to find, Scott.

comment created time in 5 days

issue commentgoogle/jax

jax.random.multivariate_normal produces incorrect output with batches of covariance matrices

Yeah! Let's get @slinderman's first of many JAX PRs here!

comment created time in 5 days

issue commentgoogle/jax

any suggestions on how to improve performance for gradient step with odeint?

As a wild guess, not packing values together with `jnp.stack`

(as in the return value of `SEIRD_mobility_coupled`

) can help both compilation time and performance.

comment created time in 5 days

issue commentgoogle/jax

any suggestions on how to improve performance for gradient step with odeint?

Is it slow to execute (after compiling), or just to compile? That is, if you try evaluating the gradient a second time, is it faster on the second evaluation?

comment created time in 5 days

issue commentgoogle/jax

Defining both custom_jvp and custom_jvp

Thanks for the question!

No, that's not currently possible. It's documented in the custom derivatives tutorial (see the subsection "Forward-mode autodiff cannot be used on the `jax.custom_vjp`

function and will raise an error" for more).

Can you say more about why you might need to define both?

comment created time in 5 days

issue commentgoogle/jax

inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))

Okay, back to it!

Notice the Jacobian of the `agg_loss`

function written with the broadcast is `[[0, 1], [0, 1]]`

:

```
import jax.numpy as jnp
from jax import lax
from jax import jacfwd
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(0), (2,))
return jacfwd(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [[0. 1.]
# [0. 1.]]
```

So, while keeping the current (at HEAD) bulk array definition of `psum`

as a reduce-sum *followed by a broadcast*, the SPMD AD semantics is consistent so long as we take `grad`

to mean "compute the VJP against a ones vector *broadcast along all named axes*":

```
import jax.numpy as jnp
from jax import vjp, jvp, pmap, grad
from jax import lax
### reverse-mode
# At HEAD, we define this SPMD program:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [0. 2.]
# To mean the same as this bulk array vjp-with-ones program (`grad` is always
# defined as vjp-with-ones plus an error check for scalar outputs that we don't
# include in the definition of SPMD semantics):
def grad2(f):
def gradfun(x):
ans, f_vjp = vjp(f, x)
x_bar, = f_vjp(jnp.ones_like(ans))
return x_bar
return gradfun
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,)) # bulk array version of psum
return grad2(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [0. 2.]
# ### forward-mode
# At HEAD, we define this SPMD program:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [1. 1.]
# To mean the same as this bulk array jvp-with-ones program:
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,)) # bulk array version of psum
return jvp(agg_loss, (w,), (jnp.ones_like(w),))[1]
print(f(jnp.ones(2), jnp.arange(2)))
# [1. 1.]
```

(In Autograd, like TF today, we used to define `grad`

as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn't make that check work with SPMD functions, in the sense that `grad`

will in effect happily allow broadcasting along named mapped axes!)

If the semantics at HEAD are self-consistent, except the error semantics for `grad`

, do we need to change anything, except perhaps to avoid this potential confusion by making `grad`

error semantics consistent in the positional and SPMD worlds?

Maybe yes. One problem with the current semantics is that if we make @sharadmv's use of `grad`

here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they'd have to write vjp-with-ones themselves, e.g. by defining `grad2`

as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we're calling `grad`

/`grad2`

on a scalar-input scalar-output function (but for the closed-over value of `data`

which is different on each device) with the same primal input value on every device, yet getting different `grad2`

results on different devices (perhaps not noticing that if we looked at the primal *output* value we'd also have a different value on each device, which might make getting different gradients less surprising).

In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for `grad`

), this example has clearly shown that they can be confusing, especially when closing over mapped data.

comment created time in 6 days

issue commentgoogle/jax

inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x))

I changed the issue title because I think the current (i.e. at HEAD) semantics make sense, but can make `grad(f)(x)`

and `jvp(f, (x,), (1.,))`

differ for scalar-input scalar-output `f`

when inside a `pmap`

, which seems worth revising. The current semantics just take a different correspondence to bulk array programs than the one in @jekbradbury's previous comment.

The original intention with `psum`

was to correspond to a bulk array operation that included broadcasting, not just a reduce-sum:

```
pmap(lambda x: psum(x, 'i'), axis_name='i')
==
lambda x: lax.broadcast(x.sum(0), (num_devices,))
```

So given this `pmap`

code:

```
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
```

we should expect it to behave like this bulk array version:

```
import jax.numpy as jnp
from jax import grad
from jax import lax
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,))
return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
```

But that code triggers an error, because in api.py we check to ensure `grad`

is only being applied to scalar-output functions. Just for defining semantics, let's sidestep that error check:

```
import jax.numpy as jnp
from jax import vjp
from jax import lax
def grad(f):
def gradfun(x):
ans, f_vjp = vjp(f, x)
x_bar, = f_vjp(jnp.ones_like(ans))
return x_bar
return gradfun
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,))
return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]
```

In effect, we had to write our own `grad`

that was willing to produce a broadcasted ones cotangent vector.

These SPMD semantics explain the reverse-mode answer in my first comment above.

The trouble with this definition, though, is that it can make `grad`

disagree with `jvp`

and numerical differences on what appears locally to be a scalar-input scalar-output function when inside a pmap, as in the examples in my first comment. I want to write out an explanation for what's going on, but I've got to step away for a moment and wanted to send this comment first. To be continued!

In any case, I think @jekbradbury 's proposal for revising the semantics is likely to be better. I just want to pin down both the old and new semantics as best we can.

comment created time in 6 days

push eventgoogle/jax

commit sha 161e81cebed0432c7d211ca33a837aadc2d7d0c0

support collectives with new pmap semantics Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: Roy Frostig <frostig@google.com>

push time in 6 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

@jekbradbury that's brilliant! Let's do it.

comment created time in 7 days

pull request commentgoogle/jax

Rm two unused lines of code from lax_parallel.psum_bind

This was a bad merge on my part! I believe this passes internal tests.

comment created time in 7 days

issue commentgoogle/jax

Prevent statements with side effects from being re-ordered

Yeah that's my thinking too!

comment created time in 7 days

issue commentgoogle/jax

Prevent statements with side effects from being re-ordered

The omnistaging change ensures that we trace the Python to a jaxpr in a way that preserves operation ordering, but I suspect we still need to thread tokens in the XLA HLO lowering. That is, JAX now has the tools to avoid reordering without user-level tokens, but XLA may still reorder things.

@hawkinsp does that sound right to you?

comment created time in 7 days

issue commentgoogle/jax

Very slow compile when doing stochastic variational inference on the parameters of an ODE integrator

Thanks for raising this; that's too slow! Which backend is this on (CPU/GPU/TPU)? It might be faster or slower to compile on different backends (often slowest on CPU, fastest on TPU).

comment created time in 7 days

issue commentgoogle/jax

Cannot use static scalar arrays in comparison when using omnistaging

I wonder if this is an interaction between the laziness of `np.ones`

and omnistaging...

comment created time in 7 days

issue commentgoogle/jax

Implement complex QR decomposition in HLO (TPU)

GE and LT are "greater than" and "less than", respectively. That error is saying that those operations aren't implemented for complex numbers. (NumPy implements them with a kind of weird convention, since complex numbers don't have a total ordering.)

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

I'm starting to agree with James that we might need to track spmd-symmetry in jaxprs, and reverse mode should do something interesting with that information. It's tricky: if we have a jaxpr to transpose with both symmetric and non-symmetric outputs, we may need to run two backward passes, one for the symmetric outputs (where we symmetrize the resulting input cotangnents) and one for the non-symmetric outputs (where we don't).

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

Thanks! One possible desideratum is not to do something different on guaranteed-symmetric values versus just-happens-to-be-symmetric values. Notice that in this example if `data`

just happens to be symmetric, we get the right answer. Also, if we noticed (say in `backward_pass`

) that the function to be transposed is symmetric (in the sense that when it's given a symmetric input it produces a symmetric output), and in that case always symmetrized the cotangents of the jaxpr inputs, then we'd compute the right answer in both guaranteed-symmetric (new!) and just-happens-to-be-symmetric (already got that one) cases.

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

Can you define 'replicated thing' and 'non-replicated thing'? Do you mean a function (`agg_loss`

in this case), and if so what's it mean to be replicated (maybe: has any collectives in it) ?

comment created time in 7 days

issue commentgoogle/jax

`grad(lambda x: psum(loss(x)))` computes incorrect gradients

This is confusing! It seems to be an issue about how we define transposes of SPMD functions (i.e. it's not just a dumb software bug). No conclusions yet, but wanted to dump some thoughts here.

Here are some test programs:

```
import jax.numpy as jnp
from jax import vmap, pmap, grad, lax
def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]
def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]
def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]
```

The jaxpr being transposed in the last one is

```
{ lambda c ; b.
let d = mul b c
e = psum[ axis_index_groups=None
axis_name=batch ] d
in (e,) }
```

Operationally, the transpose evaluation proceeds like this: `grad`

feeds in a cotangent for `e`

of `1.0`

. Then we compute a cotangent for `d`

of `2.0`

, reflecting the fact that if we perturb the value of `d`

by epsilon (on each replica) then the value of `e`

changes by 2 epsilon. Then we compute a cotangent for `b`

by multiplying by `0.`

on the first replica and `1.`

on the second replica, leading to the final result `[0. 2.]`

.

Beyond the numbers, we have a symmetry issue. The first two versions produce a symmetric-across-replicas result, which makes sense because `agg_loss`

's last operation is a `psum`

and so its output must be symmetric across replicas. But the reverse-mode version can't produce a symmetric result because it multiplies by the mapped value `data`

at the end. (The symmetric result makes the most sense because `agg_loss`

always produces a symmetric result given a symmetric input.)

Hmmmm...

comment created time in 7 days

Pull request review commentgoogle-research/dex-lang

def sampleCosineWeightedHemisphere (k:Key) (normal: Vec3) : Vec3 = rr = (rx .* uu) + (ry .* vv) + (rz .* normal) normalize rr +' ### Raytracer++Distance = Real+def Image (n:Int) :Type = Fin n => Fin n => Color -- TODO: hide the size++Position = Vec 3+Direction = Vec 3 -- Should be normalized. TODO: use a newtype wrapper --- Misc. rendering params.-MAX_ITERS = 50-HORIZON = 20.0-MAX_DEPTH = 3--def raymarch (ro:Vec3) (rd: Vec3) : (Object & Distance) =- -- Move along ray until we hit an obect.- def step (i:(Fin MAX_ITERS)) (pair:(Object & Distance)) : (Object & Distance) =- (_, t) = pair- (obj, distance) = sdScene (ro + (t .* rd))- (obj, t + distance)-- (obj_id, t) = fold (Obj_None, 0.0) step- obj_id = select (t > HORIZON) Obj_None obj_id - (obj_id, t)---def calcNormal (p:Vec3) : Vec3 =- dist = \p. snd $ sdScene p- -- normalize(grad(dist)(p))-- -- derivative approximation via midpoint rule.- -- Todo: Switch to autodiff when it works.- eps = 0.001- dx = [eps, 0.0, 0.0]- dy = [0.0, eps, 0.0]- dz = [0.0, 0.0, eps]- -- extract just the distance component- nor = [(dist(p+dx)) - (dist(p-dx)),- (dist(p+dy)) - (dist(p-dy)),- (dist(p+dz)) - (dist(p-dz))]- normalize nor---def sample_light_point (key:Key) : Vec3 =- -- Samples a point uniformly from the surface of the light.- -- Todo: remove hardcoded values- -- Todo: Allow multiple lights- (key1, key2) = splitKey key- p_light_x = randuniform key1 (-0.25) (0.25)- p_light_z = randuniform key2 (2.0 - 0.25) (2.0 + 0.25)- [p_light_x, 3.9, p_light_z]---def direct_light (key:Key) (p:Vec3) (nor:Vec3) (brdf:Color) : Vec3 =-- -- Check for line-of-sight to a random point on the light.- p_light = sample_light_point key- wi_light = normalize (p_light - p)- (obj2, t2) = raymarch (p + 0.001 .* nor) wi_light-- -- Compute radiance of light from this direction.- vis = islight obj2- cos1 = relu (vdot nor wi_light)- cos2 = relu (vdot nor_light (negvec wi_light))- pdf_A = 1.0 / LIGHT_AREA- square_distance = sum $ for i. sq (p_light.i - p.i)- scale = (cos1 * cos2) / (pdf_A * square_distance)- li = for i. scale * emittedRadiance.i * brdf.i-- case vis of- False -> zero- True -> li---def scatter_eye_rays- (depth:(Fin MAX_DEPTH)) (carry:(Key & Vec3 & Vec3)) :- ((Key & Vec3 & Vec3) & (Bool & Vec3 & Color)) =- (rng, ro, rd) = carry- (obj, t) = raymarch ro rd- brdf = brdf_map obj- - is_light = islight obj- did_intersect = not $ isnone obj-- li_e = (relu ( vdot (negvec rd) nor_light)) .* emittedRadiance- radiance = select is_light li_e zero- - p = ro + t .* rd- nor = calcNormal p- - -- Contribution directly from light.- (rng, subkey) = splitKey rng - li_direct = direct_light subkey p nor brdf- radiance = radiance + select did_intersect li_direct zero-- -- Sample bounced ray for future indirect contributions.- (rng, subkey) = splitKey rng- rd2 = sampleCosineWeightedHemisphere subkey nor-- carry = (rng, ro, rd)- outputs = (did_intersect, radiance, brdf)- (carry, outputs)----- Add light from each step if there was an intersection.-def accumulate_outgoing (li_indirect:Vec3) (x:(Bool & Vec3 & Color)) : Vec3 =- (did_intersect, radiance, brdf) = x- radiance + select did_intersect (for d. brdf.d * li_indirect.d) zero ---- Main loop.-def trace (rng:Key) (ro:Vec3) (rd:Vec3) (depth:(Fin MAX_DEPTH)) : Vec3 =- init = (rng, ro, rd)- (carry, outputs) = scan init scatter_eye_rays -- Forward pass.- fold zero \i c. accumulate_outgoing c (reverse outputs).i -- Backward pass.---' Setup and draw image--num_samples = 3-N = 500 -- pixel width and height of image.--xs = linspace (Fin N) 1.0 0.0 -- Reverse order because of pinhole camera-rng = newKey 0+BlockHalfWidths = Vec 3+Radius = Real+Radiance = Color++data ObjectGeom =+ Wall Direction Distance+ Block Position BlockHalfWidths Angle+ Sphere Position Radius++data Surface =+ Matte Color+ Mirror++OrientedSurface = (Direction & Surface)++data Object =+ PassiveObject ObjectGeom Surface+ -- position, half-width, intensity (assumed to point down)+ Light Position Real Radiance++Ray = (Position & Direction)+Filter = Color++-- TODO: use a record+-- num samples, num bounces+Params = (Int & Int)++-- TODO: use a list instead, once they work+data Scene n:Type = MkScene (n=>Object)++def sampleReflection ((nor, surf):OrientedSurface) ((pos, dir):Ray) (k:Key) : Ray =+ newDir = case surf of+ Matte _ -> sampleCosineWeightedHemisphere nor k+ -- TODO: surely there's some change-of-solid-angle correction we need to+ -- consider when reflecting off a curved surface.+ Mirror -> dir - (2.0 * dot dir nor) .* nor+ (pos, newDir)++def probReflection ((nor, surf):OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : Real =+ case surf of+ Matte _ -> relu $ dot nor outRayDir+ Mirror -> 0.0 -- TODO: this should be a delta function of some sort++def applyFilter (filter:Filter) (radiance:Radiance) : Radiance =+ for i. filter.i * radiance.i++def surfaceFilter (filter:Filter) (surf:Surface) : Filter =+ case surf of+ Matte color -> for i. filter.i * color.i+ Mirror -> filter++def sdObject (pos:Position) (obj:Object) : Distance =+ case obj of+ PassiveObject geom _ -> case geom of+ Wall nor d -> d + dot nor pos+ Block blockPos halfWidths angle ->+ pos' = rotateY (pos - blockPos) angle+ length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0+ Sphere spherePos r ->+ pos' = pos - spherePos+ max (length pos' - r) 0.0+ Light squarePos hw _ ->+ pos' = pos - squarePos+ halfWidths = [hw, 0.01, hw]+ length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0++def sdScene (scene:Scene n) (pos:Position) : (Object & Distance) =+ (MkScene objs) = scene+ minimumBy snd for i.+ obj = objs.i+ (obj, sdObject pos obj)++-- TODO: use AD!+def calcNormal (scene:Scene n) (pos:Position) : Direction =+ dist = \p. snd $ sdScene scene p+ normalize (gradNumerical dist pos)++data RayMarchResult =+ -- incident ray, surface normal, surface properties+ HitObj Ray OrientedSurface+ HitLight Radiance+ -- Could refine with failure reason (beyond horizon, failed to converge etc)+ HitNothing++def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult =+ max_iters = 100+ tol = 0.01+ startLength = 10.0 * tol -- trying to escape the current surface+ (rayOrigin, rayDir) = ray+ iter (10.0 * tol) \i rayLength.+ case i >= max_iters of+ True -> Done HitNothing+ False ->+ rayPos = rayOrigin + rayLength .* rayDir+ (obj, d) = sdScene scene $ rayPos+ -- 0.9 ensures we come close to the surface but don't touch it+ dNew = rayLength + 0.9 * d+ case d < tol of+ False -> Continue $ dNew+ True ->+ surfNorm = calcNormal scene rayPos+ case positiveProjection rayDir surfNorm of+ True ->+ -- Oops, we didn't escape the surface we're leaving..+ -- (Is there a more standard way to do this?)+ Continue dNew+ False ->+ -- We made it!+ Done $ case obj of+ PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf)+ Light _ _ radiance -> HitLight radiance++def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance =+ case raymarch scene ray of+ HitLight intensity -> intensity+ HitNothing -> zero+ HitObj _ _ -> zero++def sampleSquare (hw:Real) (k:Key) : Position =+ (kx, kz) = splitKey k+ x = randuniform (- hw) hw kx+ z = randuniform (- hw) hw kz+ [x, 0.0, z]++def sampleLightRadiance+ (scene:Scene n) (osurf:OrientedSurface) (inRay:Ray) (k:Key) : Radiance =+ (surfNor, surf) = osurf+ (rayPos, _) = inRay+ (MkScene objs) = scene+ snd $ withAccum \radiance.+ for i. case objs.i of+ PassiveObject _ _ -> ()+ Light lightPos hw _ ->+ (dirToLight, distToLight) = directionAndLength $+ lightPos + sampleSquare hw k - rayPos+ case positiveProjection dirToLight surfNor of+ False -> () -- light on the far side of current surface+ True ->+ fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight)+ outRay = (rayPos, dirToLight)+ coeff = fracSolidAngle * probReflection osurf inRay outRay+ radiance += coeff .* rayDirectRadiance scene outRay++def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color =+ (_, max_bounces) = params+ -- TODO: we ought to be able to use an accumulator here, but there's a bug+ noFilter = [1.0, 1.0, 1.0]+ iter (noFilter, zero, init_ray) $+ \i (filter, radiance, ray).+ case i >= max_bounces of+ True -> Done radiance+ False -> case raymarch scene ray of+ HitNothing -> Done radiance+ HitLight intensity -> case i == 0 of+ True -> Done intensity -- TODO: scale etc+ False -> Done radiance+ HitObj incidentRay osurf ->+ (k1, k2) = splitKey $ hash k i+ lightRadiance = sampleLightRadiance scene osurf incidentRay k1+ outRayHemisphere = sampleReflection osurf incidentRay k2+ newFilter = surfaceFilter filter (snd osurf)+ newRadiance = radiance + applyFilter newFilter lightRadiance+ Continue (newFilter, newRadiance, outRayHemisphere)++def avgRayColor (params:Params) (scene:Scene m) (ray:Ray) (k:Key) : Color =+ (num_samples, _) = params+ sampleAveraged (trace params scene ray) num_samples k++-- TODO: add number of pixels once we can hide sizes+-- sensor half-width, pinhole-sensor distance, pinhole position+-- (Assumes we're looking towards -z.)+Camera = (Position & Real & Real)++def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => Ray =+ -- images indexed from top-left+ (pos, halfWidth, sensorDist) = camera+ ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth+ xs = linspace (Fin n) (neg halfWidth) halfWidth+ for i j. (pos, normalize [xs.j, ys.i, neg sensorDist])++def takePicture+ (params:Params) (scene:Scene m) (n:Int) (camera:Camera)+ : ColorImage n n =+ rays = cameraRays n camera+ rootKey = newKey 0+ for i j.+ k = ixkey (ixkey rootKey i) j

This should be called "the David trick" because he always suggests it and it always works.

comment created time in 7 days

issue commentgoogle/jax

Slow Python tracing time with odeint inside tfp.optimizer.lbfgs_minimize

Copying from that comment on 3370 for convenience, I think we just need to memoize this line to avoid retracing.

comment created time in 8 days

issue commentgoogle/jax

Bad interaction between batch trace and custom jvp trace

Thanks for flagging this, and the great repros. I think this is a duplicate of a known bug, which I should look up and link here. The basic issue is that nondiff_argnums works basically like lexical closure over tracers, but that check is disallowing lexical closure over tracers!

@NeilGirdhar has a nice solution, and I believe a PR that I *still* haven't gotten to.

@brianwa84 how blocking is this for your work? I'd like to get to it this week, but I'm just trying to prioritize stuff. Let me know :)

comment created time in 8 days

issue commentgoogle/jax

Performance regression in compilation of JITed Jacobian

Thanks for letting us know about this! (By the way, unfortunately 64bit matmul on CPU can be slow to execute too: #3832.)

@hawkinsp any thoughts? Is this one for XLA:CPU folks?

comment created time in 8 days

pull request commentgoogle/jax

Cleanup: update license copyrights to 2020

I think copyrights are meant to be from when the file was written, and so they shouldn't be uniformly bumped to the current year.

comment created time in 9 days

push eventgoogle/jax

commit sha 4e873f417ab4f3e68a102548167f2b0a005edfad

Avoid re-flattening in jit() when no donate_argnums are present. (#3945) Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

push time in 10 days

pull request commentgoogle/jax

Avoid re-flattening in jit() and pmap() when no donate_argnums are present.

Thank you!

comment created time in 10 days

PR merged google/jax

Following the same special-casing of static_argnums, this should provide a speedup specially when the number of arguments provided is large.

pr closed time in 10 days

pull request commentgoogle/jax

Avoid lexically capturing the train_images value in MNIST VAE example.

IIUC @froystig correctly, #3238 would have the nice effect of turning this closed-over constant into a buffer kept on device. So that's pretty cool.

Seems like we should add `device_put`

to this PR and merge it, and also follow up on #3238 (which has bigger picture considerations, but also might let us simplify this example back down again if we want to).

comment created time in 10 days